added updated logistic regression prototype with newer C++ API
This commit is contained in:
parent
0e13f33193
commit
d5ad4f3255
@ -571,81 +571,66 @@ public:
|
||||
/****************************************************************************************\
|
||||
* Logistic Regression *
|
||||
\****************************************************************************************/
|
||||
|
||||
struct CV_EXPORTS_W_MAP CvLR_TrainParams
|
||||
namespace cv
|
||||
{
|
||||
CV_PROP_RW double alpha;
|
||||
CV_PROP_RW int num_iters;
|
||||
CV_PROP_RW int norm;
|
||||
///////////////////////////////////////////////////
|
||||
// CV_PROP_RW int debug;
|
||||
///////////////////////////////////////////////////
|
||||
CV_PROP_RW int regularized;
|
||||
CV_PROP_RW int train_method;
|
||||
CV_PROP_RW int minibatchsize;
|
||||
struct CV_EXPORTS LogisticRegressionParams
|
||||
{
|
||||
double alpha;
|
||||
int num_iters;
|
||||
int norm;
|
||||
int regularized;
|
||||
int train_method;
|
||||
int mini_batch_size;
|
||||
CvTermCriteria term_crit;
|
||||
|
||||
CV_PROP_RW CvTermCriteria term_crit;
|
||||
|
||||
CvLR_TrainParams();
|
||||
///////////////////////////////////////////////////
|
||||
// CvLR_TrainParams(double alpha, int num_iters, int norm, int debug, int regularized, int train_method, int minbatchsize);
|
||||
///////////////////////////////////////////////////
|
||||
CvLR_TrainParams(double alpha, int num_iters, int norm, int regularized, int train_method, int minbatchsize);
|
||||
~CvLR_TrainParams();
|
||||
LogisticRegressionParams();
|
||||
LogisticRegressionParams(double alpha, int num_iters, int norm, int regularized, int train_method, int minbatchsize);
|
||||
};
|
||||
|
||||
class CV_EXPORTS_W CvLR : public CvStatModel
|
||||
class CV_EXPORTS LogisticRegression
|
||||
{
|
||||
public:
|
||||
CvLR();
|
||||
// CvLR(const CvLR_TrainParams& Params);
|
||||
|
||||
CvLR(const cv::Mat& data, const cv::Mat& labels, const CvLR_TrainParams& params);
|
||||
LogisticRegression();
|
||||
LogisticRegression(cv::InputArray data_ip, cv::InputArray labels_ip, const LogisticRegressionParams& params);
|
||||
virtual ~LogisticRegression();
|
||||
|
||||
virtual ~CvLR();
|
||||
enum { REG_L1 = 0, REG_L2 = 1};
|
||||
enum { BATCH = 0, MINI_BATCH = 1};
|
||||
|
||||
enum { REG_L1=0, REG_L2 = 1};
|
||||
enum { BATCH, MINI_BATCH};
|
||||
virtual bool train(cv::InputArray data_ip, cv::InputArray label_ip);
|
||||
virtual void predict( cv::InputArray data, cv::OutputArray predicted_labels ) const;
|
||||
|
||||
virtual void save(std::string filepath) const;
|
||||
virtual void load(const std::string filepath);
|
||||
|
||||
virtual bool train(const cv::Mat& data, const cv::Mat& labels);//, const CvLR_TrainParams& params);
|
||||
|
||||
virtual float predict(const cv::Mat& data, cv::Mat& predicted_labels);
|
||||
virtual float predict(const cv::Mat& data);
|
||||
|
||||
virtual void write( CvFileStorage* storage, const char* name ) const;
|
||||
virtual void read( CvFileStorage* storage, CvFileNode* node );
|
||||
|
||||
virtual void clear();
|
||||
|
||||
virtual cv::Mat get_learnt_mat();
|
||||
cv::Mat get_learnt_thetas() const;
|
||||
|
||||
protected:
|
||||
|
||||
LogisticRegressionParams params;
|
||||
cv::Mat learnt_thetas;
|
||||
CvLR_TrainParams params;
|
||||
|
||||
std::string default_model_name;
|
||||
std::map<int, int> forward_mapper;
|
||||
std::map<int, int> reverse_mapper;
|
||||
|
||||
virtual bool set_default_params();
|
||||
virtual cv::Mat calc_sigmoid(const cv::Mat& data);
|
||||
|
||||
virtual double compute_cost(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta);
|
||||
virtual cv::Mat compute_batch_gradient(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta);
|
||||
virtual cv::Mat compute_mini_batch_gradient(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta);
|
||||
|
||||
virtual std::map<int, int> get_label_map(const cv::Mat& labels);
|
||||
|
||||
virtual bool set_label_map(const cv::Mat& labels);
|
||||
virtual cv::Mat remap_labels(const cv::Mat& labels, const std::map<int, int> lmap);
|
||||
|
||||
//cv::Mat Mapper;
|
||||
|
||||
cv::Mat labels_o;
|
||||
cv::Mat labels_n;
|
||||
|
||||
static cv::Mat calc_sigmoid(const cv::Mat& data);
|
||||
|
||||
virtual double compute_cost(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta);
|
||||
virtual cv::Mat compute_batch_gradient(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta);
|
||||
virtual cv::Mat compute_mini_batch_gradient(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta);
|
||||
virtual bool set_label_map(const cv::Mat& labels);
|
||||
static cv::Mat remap_labels(const cv::Mat& labels, const std::map<int, int>& lmap);
|
||||
|
||||
virtual void write(FileStorage& fs) const;
|
||||
virtual void read(const FileNode& fn);
|
||||
virtual void clear();
|
||||
|
||||
};
|
||||
}// namespace cv
|
||||
|
||||
/****************************************************************************************\
|
||||
* Auxilary functions declarations *
|
||||
|
Loading…
x
Reference in New Issue
Block a user