Updated ml module interfaces and documentation
This commit is contained in:
@@ -103,54 +103,60 @@ static void checkParamGrid(const ParamGrid& pg)
|
||||
}
|
||||
|
||||
// SVM training parameters
|
||||
SVM::Params::Params()
|
||||
struct SvmParams
|
||||
{
|
||||
svmType = SVM::C_SVC;
|
||||
kernelType = SVM::RBF;
|
||||
degree = 0;
|
||||
gamma = 1;
|
||||
coef0 = 0;
|
||||
C = 1;
|
||||
nu = 0;
|
||||
p = 0;
|
||||
termCrit = TermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, FLT_EPSILON );
|
||||
}
|
||||
int svmType;
|
||||
int kernelType;
|
||||
double gamma;
|
||||
double coef0;
|
||||
double degree;
|
||||
double C;
|
||||
double nu;
|
||||
double p;
|
||||
Mat classWeights;
|
||||
TermCriteria termCrit;
|
||||
|
||||
SvmParams()
|
||||
{
|
||||
svmType = SVM::C_SVC;
|
||||
kernelType = SVM::RBF;
|
||||
degree = 0;
|
||||
gamma = 1;
|
||||
coef0 = 0;
|
||||
C = 1;
|
||||
nu = 0;
|
||||
p = 0;
|
||||
termCrit = TermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, FLT_EPSILON );
|
||||
}
|
||||
|
||||
SVM::Params::Params( int _svmType, int _kernelType,
|
||||
double _degree, double _gamma, double _coef0,
|
||||
double _Con, double _nu, double _p,
|
||||
const Mat& _classWeights, TermCriteria _termCrit )
|
||||
{
|
||||
svmType = _svmType;
|
||||
kernelType = _kernelType;
|
||||
degree = _degree;
|
||||
gamma = _gamma;
|
||||
coef0 = _coef0;
|
||||
C = _Con;
|
||||
nu = _nu;
|
||||
p = _p;
|
||||
classWeights = _classWeights;
|
||||
termCrit = _termCrit;
|
||||
}
|
||||
SvmParams( int _svmType, int _kernelType,
|
||||
double _degree, double _gamma, double _coef0,
|
||||
double _Con, double _nu, double _p,
|
||||
const Mat& _classWeights, TermCriteria _termCrit )
|
||||
{
|
||||
svmType = _svmType;
|
||||
kernelType = _kernelType;
|
||||
degree = _degree;
|
||||
gamma = _gamma;
|
||||
coef0 = _coef0;
|
||||
C = _Con;
|
||||
nu = _nu;
|
||||
p = _p;
|
||||
classWeights = _classWeights;
|
||||
termCrit = _termCrit;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////// SVM kernel ///////////////////////////////////////
|
||||
class SVMKernelImpl : public SVM::Kernel
|
||||
{
|
||||
public:
|
||||
SVMKernelImpl()
|
||||
{
|
||||
}
|
||||
|
||||
SVMKernelImpl( const SVM::Params& _params )
|
||||
SVMKernelImpl( const SvmParams& _params = SvmParams() )
|
||||
{
|
||||
params = _params;
|
||||
}
|
||||
|
||||
virtual ~SVMKernelImpl()
|
||||
{
|
||||
}
|
||||
|
||||
int getType() const
|
||||
{
|
||||
return params.kernelType;
|
||||
@@ -327,7 +333,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
SVM::Params params;
|
||||
SvmParams params;
|
||||
};
|
||||
|
||||
|
||||
@@ -1185,7 +1191,7 @@ public:
|
||||
int cache_size;
|
||||
int max_cache_size;
|
||||
Mat samples;
|
||||
SVM::Params params;
|
||||
SvmParams params;
|
||||
vector<KernelRow> lru_cache;
|
||||
int lru_first;
|
||||
int lru_last;
|
||||
@@ -1215,6 +1221,7 @@ public:
|
||||
SVMImpl()
|
||||
{
|
||||
clear();
|
||||
checkParams();
|
||||
}
|
||||
|
||||
~SVMImpl()
|
||||
@@ -1235,33 +1242,69 @@ public:
|
||||
return sv;
|
||||
}
|
||||
|
||||
void setParams( const Params& _params, const Ptr<Kernel>& _kernel )
|
||||
CV_IMPL_PROPERTY(int, Type, params.svmType)
|
||||
CV_IMPL_PROPERTY(double, Gamma, params.gamma)
|
||||
CV_IMPL_PROPERTY(double, Coef0, params.coef0)
|
||||
CV_IMPL_PROPERTY(double, Degree, params.degree)
|
||||
CV_IMPL_PROPERTY(double, C, params.C)
|
||||
CV_IMPL_PROPERTY(double, Nu, params.nu)
|
||||
CV_IMPL_PROPERTY(double, P, params.p)
|
||||
CV_IMPL_PROPERTY_S(cv::Mat, ClassWeights, params.classWeights)
|
||||
CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)
|
||||
|
||||
int getKernelType() const
|
||||
{
|
||||
params = _params;
|
||||
return params.kernelType;
|
||||
}
|
||||
|
||||
void setKernel(int kernelType)
|
||||
{
|
||||
params.kernelType = kernelType;
|
||||
if (kernelType != CUSTOM)
|
||||
kernel = makePtr<SVMKernelImpl>(params);
|
||||
}
|
||||
|
||||
void setCustomKernel(const Ptr<Kernel> &_kernel)
|
||||
{
|
||||
params.kernelType = CUSTOM;
|
||||
kernel = _kernel;
|
||||
}
|
||||
|
||||
void checkParams()
|
||||
{
|
||||
int kernelType = params.kernelType;
|
||||
if (kernelType != CUSTOM)
|
||||
{
|
||||
if( kernelType != LINEAR && kernelType != POLY &&
|
||||
kernelType != SIGMOID && kernelType != RBF &&
|
||||
kernelType != INTER && kernelType != CHI2)
|
||||
CV_Error( CV_StsBadArg, "Unknown/unsupported kernel type" );
|
||||
|
||||
if( kernelType == LINEAR )
|
||||
params.gamma = 1;
|
||||
else if( params.gamma <= 0 )
|
||||
CV_Error( CV_StsOutOfRange, "gamma parameter of the kernel must be positive" );
|
||||
|
||||
if( kernelType != SIGMOID && kernelType != POLY )
|
||||
params.coef0 = 0;
|
||||
else if( params.coef0 < 0 )
|
||||
CV_Error( CV_StsOutOfRange, "The kernel parameter <coef0> must be positive or zero" );
|
||||
|
||||
if( kernelType != POLY )
|
||||
params.degree = 0;
|
||||
else if( params.degree <= 0 )
|
||||
CV_Error( CV_StsOutOfRange, "The kernel parameter <degree> must be positive" );
|
||||
|
||||
kernel = makePtr<SVMKernelImpl>(params);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!kernel)
|
||||
CV_Error( CV_StsBadArg, "Custom kernel is not set" );
|
||||
}
|
||||
|
||||
int svmType = params.svmType;
|
||||
|
||||
if( kernelType != LINEAR && kernelType != POLY &&
|
||||
kernelType != SIGMOID && kernelType != RBF &&
|
||||
kernelType != INTER && kernelType != CHI2)
|
||||
CV_Error( CV_StsBadArg, "Unknown/unsupported kernel type" );
|
||||
|
||||
if( kernelType == LINEAR )
|
||||
params.gamma = 1;
|
||||
else if( params.gamma <= 0 )
|
||||
CV_Error( CV_StsOutOfRange, "gamma parameter of the kernel must be positive" );
|
||||
|
||||
if( kernelType != SIGMOID && kernelType != POLY )
|
||||
params.coef0 = 0;
|
||||
else if( params.coef0 < 0 )
|
||||
CV_Error( CV_StsOutOfRange, "The kernel parameter <coef0> must be positive or zero" );
|
||||
|
||||
if( kernelType != POLY )
|
||||
params.degree = 0;
|
||||
else if( params.degree <= 0 )
|
||||
CV_Error( CV_StsOutOfRange, "The kernel parameter <degree> must be positive" );
|
||||
|
||||
if( svmType != C_SVC && svmType != NU_SVC &&
|
||||
svmType != ONE_CLASS && svmType != EPS_SVR &&
|
||||
svmType != NU_SVR )
|
||||
@@ -1285,28 +1328,18 @@ public:
|
||||
if( svmType != C_SVC )
|
||||
params.classWeights.release();
|
||||
|
||||
termCrit = params.termCrit;
|
||||
if( !(termCrit.type & TermCriteria::EPS) )
|
||||
termCrit.epsilon = DBL_EPSILON;
|
||||
termCrit.epsilon = std::max(termCrit.epsilon, DBL_EPSILON);
|
||||
if( !(termCrit.type & TermCriteria::COUNT) )
|
||||
termCrit.maxCount = INT_MAX;
|
||||
termCrit.maxCount = std::max(termCrit.maxCount, 1);
|
||||
|
||||
if( _kernel )
|
||||
kernel = _kernel;
|
||||
else
|
||||
kernel = makePtr<SVMKernelImpl>(params);
|
||||
if( !(params.termCrit.type & TermCriteria::EPS) )
|
||||
params.termCrit.epsilon = DBL_EPSILON;
|
||||
params.termCrit.epsilon = std::max(params.termCrit.epsilon, DBL_EPSILON);
|
||||
if( !(params.termCrit.type & TermCriteria::COUNT) )
|
||||
params.termCrit.maxCount = INT_MAX;
|
||||
params.termCrit.maxCount = std::max(params.termCrit.maxCount, 1);
|
||||
}
|
||||
|
||||
Params getParams() const
|
||||
void setParams( const SvmParams& _params)
|
||||
{
|
||||
return params;
|
||||
}
|
||||
|
||||
Ptr<Kernel> getKernel() const
|
||||
{
|
||||
return kernel;
|
||||
params = _params;
|
||||
checkParams();
|
||||
}
|
||||
|
||||
int getSVCount(int i) const
|
||||
@@ -1335,9 +1368,9 @@ public:
|
||||
_responses.convertTo(_yf, CV_32F);
|
||||
|
||||
bool ok =
|
||||
svmType == ONE_CLASS ? Solver::solve_one_class( _samples, params.nu, kernel, _alpha, sinfo, termCrit ) :
|
||||
svmType == EPS_SVR ? Solver::solve_eps_svr( _samples, _yf, params.p, params.C, kernel, _alpha, sinfo, termCrit ) :
|
||||
svmType == NU_SVR ? Solver::solve_nu_svr( _samples, _yf, params.nu, params.C, kernel, _alpha, sinfo, termCrit ) : false;
|
||||
svmType == ONE_CLASS ? Solver::solve_one_class( _samples, params.nu, kernel, _alpha, sinfo, params.termCrit ) :
|
||||
svmType == EPS_SVR ? Solver::solve_eps_svr( _samples, _yf, params.p, params.C, kernel, _alpha, sinfo, params.termCrit ) :
|
||||
svmType == NU_SVR ? Solver::solve_nu_svr( _samples, _yf, params.nu, params.C, kernel, _alpha, sinfo, params.termCrit ) : false;
|
||||
|
||||
if( !ok )
|
||||
return false;
|
||||
@@ -1397,7 +1430,7 @@ public:
|
||||
//check that while cross-validation there were the samples from all the classes
|
||||
if( class_ranges[class_count] <= 0 )
|
||||
CV_Error( CV_StsBadArg, "While cross-validation one or more of the classes have "
|
||||
"been fell out of the sample. Try to enlarge <CvSVMParams::k_fold>" );
|
||||
"been fell out of the sample. Try to enlarge <Params::k_fold>" );
|
||||
|
||||
if( svmType == NU_SVC )
|
||||
{
|
||||
@@ -1448,10 +1481,10 @@ public:
|
||||
DecisionFunc df;
|
||||
bool ok = params.svmType == C_SVC ?
|
||||
Solver::solve_c_svc( temp_samples, temp_y, Cp, Cn,
|
||||
kernel, _alpha, sinfo, termCrit ) :
|
||||
kernel, _alpha, sinfo, params.termCrit ) :
|
||||
params.svmType == NU_SVC ?
|
||||
Solver::solve_nu_svc( temp_samples, temp_y, params.nu,
|
||||
kernel, _alpha, sinfo, termCrit ) :
|
||||
kernel, _alpha, sinfo, params.termCrit ) :
|
||||
false;
|
||||
if( !ok )
|
||||
return false;
|
||||
@@ -1557,6 +1590,8 @@ public:
|
||||
{
|
||||
clear();
|
||||
|
||||
checkParams();
|
||||
|
||||
int svmType = params.svmType;
|
||||
Mat samples = data->getTrainSamples();
|
||||
Mat responses;
|
||||
@@ -1586,6 +1621,8 @@ public:
|
||||
ParamGrid nu_grid, ParamGrid coef_grid, ParamGrid degree_grid,
|
||||
bool balanced )
|
||||
{
|
||||
checkParams();
|
||||
|
||||
int svmType = params.svmType;
|
||||
RNG rng((uint64)-1);
|
||||
|
||||
@@ -1708,7 +1745,7 @@ public:
|
||||
int test_sample_count = (sample_count + k_fold/2)/k_fold;
|
||||
int train_sample_count = sample_count - test_sample_count;
|
||||
|
||||
Params best_params = params;
|
||||
SvmParams best_params = params;
|
||||
double min_error = FLT_MAX;
|
||||
|
||||
int rtype = responses.type();
|
||||
@@ -1729,7 +1766,7 @@ public:
|
||||
FOR_IN_GRID(degree, degree_grid)
|
||||
{
|
||||
// make sure we updated the kernel and other parameters
|
||||
setParams(params, Ptr<Kernel>() );
|
||||
setParams(params);
|
||||
|
||||
double error = 0;
|
||||
for( k = 0; k < k_fold; k++ )
|
||||
@@ -1919,7 +1956,9 @@ public:
|
||||
kernelType == LINEAR ? "LINEAR" :
|
||||
kernelType == POLY ? "POLY" :
|
||||
kernelType == RBF ? "RBF" :
|
||||
kernelType == SIGMOID ? "SIGMOID" : format("Unknown_%d", kernelType);
|
||||
kernelType == SIGMOID ? "SIGMOID" :
|
||||
kernelType == CHI2 ? "CHI2" :
|
||||
kernelType == INTER ? "INTER" : format("Unknown_%d", kernelType);
|
||||
|
||||
fs << "svmType" << svm_type_str;
|
||||
|
||||
@@ -2036,7 +2075,7 @@ public:
|
||||
|
||||
void read_params( const FileNode& fn )
|
||||
{
|
||||
Params _params;
|
||||
SvmParams _params;
|
||||
|
||||
// check for old naming
|
||||
String svm_type_str = (String)(fn["svm_type"].empty() ? fn["svmType"] : fn["svm_type"]);
|
||||
@@ -2059,10 +2098,12 @@ public:
|
||||
kernel_type_str == "LINEAR" ? LINEAR :
|
||||
kernel_type_str == "POLY" ? POLY :
|
||||
kernel_type_str == "RBF" ? RBF :
|
||||
kernel_type_str == "SIGMOID" ? SIGMOID : -1;
|
||||
kernel_type_str == "SIGMOID" ? SIGMOID :
|
||||
kernel_type_str == "CHI2" ? CHI2 :
|
||||
kernel_type_str == "INTER" ? INTER : CUSTOM;
|
||||
|
||||
if( kernelType < 0 )
|
||||
CV_Error( CV_StsParseError, "Missing of invalid SVM kernel type" );
|
||||
if( kernelType == CUSTOM )
|
||||
CV_Error( CV_StsParseError, "Invalid SVM kernel type (or custom kernel)" );
|
||||
|
||||
_params.svmType = svmType;
|
||||
_params.kernelType = kernelType;
|
||||
@@ -2086,7 +2127,7 @@ public:
|
||||
else
|
||||
_params.termCrit = TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 1000, FLT_EPSILON );
|
||||
|
||||
setParams( _params, Ptr<Kernel>() );
|
||||
setParams( _params );
|
||||
}
|
||||
|
||||
void read( const FileNode& fn )
|
||||
@@ -2154,8 +2195,7 @@ public:
|
||||
optimize_linear_svm();
|
||||
}
|
||||
|
||||
Params params;
|
||||
TermCriteria termCrit;
|
||||
SvmParams params;
|
||||
Mat class_labels;
|
||||
int var_count;
|
||||
Mat sv;
|
||||
@@ -2167,11 +2207,9 @@ public:
|
||||
};
|
||||
|
||||
|
||||
Ptr<SVM> SVM::create(const Params& params, const Ptr<SVM::Kernel>& kernel)
|
||||
Ptr<SVM> SVM::create()
|
||||
{
|
||||
Ptr<SVMImpl> p = makePtr<SVMImpl>();
|
||||
p->setParams(params, kernel);
|
||||
return p;
|
||||
return makePtr<SVMImpl>();
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user