Reworked ML logistic regression implementation, initial version
This commit is contained in:
@@ -571,56 +571,43 @@ public:
|
||||
/****************************************************************************************\
|
||||
* Logistic Regression *
|
||||
\****************************************************************************************/
|
||||
struct CV_EXPORTS LogisticRegressionParams
|
||||
{
|
||||
double alpha;
|
||||
int num_iters;
|
||||
int norm;
|
||||
int regularized;
|
||||
int train_method;
|
||||
int mini_batch_size;
|
||||
cv::TermCriteria term_crit;
|
||||
|
||||
LogisticRegressionParams();
|
||||
LogisticRegressionParams(double learning_rate, int iters, int train_method, int normlization, int reg, int mini_batch_size);
|
||||
};
|
||||
|
||||
class CV_EXPORTS LogisticRegression
|
||||
class CV_EXPORTS LogisticRegression : public StatModel
|
||||
{
|
||||
public:
|
||||
LogisticRegression( const LogisticRegressionParams& params = LogisticRegressionParams());
|
||||
LogisticRegression(cv::InputArray data_ip, cv::InputArray labels_ip, const LogisticRegressionParams& params);
|
||||
virtual ~LogisticRegression();
|
||||
class CV_EXPORTS Params
|
||||
{
|
||||
public:
|
||||
Params(double learning_rate = 0.001,
|
||||
int iters = 1000,
|
||||
int method = LogisticRegression::BATCH,
|
||||
int normlization = LogisticRegression::REG_L2,
|
||||
int reg = 1,
|
||||
int batch_size = 1);
|
||||
double alpha;
|
||||
int num_iters;
|
||||
int norm;
|
||||
int regularized;
|
||||
int train_method;
|
||||
int mini_batch_size;
|
||||
cv::TermCriteria term_crit;
|
||||
};
|
||||
|
||||
enum { REG_L1 = 0, REG_L2 = 1};
|
||||
enum { BATCH = 0, MINI_BATCH = 1};
|
||||
|
||||
virtual bool train(cv::InputArray data_ip, cv::InputArray label_ip);
|
||||
virtual void predict( cv::InputArray data, cv::OutputArray predicted_labels ) const;
|
||||
// Algorithm interface
|
||||
virtual void write( FileStorage &fs ) const = 0;
|
||||
virtual void read( const FileNode &fn ) = 0;
|
||||
|
||||
virtual void write(FileStorage& fs) const;
|
||||
virtual void read(const FileNode& fn);
|
||||
// StatModel interface
|
||||
virtual bool train( const Ptr<TrainData>& trainData, int flags=0 ) = 0;
|
||||
virtual float predict( InputArray samples, OutputArray results=noArray(), int flags=0 ) const = 0;
|
||||
virtual void clear() = 0;
|
||||
|
||||
const cv::Mat get_learnt_thetas() const;
|
||||
virtual void clear();
|
||||
virtual Mat get_learnt_thetas() const = 0;
|
||||
|
||||
protected:
|
||||
|
||||
LogisticRegressionParams params;
|
||||
cv::Mat learnt_thetas;
|
||||
std::string default_model_name;
|
||||
std::map<int, int> forward_mapper;
|
||||
std::map<int, int> reverse_mapper;
|
||||
|
||||
cv::Mat labels_o;
|
||||
cv::Mat labels_n;
|
||||
|
||||
static cv::Mat calc_sigmoid(const cv::Mat& data);
|
||||
virtual double compute_cost(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta);
|
||||
virtual cv::Mat compute_batch_gradient(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta);
|
||||
virtual cv::Mat compute_mini_batch_gradient(const cv::Mat& data, const cv::Mat& labels, const cv::Mat& init_theta);
|
||||
virtual bool set_label_map(const cv::Mat& labels);
|
||||
static cv::Mat remap_labels(const cv::Mat& labels, const std::map<int, int>& lmap);
|
||||
static Ptr<LogisticRegression> create( const Params& params = Params() );
|
||||
};
|
||||
|
||||
/****************************************************************************************\
|
||||
|
Reference in New Issue
Block a user