Merge pull request #6096 from mnoskova:mn/SVMSGD_to_opencv3_0
This commit is contained in:
@@ -1499,6 +1499,165 @@ public:
|
||||
CV_WRAP static Ptr<LogisticRegression> create();
|
||||
};
|
||||
|
||||
|
||||
/****************************************************************************************\
|
||||
* Stochastic Gradient Descent SVM Classifier *
|
||||
\****************************************************************************************/
|
||||
|
||||
/*!
|
||||
@brief Stochastic Gradient Descent SVM classifier
|
||||
|
||||
SVMSGD provides a fast and easy-to-use implementation of the SVM classifier using the Stochastic Gradient Descent approach,
|
||||
as presented in @cite bottou2010large.
|
||||
|
||||
The classifier has following parameters:
|
||||
- model type,
|
||||
- margin type,
|
||||
- margin regularization (\f$\lambda\f$),
|
||||
- initial step size (\f$\gamma_0\f$),
|
||||
- step decreasing power (\f$c\f$),
|
||||
- and termination criteria.
|
||||
|
||||
The model type may have one of the following values: \ref SGD and \ref ASGD.
|
||||
|
||||
- \ref SGD is the classic version of SVMSGD classifier: every next step is calculated by the formula
|
||||
\f[w_{t+1} = w_t - \gamma(t) \frac{dQ_i}{dw} |_{w = w_t}\f]
|
||||
where
|
||||
- \f$w_t\f$ is the weights vector for decision function at step \f$t\f$,
|
||||
- \f$\gamma(t)\f$ is the step size of model parameters at the iteration \f$t\f$, it is decreased on each step by the formula
|
||||
\f$\gamma(t) = \gamma_0 (1 + \lambda \gamma_0 t) ^ {-c}\f$
|
||||
- \f$Q_i\f$ is the target functional from SVM task for sample with number \f$i\f$, this sample is chosen stochastically on each step of the algorithm.
|
||||
|
||||
- \ref ASGD is Average Stochastic Gradient Descent SVM Classifier. ASGD classifier averages weights vector on each step of algorithm by the formula
|
||||
\f$\widehat{w}_{t+1} = \frac{t}{1+t}\widehat{w}_{t} + \frac{1}{1+t}w_{t+1}\f$
|
||||
|
||||
The recommended model type is ASGD (following @cite bottou2010large).
|
||||
|
||||
The margin type may have one of the following values: \ref SOFT_MARGIN or \ref HARD_MARGIN.
|
||||
|
||||
- You should use \ref HARD_MARGIN type, if you have linearly separable sets.
|
||||
- You should use \ref SOFT_MARGIN type, if you have non-linearly separable sets or sets with outliers.
|
||||
- In the general case (if you know nothing about linear separability of your sets), use SOFT_MARGIN.
|
||||
|
||||
The other parameters may be described as follows:
|
||||
- Margin regularization parameter is responsible for weights decreasing at each step and for the strength of restrictions on outliers
|
||||
(the less the parameter, the less probability that an outlier will be ignored).
|
||||
Recommended value for SGD model is 0.0001, for ASGD model is 0.00001.
|
||||
|
||||
- Initial step size parameter is the initial value for the step size \f$\gamma(t)\f$.
|
||||
You will have to find the best initial step for your problem.
|
||||
|
||||
- Step decreasing power is the power parameter for \f$\gamma(t)\f$ decreasing by the formula, mentioned above.
|
||||
Recommended value for SGD model is 1, for ASGD model is 0.75.
|
||||
|
||||
- Termination criteria can be TermCriteria::COUNT, TermCriteria::EPS or TermCriteria::COUNT + TermCriteria::EPS.
|
||||
You will have to find the best termination criteria for your problem.
|
||||
|
||||
Note that the parameters margin regularization, initial step size, and step decreasing power should be positive.
|
||||
|
||||
To use SVMSGD algorithm do as follows:
|
||||
|
||||
- first, create the SVMSGD object. The algoorithm will set optimal parameters by default, but you can set your own parameters via functions setSvmsgdType(),
|
||||
setMarginType(), setMarginRegularization(), setInitialStepSize(), and setStepDecreasingPower().
|
||||
|
||||
- then the SVM model can be trained using the train features and the correspondent labels by the method train().
|
||||
|
||||
- after that, the label of a new feature vector can be predicted using the method predict().
|
||||
|
||||
@code
|
||||
// Create empty object
|
||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
|
||||
|
||||
// Train the Stochastic Gradient Descent SVM
|
||||
svmsgd->train(trainData);
|
||||
|
||||
// Predict labels for the new samples
|
||||
svmsgd->predict(samples, responses);
|
||||
@endcode
|
||||
|
||||
*/
|
||||
|
||||
class CV_EXPORTS_W SVMSGD : public cv::ml::StatModel
|
||||
{
|
||||
public:
|
||||
|
||||
/** SVMSGD type.
|
||||
ASGD is often the preferable choice. */
|
||||
enum SvmsgdType
|
||||
{
|
||||
SGD, //!< Stochastic Gradient Descent
|
||||
ASGD //!< Average Stochastic Gradient Descent
|
||||
};
|
||||
|
||||
/** Margin type.*/
|
||||
enum MarginType
|
||||
{
|
||||
SOFT_MARGIN, //!< General case, suits to the case of non-linearly separable sets, allows outliers.
|
||||
HARD_MARGIN //!< More accurate for the case of linearly separable sets.
|
||||
};
|
||||
|
||||
/**
|
||||
* @return the weights of the trained model (decision function f(x) = weights * x + shift).
|
||||
*/
|
||||
CV_WRAP virtual Mat getWeights() = 0;
|
||||
|
||||
/**
|
||||
* @return the shift of the trained model (decision function f(x) = weights * x + shift).
|
||||
*/
|
||||
CV_WRAP virtual float getShift() = 0;
|
||||
|
||||
/** @brief Creates empty model.
|
||||
* Use StatModel::train to train the model. Since %SVMSGD has several parameters, you may want to
|
||||
* find the best parameters for your problem or use setOptimalParameters() to set some default parameters.
|
||||
*/
|
||||
CV_WRAP static Ptr<SVMSGD> create();
|
||||
|
||||
/** @brief Function sets optimal parameters values for chosen SVM SGD model.
|
||||
* @param svmsgdType is the type of SVMSGD classifier.
|
||||
* @param marginType is the type of margin constraint.
|
||||
*/
|
||||
CV_WRAP virtual void setOptimalParameters(int svmsgdType = SVMSGD::ASGD, int marginType = SVMSGD::SOFT_MARGIN) = 0;
|
||||
|
||||
/** @brief %Algorithm type, one of SVMSGD::SvmsgdType. */
|
||||
/** @see setSvmsgdType */
|
||||
CV_WRAP virtual int getSvmsgdType() const = 0;
|
||||
/** @copybrief getSvmsgdType @see getSvmsgdType */
|
||||
CV_WRAP virtual void setSvmsgdType(int svmsgdType) = 0;
|
||||
|
||||
/** @brief %Margin type, one of SVMSGD::MarginType. */
|
||||
/** @see setMarginType */
|
||||
CV_WRAP virtual int getMarginType() const = 0;
|
||||
/** @copybrief getMarginType @see getMarginType */
|
||||
CV_WRAP virtual void setMarginType(int marginType) = 0;
|
||||
|
||||
/** @brief Parameter marginRegularization of a %SVMSGD optimization problem. */
|
||||
/** @see setMarginRegularization */
|
||||
CV_WRAP virtual float getMarginRegularization() const = 0;
|
||||
/** @copybrief getMarginRegularization @see getMarginRegularization */
|
||||
CV_WRAP virtual void setMarginRegularization(float marginRegularization) = 0;
|
||||
|
||||
/** @brief Parameter initialStepSize of a %SVMSGD optimization problem. */
|
||||
/** @see setInitialStepSize */
|
||||
CV_WRAP virtual float getInitialStepSize() const = 0;
|
||||
/** @copybrief getInitialStepSize @see getInitialStepSize */
|
||||
CV_WRAP virtual void setInitialStepSize(float InitialStepSize) = 0;
|
||||
|
||||
/** @brief Parameter stepDecreasingPower of a %SVMSGD optimization problem. */
|
||||
/** @see setStepDecreasingPower */
|
||||
CV_WRAP virtual float getStepDecreasingPower() const = 0;
|
||||
/** @copybrief getStepDecreasingPower @see getStepDecreasingPower */
|
||||
CV_WRAP virtual void setStepDecreasingPower(float stepDecreasingPower) = 0;
|
||||
|
||||
/** @brief Termination criteria of the training algorithm.
|
||||
You can specify the maximum number of iterations (maxCount) and/or how much the error could
|
||||
change between the iterations to make the algorithm continue (epsilon).*/
|
||||
/** @see setTermCriteria */
|
||||
CV_WRAP virtual TermCriteria getTermCriteria() const = 0;
|
||||
/** @copybrief getTermCriteria @see getTermCriteria */
|
||||
CV_WRAP virtual void setTermCriteria(const cv::TermCriteria &val) = 0;
|
||||
};
|
||||
|
||||
|
||||
/****************************************************************************************\
|
||||
* Auxilary functions declarations *
|
||||
\****************************************************************************************/
|
||||
|
Reference in New Issue
Block a user