diff --git a/modules/ml/include/opencv2/ml.hpp b/modules/ml/include/opencv2/ml.hpp index 346df8f15..75fb6d134 100644 --- a/modules/ml/include/opencv2/ml.hpp +++ b/modules/ml/include/opencv2/ml.hpp @@ -1588,7 +1588,6 @@ public: ASGD is often the preferable choice. */ enum SvmsgdType { - ILLEGAL_SVMSGD_TYPE, SGD, //!< Stochastic Gradient Descent ASGD //!< Average Stochastic Gradient Descent }; @@ -1596,7 +1595,6 @@ public: /** Margin type.*/ enum MarginType { - ILLEGAL_MARGIN_TYPE, SOFT_MARGIN, //!< General case, suits to the case of non-linearly separable sets, allows outliers. HARD_MARGIN //!< More accurate for the case of linearly separable sets. }; diff --git a/modules/ml/src/svmsgd.cpp b/modules/ml/src/svmsgd.cpp index 77ac2ad67..0f1efcfea 100644 --- a/modules/ml/src/svmsgd.cpp +++ b/modules/ml/src/svmsgd.cpp @@ -89,14 +89,8 @@ public: virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN); - virtual int getSvmsgdType() const; - - virtual void setSvmsgdType(int svmsgdType); - - virtual int getMarginType() const; - - virtual void setMarginType(int marginType); - + CV_IMPL_PROPERTY(int, SvmsgdType, params.svmsgdType) + CV_IMPL_PROPERTY(int, MarginType, params.marginType) CV_IMPL_PROPERTY(float, Lambda, params.lambda) CV_IMPL_PROPERTY(float, Gamma0, params.gamma0) CV_IMPL_PROPERTY(float, C, params.c) @@ -132,8 +126,8 @@ private: float gamma0; //learning rate float c; TermCriteria termCrit; - SvmsgdType svmsgdType; - MarginType marginType; + int svmsgdType; + int marginType; }; SVMSGDParams params; @@ -148,9 +142,9 @@ std::pair SVMSGDImpl::areClassesEmpty(Mat responses) { CV_Assert(responses.cols == 1 || responses.rows == 1); std::pair emptyInClasses(true, true); - int limit_index = responses.rows; + int limitIndex = responses.rows; - for(int index = 0; index < limit_index; index++) + for(int index = 0; index < limitIndex; index++) { if (isPositive(responses.at(index))) emptyInClasses.first = false; @@ -276,9 +270,9 @@ bool SVMSGDImpl::train(const Ptr& data, int) int extendedTrainSamplesCount = extendedTrainSamples.rows; int extendedFeatureCount = extendedTrainSamples.cols; - Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); // Initialize extendedWeights vector with zeros - Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); //extendedWeights vector for calculating terminal criterion - Mat averageExtendedWeights; //average extendedWeights vector for ASGD model + Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); + Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); + Mat averageExtendedWeights; if (params.svmsgdType == ASGD) { averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); @@ -407,10 +401,8 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const case ASGD: SvmsgdTypeStr = "ASGD"; break; - case ILLEGAL_SVMSGD_TYPE: - SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType); default: - std::cout << "params.svmsgdType isn't initialized" << std::endl; + SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType); } fs << "svmsgdType" << SvmsgdTypeStr; @@ -425,10 +417,8 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const case HARD_MARGIN: marginTypeStr = "HARD_MARGIN"; break; - case ILLEGAL_MARGIN_TYPE: - marginTypeStr = format("Unknown_%d", params.marginType); default: - std::cout << "params.marginType isn't initialized" << std::endl; + marginTypeStr = format("Unknown_%d", params.marginType); } fs << "marginType" << marginTypeStr; @@ -458,21 +448,21 @@ void SVMSGDImpl::read(const FileNode& fn) void SVMSGDImpl::readParams( const FileNode& fn ) { String svmsgdTypeStr = (String)fn["svmsgdType"]; - SvmsgdType svmsgdType = + int svmsgdType = svmsgdTypeStr == "SGD" ? SGD : - svmsgdTypeStr == "ASGD" ? ASGD : ILLEGAL_SVMSGD_TYPE; + svmsgdTypeStr == "ASGD" ? ASGD : -1; - if( svmsgdType == ILLEGAL_SVMSGD_TYPE ) + if( svmsgdType < 0 ) CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" ); params.svmsgdType = svmsgdType; String marginTypeStr = (String)fn["marginType"]; - MarginType marginType = + int marginType = marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN : - marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE; + marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1; - if( marginType == ILLEGAL_MARGIN_TYPE ) + if( marginType < 0 ) CV_Error( CV_StsParseError, "Missing or invalid margin type" ); params.marginType = marginType; @@ -510,8 +500,8 @@ SVMSGDImpl::SVMSGDImpl() { clear(); - params.svmsgdType = ILLEGAL_SVMSGD_TYPE; - params.marginType = ILLEGAL_MARGIN_TYPE; + params.svmsgdType = -1; + params.marginType = -1; // Parameters for learning params.lambda = 0; // regularization @@ -529,7 +519,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType) case SGD: params.svmsgdType = SGD; params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN : - (marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE; + (marginType == HARD_MARGIN) ? HARD_MARGIN : -1; params.lambda = 0.0001f; params.gamma0 = 0.05f; params.c = 1.f; @@ -539,7 +529,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType) case ASGD: params.svmsgdType = ASGD; params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN : - (marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE; + (marginType == HARD_MARGIN) ? HARD_MARGIN : -1; params.lambda = 0.00001f; params.gamma0 = 0.05f; params.c = 0.75f; @@ -550,45 +540,5 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType) CV_Error( CV_StsParseError, "SVMSGD model data is invalid" ); } } - -void SVMSGDImpl::setSvmsgdType(int type) -{ - switch (type) - { - case SGD: - params.svmsgdType = SGD; - break; - case ASGD: - params.svmsgdType = ASGD; - break; - default: - params.svmsgdType = ILLEGAL_SVMSGD_TYPE; - } -} - -int SVMSGDImpl::getSvmsgdType() const -{ - return params.svmsgdType; -} - -void SVMSGDImpl::setMarginType(int type) -{ - switch (type) - { - case HARD_MARGIN: - params.marginType = HARD_MARGIN; - break; - case SOFT_MARGIN: - params.marginType = SOFT_MARGIN; - break; - default: - params.marginType = ILLEGAL_MARGIN_TYPE; - } -} - -int SVMSGDImpl::getMarginType() const -{ - return params.marginType; -} } //ml } //cv