Updated ml module interfaces and documentation
This commit is contained in:
@@ -60,31 +60,41 @@ using namespace std;
|
||||
namespace cv {
|
||||
namespace ml {
|
||||
|
||||
LogisticRegression::Params::Params(double learning_rate,
|
||||
int iters,
|
||||
int method,
|
||||
int normlization,
|
||||
int reg,
|
||||
int batch_size)
|
||||
class LrParams
|
||||
{
|
||||
alpha = learning_rate;
|
||||
num_iters = iters;
|
||||
norm = normlization;
|
||||
regularized = reg;
|
||||
train_method = method;
|
||||
mini_batch_size = batch_size;
|
||||
term_crit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, num_iters, alpha);
|
||||
}
|
||||
public:
|
||||
LrParams()
|
||||
{
|
||||
alpha = 0.001;
|
||||
num_iters = 1000;
|
||||
norm = LogisticRegression::REG_L2;
|
||||
train_method = LogisticRegression::BATCH;
|
||||
mini_batch_size = 1;
|
||||
term_crit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, num_iters, alpha);
|
||||
}
|
||||
|
||||
double alpha; //!< learning rate.
|
||||
int num_iters; //!< number of iterations.
|
||||
int norm;
|
||||
int train_method;
|
||||
int mini_batch_size;
|
||||
TermCriteria term_crit;
|
||||
};
|
||||
|
||||
class LogisticRegressionImpl : public LogisticRegression
|
||||
{
|
||||
public:
|
||||
LogisticRegressionImpl(const Params& pms)
|
||||
: params(pms)
|
||||
{
|
||||
}
|
||||
|
||||
LogisticRegressionImpl() { }
|
||||
virtual ~LogisticRegressionImpl() {}
|
||||
|
||||
CV_IMPL_PROPERTY(double, LearningRate, params.alpha)
|
||||
CV_IMPL_PROPERTY(int, Iterations, params.num_iters)
|
||||
CV_IMPL_PROPERTY(int, Regularization, params.norm)
|
||||
CV_IMPL_PROPERTY(int, TrainMethod, params.train_method)
|
||||
CV_IMPL_PROPERTY(int, MiniBatchSize, params.mini_batch_size)
|
||||
CV_IMPL_PROPERTY(TermCriteria, TermCriteria, params.term_crit)
|
||||
|
||||
virtual bool train( const Ptr<TrainData>& trainData, int=0 );
|
||||
virtual float predict(InputArray samples, OutputArray results, int) const;
|
||||
virtual void clear();
|
||||
@@ -103,7 +113,7 @@ protected:
|
||||
bool set_label_map(const Mat& _labels_i);
|
||||
Mat remap_labels(const Mat& _labels_i, const map<int, int>& lmap) const;
|
||||
protected:
|
||||
Params params;
|
||||
LrParams params;
|
||||
Mat learnt_thetas;
|
||||
map<int, int> forward_mapper;
|
||||
map<int, int> reverse_mapper;
|
||||
@@ -111,9 +121,9 @@ protected:
|
||||
Mat labels_n;
|
||||
};
|
||||
|
||||
Ptr<LogisticRegression> LogisticRegression::create(const Params& params)
|
||||
Ptr<LogisticRegression> LogisticRegression::create()
|
||||
{
|
||||
return makePtr<LogisticRegressionImpl>(params);
|
||||
return makePtr<LogisticRegressionImpl>();
|
||||
}
|
||||
|
||||
bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
|
||||
@@ -312,7 +322,7 @@ double LogisticRegressionImpl::compute_cost(const Mat& _data, const Mat& _labels
|
||||
theta_b = _init_theta(Range(1, n), Range::all());
|
||||
multiply(theta_b, theta_b, theta_c, 1);
|
||||
|
||||
if(this->params.regularized > 0)
|
||||
if(params.norm != REG_NONE)
|
||||
{
|
||||
llambda = 1;
|
||||
}
|
||||
@@ -367,7 +377,7 @@ Mat LogisticRegressionImpl::compute_batch_gradient(const Mat& _data, const Mat&
|
||||
m = _data.rows;
|
||||
n = _data.cols;
|
||||
|
||||
if(this->params.regularized > 0)
|
||||
if(params.norm != REG_NONE)
|
||||
{
|
||||
llambda = 1;
|
||||
}
|
||||
@@ -439,7 +449,7 @@ Mat LogisticRegressionImpl::compute_mini_batch_gradient(const Mat& _data, const
|
||||
Mat data_d;
|
||||
Mat labels_l;
|
||||
|
||||
if(this->params.regularized > 0)
|
||||
if(params.norm != REG_NONE)
|
||||
{
|
||||
lambda_l = 1;
|
||||
}
|
||||
@@ -570,7 +580,6 @@ void LogisticRegressionImpl::write(FileStorage& fs) const
|
||||
fs<<"alpha"<<this->params.alpha;
|
||||
fs<<"iterations"<<this->params.num_iters;
|
||||
fs<<"norm"<<this->params.norm;
|
||||
fs<<"regularized"<<this->params.regularized;
|
||||
fs<<"train_method"<<this->params.train_method;
|
||||
if(this->params.train_method == LogisticRegression::MINI_BATCH)
|
||||
{
|
||||
@@ -592,7 +601,6 @@ void LogisticRegressionImpl::read(const FileNode& fn)
|
||||
this->params.alpha = (double)fn["alpha"];
|
||||
this->params.num_iters = (int)fn["iterations"];
|
||||
this->params.norm = (int)fn["norm"];
|
||||
this->params.regularized = (int)fn["regularized"];
|
||||
this->params.train_method = (int)fn["train_method"];
|
||||
|
||||
if(this->params.train_method == LogisticRegression::MINI_BATCH)
|
||||
|
Reference in New Issue
Block a user