Merge pull request #5964 from amroamroamro:fix_lr

This commit is contained in:
Alexander Alekhin 2016-01-14 12:08:53 +00:00
commit a1d7e38adb

View File

@ -96,11 +96,11 @@ public:
CV_IMPL_PROPERTY(TermCriteria, TermCriteria, params.term_crit) CV_IMPL_PROPERTY(TermCriteria, TermCriteria, params.term_crit)
virtual bool train( const Ptr<TrainData>& trainData, int=0 ); virtual bool train( const Ptr<TrainData>& trainData, int=0 );
virtual float predict(InputArray samples, OutputArray results, int) const; virtual float predict(InputArray samples, OutputArray results, int flags=0) const;
virtual void clear(); virtual void clear();
virtual void write(FileStorage& fs) const; virtual void write(FileStorage& fs) const;
virtual void read(const FileNode& fn); virtual void read(const FileNode& fn);
virtual Mat get_learnt_thetas() const; virtual Mat get_learnt_thetas() const { return learnt_thetas; }
virtual int getVarCount() const { return learnt_thetas.cols; } virtual int getVarCount() const { return learnt_thetas.cols; }
virtual bool isTrained() const { return !learnt_thetas.empty(); } virtual bool isTrained() const { return !learnt_thetas.empty(); }
virtual bool isClassifier() const { return true; } virtual bool isClassifier() const { return true; }
@ -129,57 +129,48 @@ Ptr<LogisticRegression> LogisticRegression::create()
bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int) bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
{ {
// return value
bool ok = false;
clear(); clear();
Mat _data_i = trainData->getSamples(); Mat _data_i = trainData->getSamples();
Mat _labels_i = trainData->getResponses(); Mat _labels_i = trainData->getResponses();
// check size and type of training data
CV_Assert( !_labels_i.empty() && !_data_i.empty()); CV_Assert( !_labels_i.empty() && !_data_i.empty());
// check the number of columns
if(_labels_i.cols != 1) if(_labels_i.cols != 1)
{ {
CV_Error( CV_StsBadArg, "_labels_i should be a column matrix" ); CV_Error( CV_StsBadArg, "labels should be a column matrix" );
} }
if(_data_i.type() != CV_32FC1 || _labels_i.type() != CV_32FC1)
// check data type.
// data should be of floating type CV_32FC1
if((_data_i.type() != CV_32FC1) || (_labels_i.type() != CV_32FC1))
{ {
CV_Error( CV_StsBadArg, "data and labels must be a floating point matrix" ); CV_Error( CV_StsBadArg, "data and labels must be a floating point matrix" );
} }
if(_labels_i.rows != _data_i.rows)
{
CV_Error( CV_StsBadArg, "number of rows in data and labels should be equal" );
}
bool ok = false; // class labels
Mat labels;
set_label_map(_labels_i); set_label_map(_labels_i);
Mat labels_l = remap_labels(_labels_i, this->forward_mapper);
int num_classes = (int) this->forward_mapper.size(); int num_classes = (int) this->forward_mapper.size();
// add a column of ones
Mat data_t;
hconcat( cv::Mat::ones( _data_i.rows, 1, CV_32F ), _data_i, data_t );
if(num_classes < 2) if(num_classes < 2)
{ {
CV_Error( CV_StsBadArg, "data should have atleast 2 classes" ); CV_Error( CV_StsBadArg, "data should have atleast 2 classes" );
} }
if(_labels_i.rows != _data_i.rows) // add a column of ones to the data (bias/intercept term)
{ Mat data_t;
CV_Error( CV_StsBadArg, "number of rows in data and labels should be the equal" ); hconcat( cv::Mat::ones( _data_i.rows, 1, CV_32F ), _data_i, data_t );
}
// coefficient matrix (zero-initialized)
Mat thetas = Mat::zeros(num_classes, data_t.cols, CV_32F); Mat thetas;
Mat init_theta = Mat::zeros(data_t.cols, 1, CV_32F); Mat init_theta = Mat::zeros(data_t.cols, 1, CV_32F);
Mat labels_l = remap_labels(_labels_i, this->forward_mapper); // fit the model (handles binary and multiclass cases)
Mat new_local_labels;
int ii=0;
Mat new_theta; Mat new_theta;
Mat labels;
if(num_classes == 2) if(num_classes == 2)
{ {
labels_l.convertTo(labels, CV_32F); labels_l.convertTo(labels, CV_32F);
@ -193,12 +184,14 @@ bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
{ {
/* take each class and rename classes you will get a theta per class /* take each class and rename classes you will get a theta per class
as in multi class class scenario, we will have n thetas for n classes */ as in multi class class scenario, we will have n thetas for n classes */
ii = 0; thetas.create(num_classes, data_t.cols, CV_32F);
Mat labels_binary;
int ii = 0;
for(map<int,int>::iterator it = this->forward_mapper.begin(); it != this->forward_mapper.end(); ++it) for(map<int,int>::iterator it = this->forward_mapper.begin(); it != this->forward_mapper.end(); ++it)
{ {
new_local_labels = (labels_l == it->second)/255; // one-vs-rest (OvR) scheme
new_local_labels.convertTo(labels, CV_32F); labels_binary = (labels_l == it->second)/255;
labels_binary.convertTo(labels, CV_32F);
if(this->params.train_method == LogisticRegression::BATCH) if(this->params.train_method == LogisticRegression::BATCH)
new_theta = batch_gradient_descent(data_t, labels, init_theta); new_theta = batch_gradient_descent(data_t, labels, init_theta);
else else
@ -208,38 +201,28 @@ bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
} }
} }
// check that the estimates are stable and finite
this->learnt_thetas = thetas.clone(); this->learnt_thetas = thetas.clone();
if( cvIsNaN( (double)sum(this->learnt_thetas)[0] ) ) if( cvIsNaN( (double)sum(this->learnt_thetas)[0] ) )
{ {
CV_Error( CV_StsBadArg, "check training parameters. Invalid training classifier" ); CV_Error( CV_StsBadArg, "check training parameters. Invalid training classifier" );
} }
// success
ok = true; ok = true;
return ok; return ok;
} }
float LogisticRegressionImpl::predict(InputArray samples, OutputArray results, int flags) const float LogisticRegressionImpl::predict(InputArray samples, OutputArray results, int flags) const
{ {
/* returns a class of the predicted class
class names can be 1,2,3,4, .... etc */
Mat thetas, data, pred_labs;
data = samples.getMat();
const bool rawout = flags & StatModel::RAW_OUTPUT;
// check if learnt_mats array is populated // check if learnt_mats array is populated
if(this->learnt_thetas.total()<=0) if(!this->isTrained())
{ {
CV_Error( CV_StsBadArg, "classifier should be trained first" ); CV_Error( CV_StsBadArg, "classifier should be trained first" );
} }
if(data.type() != CV_32F)
{
CV_Error( CV_StsBadArg, "data must be of floating type" );
}
// add a column of ones
Mat data_t;
hconcat( cv::Mat::ones( data.rows, 1, CV_32F ), data, data_t );
// coefficient matrix
Mat thetas;
if ( learnt_thetas.type() == CV_32F ) if ( learnt_thetas.type() == CV_32F )
{ {
thetas = learnt_thetas; thetas = learnt_thetas;
@ -248,52 +231,65 @@ float LogisticRegressionImpl::predict(InputArray samples, OutputArray results, i
{ {
this->learnt_thetas.convertTo( thetas, CV_32F ); this->learnt_thetas.convertTo( thetas, CV_32F );
} }
CV_Assert(thetas.rows > 0); CV_Assert(thetas.rows > 0);
double min_val; // data samples
double max_val; Mat data = samples.getMat();
if(data.type() != CV_32F)
{
CV_Error( CV_StsBadArg, "data must be of floating type" );
}
Point min_loc; // add a column of ones to the data (bias/intercept term)
Point max_loc; Mat data_t;
hconcat( cv::Mat::ones( data.rows, 1, CV_32F ), data, data_t );
CV_Assert(data_t.cols == thetas.cols);
Mat labels; // predict class labels for samples (handles binary and multiclass cases)
Mat labels_c; Mat labels_c;
Mat pred_m;
Mat temp_pred; Mat temp_pred;
Mat pred_m = Mat::zeros(data_t.rows, thetas.rows, data.type());
if(thetas.rows == 1) if(thetas.rows == 1)
{ {
// apply sigmoid function
temp_pred = calc_sigmoid(data_t * thetas.t()); temp_pred = calc_sigmoid(data_t * thetas.t());
CV_Assert(temp_pred.cols==1); CV_Assert(temp_pred.cols==1);
pred_m = temp_pred.clone();
// if greater than 0.5, predict class 0 or predict class 1 // if greater than 0.5, predict class 0 or predict class 1
temp_pred = (temp_pred>0.5)/255; temp_pred = (temp_pred > 0.5f) / 255;
temp_pred.convertTo(labels_c, CV_32S); temp_pred.convertTo(labels_c, CV_32S);
} }
else else
{ {
// apply sigmoid function
pred_m.create(data_t.rows, thetas.rows, data.type());
for(int i = 0; i < thetas.rows; i++) for(int i = 0; i < thetas.rows; i++)
{ {
temp_pred = calc_sigmoid(data_t * thetas.row(i).t()); temp_pred = calc_sigmoid(data_t * thetas.row(i).t());
vconcat(temp_pred, pred_m.col(i)); vconcat(temp_pred, pred_m.col(i));
} }
// predict class with the maximum output
Point max_loc;
Mat labels;
for(int i = 0; i < pred_m.rows; i++) for(int i = 0; i < pred_m.rows; i++)
{ {
temp_pred = pred_m.row(i); temp_pred = pred_m.row(i);
minMaxLoc( temp_pred, &min_val, &max_val, &min_loc, &max_loc, Mat() ); minMaxLoc( temp_pred, NULL, NULL, NULL, &max_loc );
labels.push_back(max_loc.x); labels.push_back(max_loc.x);
} }
labels.convertTo(labels_c, CV_32S); labels.convertTo(labels_c, CV_32S);
} }
pred_labs = remap_labels(labels_c, this->reverse_mapper);
// convert pred_labs to integer type // return label of the predicted class. class names can be 1,2,3,...
Mat pred_labs = remap_labels(labels_c, this->reverse_mapper);
pred_labs.convertTo(pred_labs, CV_32S); pred_labs.convertTo(pred_labs, CV_32S);
// return either the labels or the raw output // return either the labels or the raw output
if ( results.needed() ) if ( results.needed() )
{ {
if ( rawout ) if ( flags & StatModel::RAW_OUTPUT )
{ {
pred_m.copyTo( results ); pred_m.copyTo( results );
} }
@ -303,7 +299,7 @@ float LogisticRegressionImpl::predict(InputArray samples, OutputArray results, i
} }
} }
return ( pred_labs.empty() ? 0.f : (float) pred_labs.at< int >( 0 ) ); return ( pred_labs.empty() ? 0.f : static_cast<float>(pred_labs.at<int>(0)) );
} }
Mat LogisticRegressionImpl::calc_sigmoid(const Mat& data) const Mat LogisticRegressionImpl::calc_sigmoid(const Mat& data) const
@ -595,11 +591,6 @@ void LogisticRegressionImpl::read(const FileNode& fn)
} }
} }
Mat LogisticRegressionImpl::get_learnt_thetas() const
{
return this->learnt_thetas;
}
} }
} }