Merge pull request #5964 from amroamroamro:fix_lr
This commit is contained in:
commit
a1d7e38adb
@ -96,11 +96,11 @@ public:
|
|||||||
CV_IMPL_PROPERTY(TermCriteria, TermCriteria, params.term_crit)
|
CV_IMPL_PROPERTY(TermCriteria, TermCriteria, params.term_crit)
|
||||||
|
|
||||||
virtual bool train( const Ptr<TrainData>& trainData, int=0 );
|
virtual bool train( const Ptr<TrainData>& trainData, int=0 );
|
||||||
virtual float predict(InputArray samples, OutputArray results, int) const;
|
virtual float predict(InputArray samples, OutputArray results, int flags=0) const;
|
||||||
virtual void clear();
|
virtual void clear();
|
||||||
virtual void write(FileStorage& fs) const;
|
virtual void write(FileStorage& fs) const;
|
||||||
virtual void read(const FileNode& fn);
|
virtual void read(const FileNode& fn);
|
||||||
virtual Mat get_learnt_thetas() const;
|
virtual Mat get_learnt_thetas() const { return learnt_thetas; }
|
||||||
virtual int getVarCount() const { return learnt_thetas.cols; }
|
virtual int getVarCount() const { return learnt_thetas.cols; }
|
||||||
virtual bool isTrained() const { return !learnt_thetas.empty(); }
|
virtual bool isTrained() const { return !learnt_thetas.empty(); }
|
||||||
virtual bool isClassifier() const { return true; }
|
virtual bool isClassifier() const { return true; }
|
||||||
@ -129,57 +129,48 @@ Ptr<LogisticRegression> LogisticRegression::create()
|
|||||||
|
|
||||||
bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
|
bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
|
||||||
{
|
{
|
||||||
|
// return value
|
||||||
|
bool ok = false;
|
||||||
|
|
||||||
clear();
|
clear();
|
||||||
Mat _data_i = trainData->getSamples();
|
Mat _data_i = trainData->getSamples();
|
||||||
Mat _labels_i = trainData->getResponses();
|
Mat _labels_i = trainData->getResponses();
|
||||||
|
|
||||||
|
// check size and type of training data
|
||||||
CV_Assert( !_labels_i.empty() && !_data_i.empty());
|
CV_Assert( !_labels_i.empty() && !_data_i.empty());
|
||||||
|
|
||||||
// check the number of columns
|
|
||||||
if(_labels_i.cols != 1)
|
if(_labels_i.cols != 1)
|
||||||
{
|
{
|
||||||
CV_Error( CV_StsBadArg, "_labels_i should be a column matrix" );
|
CV_Error( CV_StsBadArg, "labels should be a column matrix" );
|
||||||
}
|
}
|
||||||
|
if(_data_i.type() != CV_32FC1 || _labels_i.type() != CV_32FC1)
|
||||||
// check data type.
|
|
||||||
// data should be of floating type CV_32FC1
|
|
||||||
|
|
||||||
if((_data_i.type() != CV_32FC1) || (_labels_i.type() != CV_32FC1))
|
|
||||||
{
|
{
|
||||||
CV_Error( CV_StsBadArg, "data and labels must be a floating point matrix" );
|
CV_Error( CV_StsBadArg, "data and labels must be a floating point matrix" );
|
||||||
}
|
}
|
||||||
|
if(_labels_i.rows != _data_i.rows)
|
||||||
|
{
|
||||||
|
CV_Error( CV_StsBadArg, "number of rows in data and labels should be equal" );
|
||||||
|
}
|
||||||
|
|
||||||
bool ok = false;
|
// class labels
|
||||||
|
|
||||||
Mat labels;
|
|
||||||
|
|
||||||
set_label_map(_labels_i);
|
set_label_map(_labels_i);
|
||||||
|
Mat labels_l = remap_labels(_labels_i, this->forward_mapper);
|
||||||
int num_classes = (int) this->forward_mapper.size();
|
int num_classes = (int) this->forward_mapper.size();
|
||||||
|
|
||||||
// add a column of ones
|
|
||||||
Mat data_t;
|
|
||||||
hconcat( cv::Mat::ones( _data_i.rows, 1, CV_32F ), _data_i, data_t );
|
|
||||||
|
|
||||||
if(num_classes < 2)
|
if(num_classes < 2)
|
||||||
{
|
{
|
||||||
CV_Error( CV_StsBadArg, "data should have atleast 2 classes" );
|
CV_Error( CV_StsBadArg, "data should have atleast 2 classes" );
|
||||||
}
|
}
|
||||||
|
|
||||||
if(_labels_i.rows != _data_i.rows)
|
// add a column of ones to the data (bias/intercept term)
|
||||||
{
|
Mat data_t;
|
||||||
CV_Error( CV_StsBadArg, "number of rows in data and labels should be the equal" );
|
hconcat( cv::Mat::ones( _data_i.rows, 1, CV_32F ), _data_i, data_t );
|
||||||
}
|
|
||||||
|
|
||||||
|
// coefficient matrix (zero-initialized)
|
||||||
Mat thetas = Mat::zeros(num_classes, data_t.cols, CV_32F);
|
Mat thetas;
|
||||||
Mat init_theta = Mat::zeros(data_t.cols, 1, CV_32F);
|
Mat init_theta = Mat::zeros(data_t.cols, 1, CV_32F);
|
||||||
|
|
||||||
Mat labels_l = remap_labels(_labels_i, this->forward_mapper);
|
// fit the model (handles binary and multiclass cases)
|
||||||
Mat new_local_labels;
|
|
||||||
|
|
||||||
int ii=0;
|
|
||||||
Mat new_theta;
|
Mat new_theta;
|
||||||
|
Mat labels;
|
||||||
if(num_classes == 2)
|
if(num_classes == 2)
|
||||||
{
|
{
|
||||||
labels_l.convertTo(labels, CV_32F);
|
labels_l.convertTo(labels, CV_32F);
|
||||||
@ -193,12 +184,14 @@ bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
|
|||||||
{
|
{
|
||||||
/* take each class and rename classes you will get a theta per class
|
/* take each class and rename classes you will get a theta per class
|
||||||
as in multi class class scenario, we will have n thetas for n classes */
|
as in multi class class scenario, we will have n thetas for n classes */
|
||||||
ii = 0;
|
thetas.create(num_classes, data_t.cols, CV_32F);
|
||||||
|
Mat labels_binary;
|
||||||
|
int ii = 0;
|
||||||
for(map<int,int>::iterator it = this->forward_mapper.begin(); it != this->forward_mapper.end(); ++it)
|
for(map<int,int>::iterator it = this->forward_mapper.begin(); it != this->forward_mapper.end(); ++it)
|
||||||
{
|
{
|
||||||
new_local_labels = (labels_l == it->second)/255;
|
// one-vs-rest (OvR) scheme
|
||||||
new_local_labels.convertTo(labels, CV_32F);
|
labels_binary = (labels_l == it->second)/255;
|
||||||
|
labels_binary.convertTo(labels, CV_32F);
|
||||||
if(this->params.train_method == LogisticRegression::BATCH)
|
if(this->params.train_method == LogisticRegression::BATCH)
|
||||||
new_theta = batch_gradient_descent(data_t, labels, init_theta);
|
new_theta = batch_gradient_descent(data_t, labels, init_theta);
|
||||||
else
|
else
|
||||||
@ -208,38 +201,28 @@ bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check that the estimates are stable and finite
|
||||||
this->learnt_thetas = thetas.clone();
|
this->learnt_thetas = thetas.clone();
|
||||||
if( cvIsNaN( (double)sum(this->learnt_thetas)[0] ) )
|
if( cvIsNaN( (double)sum(this->learnt_thetas)[0] ) )
|
||||||
{
|
{
|
||||||
CV_Error( CV_StsBadArg, "check training parameters. Invalid training classifier" );
|
CV_Error( CV_StsBadArg, "check training parameters. Invalid training classifier" );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// success
|
||||||
ok = true;
|
ok = true;
|
||||||
return ok;
|
return ok;
|
||||||
}
|
}
|
||||||
|
|
||||||
float LogisticRegressionImpl::predict(InputArray samples, OutputArray results, int flags) const
|
float LogisticRegressionImpl::predict(InputArray samples, OutputArray results, int flags) const
|
||||||
{
|
{
|
||||||
/* returns a class of the predicted class
|
|
||||||
class names can be 1,2,3,4, .... etc */
|
|
||||||
Mat thetas, data, pred_labs;
|
|
||||||
data = samples.getMat();
|
|
||||||
|
|
||||||
const bool rawout = flags & StatModel::RAW_OUTPUT;
|
|
||||||
|
|
||||||
// check if learnt_mats array is populated
|
// check if learnt_mats array is populated
|
||||||
if(this->learnt_thetas.total()<=0)
|
if(!this->isTrained())
|
||||||
{
|
{
|
||||||
CV_Error( CV_StsBadArg, "classifier should be trained first" );
|
CV_Error( CV_StsBadArg, "classifier should be trained first" );
|
||||||
}
|
}
|
||||||
if(data.type() != CV_32F)
|
|
||||||
{
|
|
||||||
CV_Error( CV_StsBadArg, "data must be of floating type" );
|
|
||||||
}
|
|
||||||
|
|
||||||
// add a column of ones
|
|
||||||
Mat data_t;
|
|
||||||
hconcat( cv::Mat::ones( data.rows, 1, CV_32F ), data, data_t );
|
|
||||||
|
|
||||||
|
// coefficient matrix
|
||||||
|
Mat thetas;
|
||||||
if ( learnt_thetas.type() == CV_32F )
|
if ( learnt_thetas.type() == CV_32F )
|
||||||
{
|
{
|
||||||
thetas = learnt_thetas;
|
thetas = learnt_thetas;
|
||||||
@ -248,52 +231,65 @@ float LogisticRegressionImpl::predict(InputArray samples, OutputArray results, i
|
|||||||
{
|
{
|
||||||
this->learnt_thetas.convertTo( thetas, CV_32F );
|
this->learnt_thetas.convertTo( thetas, CV_32F );
|
||||||
}
|
}
|
||||||
|
|
||||||
CV_Assert(thetas.rows > 0);
|
CV_Assert(thetas.rows > 0);
|
||||||
|
|
||||||
double min_val;
|
// data samples
|
||||||
double max_val;
|
Mat data = samples.getMat();
|
||||||
|
if(data.type() != CV_32F)
|
||||||
|
{
|
||||||
|
CV_Error( CV_StsBadArg, "data must be of floating type" );
|
||||||
|
}
|
||||||
|
|
||||||
Point min_loc;
|
// add a column of ones to the data (bias/intercept term)
|
||||||
Point max_loc;
|
Mat data_t;
|
||||||
|
hconcat( cv::Mat::ones( data.rows, 1, CV_32F ), data, data_t );
|
||||||
|
CV_Assert(data_t.cols == thetas.cols);
|
||||||
|
|
||||||
Mat labels;
|
// predict class labels for samples (handles binary and multiclass cases)
|
||||||
Mat labels_c;
|
Mat labels_c;
|
||||||
|
Mat pred_m;
|
||||||
Mat temp_pred;
|
Mat temp_pred;
|
||||||
Mat pred_m = Mat::zeros(data_t.rows, thetas.rows, data.type());
|
|
||||||
|
|
||||||
if(thetas.rows == 1)
|
if(thetas.rows == 1)
|
||||||
{
|
{
|
||||||
temp_pred = calc_sigmoid(data_t*thetas.t());
|
// apply sigmoid function
|
||||||
|
temp_pred = calc_sigmoid(data_t * thetas.t());
|
||||||
CV_Assert(temp_pred.cols==1);
|
CV_Assert(temp_pred.cols==1);
|
||||||
|
pred_m = temp_pred.clone();
|
||||||
|
|
||||||
// if greater than 0.5, predict class 0 or predict class 1
|
// if greater than 0.5, predict class 0 or predict class 1
|
||||||
temp_pred = (temp_pred>0.5)/255;
|
temp_pred = (temp_pred > 0.5f) / 255;
|
||||||
temp_pred.convertTo(labels_c, CV_32S);
|
temp_pred.convertTo(labels_c, CV_32S);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
for(int i = 0;i<thetas.rows;i++)
|
// apply sigmoid function
|
||||||
|
pred_m.create(data_t.rows, thetas.rows, data.type());
|
||||||
|
for(int i = 0; i < thetas.rows; i++)
|
||||||
{
|
{
|
||||||
temp_pred = calc_sigmoid(data_t * thetas.row(i).t());
|
temp_pred = calc_sigmoid(data_t * thetas.row(i).t());
|
||||||
vconcat(temp_pred, pred_m.col(i));
|
vconcat(temp_pred, pred_m.col(i));
|
||||||
}
|
}
|
||||||
for(int i = 0;i<pred_m.rows;i++)
|
|
||||||
|
// predict class with the maximum output
|
||||||
|
Point max_loc;
|
||||||
|
Mat labels;
|
||||||
|
for(int i = 0; i < pred_m.rows; i++)
|
||||||
{
|
{
|
||||||
temp_pred = pred_m.row(i);
|
temp_pred = pred_m.row(i);
|
||||||
minMaxLoc( temp_pred, &min_val, &max_val, &min_loc, &max_loc, Mat() );
|
minMaxLoc( temp_pred, NULL, NULL, NULL, &max_loc );
|
||||||
labels.push_back(max_loc.x);
|
labels.push_back(max_loc.x);
|
||||||
}
|
}
|
||||||
labels.convertTo(labels_c, CV_32S);
|
labels.convertTo(labels_c, CV_32S);
|
||||||
}
|
}
|
||||||
pred_labs = remap_labels(labels_c, this->reverse_mapper);
|
|
||||||
// convert pred_labs to integer type
|
// return label of the predicted class. class names can be 1,2,3,...
|
||||||
|
Mat pred_labs = remap_labels(labels_c, this->reverse_mapper);
|
||||||
pred_labs.convertTo(pred_labs, CV_32S);
|
pred_labs.convertTo(pred_labs, CV_32S);
|
||||||
|
|
||||||
// return either the labels or the raw output
|
// return either the labels or the raw output
|
||||||
if ( results.needed() )
|
if ( results.needed() )
|
||||||
{
|
{
|
||||||
if ( rawout )
|
if ( flags & StatModel::RAW_OUTPUT )
|
||||||
{
|
{
|
||||||
pred_m.copyTo( results );
|
pred_m.copyTo( results );
|
||||||
}
|
}
|
||||||
@ -303,7 +299,7 @@ float LogisticRegressionImpl::predict(InputArray samples, OutputArray results, i
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ( pred_labs.empty() ? 0.f : (float) pred_labs.at< int >( 0 ) );
|
return ( pred_labs.empty() ? 0.f : static_cast<float>(pred_labs.at<int>(0)) );
|
||||||
}
|
}
|
||||||
|
|
||||||
Mat LogisticRegressionImpl::calc_sigmoid(const Mat& data) const
|
Mat LogisticRegressionImpl::calc_sigmoid(const Mat& data) const
|
||||||
@ -595,11 +591,6 @@ void LogisticRegressionImpl::read(const FileNode& fn)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Mat LogisticRegressionImpl::get_learnt_thetas() const
|
|
||||||
{
|
|
||||||
return this->learnt_thetas;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user