Merge pull request #5964 from amroamroamro:fix_lr

This commit is contained in:
Alexander Alekhin 2016-01-14 12:08:53 +00:00
commit a1d7e38adb

View File

@ -96,11 +96,11 @@ public:
CV_IMPL_PROPERTY(TermCriteria, TermCriteria, params.term_crit)
virtual bool train( const Ptr<TrainData>& trainData, int=0 );
virtual float predict(InputArray samples, OutputArray results, int) const;
virtual float predict(InputArray samples, OutputArray results, int flags=0) const;
virtual void clear();
virtual void write(FileStorage& fs) const;
virtual void read(const FileNode& fn);
virtual Mat get_learnt_thetas() const;
virtual Mat get_learnt_thetas() const { return learnt_thetas; }
virtual int getVarCount() const { return learnt_thetas.cols; }
virtual bool isTrained() const { return !learnt_thetas.empty(); }
virtual bool isClassifier() const { return true; }
@ -129,57 +129,48 @@ Ptr<LogisticRegression> LogisticRegression::create()
bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
{
// return value
bool ok = false;
clear();
Mat _data_i = trainData->getSamples();
Mat _labels_i = trainData->getResponses();
// check size and type of training data
CV_Assert( !_labels_i.empty() && !_data_i.empty());
// check the number of columns
if(_labels_i.cols != 1)
{
CV_Error( CV_StsBadArg, "_labels_i should be a column matrix" );
CV_Error( CV_StsBadArg, "labels should be a column matrix" );
}
// check data type.
// data should be of floating type CV_32FC1
if((_data_i.type() != CV_32FC1) || (_labels_i.type() != CV_32FC1))
if(_data_i.type() != CV_32FC1 || _labels_i.type() != CV_32FC1)
{
CV_Error( CV_StsBadArg, "data and labels must be a floating point matrix" );
}
if(_labels_i.rows != _data_i.rows)
{
CV_Error( CV_StsBadArg, "number of rows in data and labels should be equal" );
}
bool ok = false;
Mat labels;
// class labels
set_label_map(_labels_i);
Mat labels_l = remap_labels(_labels_i, this->forward_mapper);
int num_classes = (int) this->forward_mapper.size();
// add a column of ones
Mat data_t;
hconcat( cv::Mat::ones( _data_i.rows, 1, CV_32F ), _data_i, data_t );
if(num_classes < 2)
{
CV_Error( CV_StsBadArg, "data should have atleast 2 classes" );
}
if(_labels_i.rows != _data_i.rows)
{
CV_Error( CV_StsBadArg, "number of rows in data and labels should be the equal" );
}
// add a column of ones to the data (bias/intercept term)
Mat data_t;
hconcat( cv::Mat::ones( _data_i.rows, 1, CV_32F ), _data_i, data_t );
Mat thetas = Mat::zeros(num_classes, data_t.cols, CV_32F);
// coefficient matrix (zero-initialized)
Mat thetas;
Mat init_theta = Mat::zeros(data_t.cols, 1, CV_32F);
Mat labels_l = remap_labels(_labels_i, this->forward_mapper);
Mat new_local_labels;
int ii=0;
// fit the model (handles binary and multiclass cases)
Mat new_theta;
Mat labels;
if(num_classes == 2)
{
labels_l.convertTo(labels, CV_32F);
@ -193,12 +184,14 @@ bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
{
/* take each class and rename classes you will get a theta per class
as in multi class class scenario, we will have n thetas for n classes */
ii = 0;
thetas.create(num_classes, data_t.cols, CV_32F);
Mat labels_binary;
int ii = 0;
for(map<int,int>::iterator it = this->forward_mapper.begin(); it != this->forward_mapper.end(); ++it)
{
new_local_labels = (labels_l == it->second)/255;
new_local_labels.convertTo(labels, CV_32F);
// one-vs-rest (OvR) scheme
labels_binary = (labels_l == it->second)/255;
labels_binary.convertTo(labels, CV_32F);
if(this->params.train_method == LogisticRegression::BATCH)
new_theta = batch_gradient_descent(data_t, labels, init_theta);
else
@ -208,38 +201,28 @@ bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
}
}
// check that the estimates are stable and finite
this->learnt_thetas = thetas.clone();
if( cvIsNaN( (double)sum(this->learnt_thetas)[0] ) )
{
CV_Error( CV_StsBadArg, "check training parameters. Invalid training classifier" );
}
// success
ok = true;
return ok;
}
float LogisticRegressionImpl::predict(InputArray samples, OutputArray results, int flags) const
{
/* returns a class of the predicted class
class names can be 1,2,3,4, .... etc */
Mat thetas, data, pred_labs;
data = samples.getMat();
const bool rawout = flags & StatModel::RAW_OUTPUT;
// check if learnt_mats array is populated
if(this->learnt_thetas.total()<=0)
if(!this->isTrained())
{
CV_Error( CV_StsBadArg, "classifier should be trained first" );
}
if(data.type() != CV_32F)
{
CV_Error( CV_StsBadArg, "data must be of floating type" );
}
// add a column of ones
Mat data_t;
hconcat( cv::Mat::ones( data.rows, 1, CV_32F ), data, data_t );
// coefficient matrix
Mat thetas;
if ( learnt_thetas.type() == CV_32F )
{
thetas = learnt_thetas;
@ -248,52 +231,65 @@ float LogisticRegressionImpl::predict(InputArray samples, OutputArray results, i
{
this->learnt_thetas.convertTo( thetas, CV_32F );
}
CV_Assert(thetas.rows > 0);
double min_val;
double max_val;
// data samples
Mat data = samples.getMat();
if(data.type() != CV_32F)
{
CV_Error( CV_StsBadArg, "data must be of floating type" );
}
Point min_loc;
Point max_loc;
// add a column of ones to the data (bias/intercept term)
Mat data_t;
hconcat( cv::Mat::ones( data.rows, 1, CV_32F ), data, data_t );
CV_Assert(data_t.cols == thetas.cols);
Mat labels;
// predict class labels for samples (handles binary and multiclass cases)
Mat labels_c;
Mat pred_m;
Mat temp_pred;
Mat pred_m = Mat::zeros(data_t.rows, thetas.rows, data.type());
if(thetas.rows == 1)
{
temp_pred = calc_sigmoid(data_t*thetas.t());
// apply sigmoid function
temp_pred = calc_sigmoid(data_t * thetas.t());
CV_Assert(temp_pred.cols==1);
pred_m = temp_pred.clone();
// if greater than 0.5, predict class 0 or predict class 1
temp_pred = (temp_pred>0.5)/255;
temp_pred = (temp_pred > 0.5f) / 255;
temp_pred.convertTo(labels_c, CV_32S);
}
else
{
for(int i = 0;i<thetas.rows;i++)
// apply sigmoid function
pred_m.create(data_t.rows, thetas.rows, data.type());
for(int i = 0; i < thetas.rows; i++)
{
temp_pred = calc_sigmoid(data_t * thetas.row(i).t());
vconcat(temp_pred, pred_m.col(i));
}
for(int i = 0;i<pred_m.rows;i++)
// predict class with the maximum output
Point max_loc;
Mat labels;
for(int i = 0; i < pred_m.rows; i++)
{
temp_pred = pred_m.row(i);
minMaxLoc( temp_pred, &min_val, &max_val, &min_loc, &max_loc, Mat() );
minMaxLoc( temp_pred, NULL, NULL, NULL, &max_loc );
labels.push_back(max_loc.x);
}
labels.convertTo(labels_c, CV_32S);
}
pred_labs = remap_labels(labels_c, this->reverse_mapper);
// convert pred_labs to integer type
// return label of the predicted class. class names can be 1,2,3,...
Mat pred_labs = remap_labels(labels_c, this->reverse_mapper);
pred_labs.convertTo(pred_labs, CV_32S);
// return either the labels or the raw output
if ( results.needed() )
{
if ( rawout )
if ( flags & StatModel::RAW_OUTPUT )
{
pred_m.copyTo( results );
}
@ -303,7 +299,7 @@ float LogisticRegressionImpl::predict(InputArray samples, OutputArray results, i
}
}
return ( pred_labs.empty() ? 0.f : (float) pred_labs.at< int >( 0 ) );
return ( pred_labs.empty() ? 0.f : static_cast<float>(pred_labs.at<int>(0)) );
}
Mat LogisticRegressionImpl::calc_sigmoid(const Mat& data) const
@ -595,11 +591,6 @@ void LogisticRegressionImpl::read(const FileNode& fn)
}
}
Mat LogisticRegressionImpl::get_learnt_thetas() const
{
return this->learnt_thetas;
}
}
}