Reworked ML logistic regression implementation, initial version

This commit is contained in:
Maksim Shabunin
2014-08-14 19:01:45 +04:00
parent 71770eb790
commit 3e26086f82
4 changed files with 214 additions and 312 deletions

View File

@@ -571,56 +571,43 @@ public:
/****************************************************************************************\
* Logistic Regression *
\****************************************************************************************/
struct CV_EXPORTS LogisticRegressionParams
{
double alpha;
int num_iters;
int norm;
int regularized;
int train_method;
int mini_batch_size;
cv::TermCriteria term_crit;
LogisticRegressionParams();
LogisticRegressionParams(double learning_rate, int iters, int train_method, int normlization, int reg, int mini_batch_size);
};
class CV_EXPORTS LogisticRegression
class CV_EXPORTS LogisticRegression : public StatModel
{
public:
LogisticRegression( const LogisticRegressionParams& params = LogisticRegressionParams());
LogisticRegression(cv::InputArray data_ip, cv::InputArray labels_ip, const LogisticRegressionParams& params);
virtual ~LogisticRegression();
class CV_EXPORTS Params
{
public:
Params(double learning_rate = 0.001,
int iters = 1000,
int method = LogisticRegression::BATCH,
int normlization = LogisticRegression::REG_L2,
int reg = 1,
int batch_size = 1);
double alpha;
int num_iters;
int norm;
int regularized;
int train_method;
int mini_batch_size;
cv::TermCriteria term_crit;
};
enum { REG_L1 = 0, REG_L2 = 1};
enum { BATCH = 0, MINI_BATCH = 1};
virtual bool train(cv::InputArray data_ip, cv::InputArray label_ip);
virtual void predict( cv::InputArray data, cv::OutputArray predicted_labels ) const;
// Algorithm interface
virtual void write( FileStorage &fs ) const = 0;
virtual void read( const FileNode &fn ) = 0;
virtual void write(FileStorage& fs) const;
virtual void read(const FileNode& fn);
// StatModel interface
virtual bool train( const Ptr<TrainData>& trainData, int flags=0 ) = 0;
virtual float predict( InputArray samples, OutputArray results=noArray(), int flags=0 ) const = 0;
virtual void clear() = 0;
const cv::Mat get_learnt_thetas() const;
virtual void clear();
virtual Mat get_learnt_thetas() const = 0;
protected:
LogisticRegressionParams params;
cv::Mat learnt_thetas;
std::string default_model_name;
std::map<int, int> forward_mapper;
std::map<int, int> reverse_mapper;
cv::Mat labels_o;
cv::Mat labels_n;
static cv::Mat calc_sigmoid(const cv::Mat& data);
virtual double compute_cost(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta);
virtual cv::Mat compute_batch_gradient(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta);
virtual cv::Mat compute_mini_batch_gradient(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta);
virtual bool set_label_map(const cv::Mat& labels);
static cv::Mat remap_labels(const cv::Mat& labels, const std::map<int, int>& lmap);
static Ptr<LogisticRegression> create( const Params& params = Params() );
};
/****************************************************************************************\