Merge pull request #5964 from amroamroamro:fix_lr
This commit is contained in:
commit
a1d7e38adb
@ -96,11 +96,11 @@ public:
|
||||
CV_IMPL_PROPERTY(TermCriteria, TermCriteria, params.term_crit)
|
||||
|
||||
virtual bool train( const Ptr<TrainData>& trainData, int=0 );
|
||||
virtual float predict(InputArray samples, OutputArray results, int) const;
|
||||
virtual float predict(InputArray samples, OutputArray results, int flags=0) const;
|
||||
virtual void clear();
|
||||
virtual void write(FileStorage& fs) const;
|
||||
virtual void read(const FileNode& fn);
|
||||
virtual Mat get_learnt_thetas() const;
|
||||
virtual Mat get_learnt_thetas() const { return learnt_thetas; }
|
||||
virtual int getVarCount() const { return learnt_thetas.cols; }
|
||||
virtual bool isTrained() const { return !learnt_thetas.empty(); }
|
||||
virtual bool isClassifier() const { return true; }
|
||||
@ -129,57 +129,48 @@ Ptr<LogisticRegression> LogisticRegression::create()
|
||||
|
||||
bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
|
||||
{
|
||||
// return value
|
||||
bool ok = false;
|
||||
|
||||
clear();
|
||||
Mat _data_i = trainData->getSamples();
|
||||
Mat _labels_i = trainData->getResponses();
|
||||
|
||||
// check size and type of training data
|
||||
CV_Assert( !_labels_i.empty() && !_data_i.empty());
|
||||
|
||||
// check the number of columns
|
||||
if(_labels_i.cols != 1)
|
||||
{
|
||||
CV_Error( CV_StsBadArg, "_labels_i should be a column matrix" );
|
||||
CV_Error( CV_StsBadArg, "labels should be a column matrix" );
|
||||
}
|
||||
|
||||
// check data type.
|
||||
// data should be of floating type CV_32FC1
|
||||
|
||||
if((_data_i.type() != CV_32FC1) || (_labels_i.type() != CV_32FC1))
|
||||
if(_data_i.type() != CV_32FC1 || _labels_i.type() != CV_32FC1)
|
||||
{
|
||||
CV_Error( CV_StsBadArg, "data and labels must be a floating point matrix" );
|
||||
}
|
||||
if(_labels_i.rows != _data_i.rows)
|
||||
{
|
||||
CV_Error( CV_StsBadArg, "number of rows in data and labels should be equal" );
|
||||
}
|
||||
|
||||
bool ok = false;
|
||||
|
||||
Mat labels;
|
||||
|
||||
// class labels
|
||||
set_label_map(_labels_i);
|
||||
Mat labels_l = remap_labels(_labels_i, this->forward_mapper);
|
||||
int num_classes = (int) this->forward_mapper.size();
|
||||
|
||||
// add a column of ones
|
||||
Mat data_t;
|
||||
hconcat( cv::Mat::ones( _data_i.rows, 1, CV_32F ), _data_i, data_t );
|
||||
|
||||
if(num_classes < 2)
|
||||
{
|
||||
CV_Error( CV_StsBadArg, "data should have atleast 2 classes" );
|
||||
}
|
||||
|
||||
if(_labels_i.rows != _data_i.rows)
|
||||
{
|
||||
CV_Error( CV_StsBadArg, "number of rows in data and labels should be the equal" );
|
||||
}
|
||||
// add a column of ones to the data (bias/intercept term)
|
||||
Mat data_t;
|
||||
hconcat( cv::Mat::ones( _data_i.rows, 1, CV_32F ), _data_i, data_t );
|
||||
|
||||
|
||||
Mat thetas = Mat::zeros(num_classes, data_t.cols, CV_32F);
|
||||
// coefficient matrix (zero-initialized)
|
||||
Mat thetas;
|
||||
Mat init_theta = Mat::zeros(data_t.cols, 1, CV_32F);
|
||||
|
||||
Mat labels_l = remap_labels(_labels_i, this->forward_mapper);
|
||||
Mat new_local_labels;
|
||||
|
||||
int ii=0;
|
||||
// fit the model (handles binary and multiclass cases)
|
||||
Mat new_theta;
|
||||
|
||||
Mat labels;
|
||||
if(num_classes == 2)
|
||||
{
|
||||
labels_l.convertTo(labels, CV_32F);
|
||||
@ -193,12 +184,14 @@ bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
|
||||
{
|
||||
/* take each class and rename classes you will get a theta per class
|
||||
as in multi class class scenario, we will have n thetas for n classes */
|
||||
ii = 0;
|
||||
|
||||
thetas.create(num_classes, data_t.cols, CV_32F);
|
||||
Mat labels_binary;
|
||||
int ii = 0;
|
||||
for(map<int,int>::iterator it = this->forward_mapper.begin(); it != this->forward_mapper.end(); ++it)
|
||||
{
|
||||
new_local_labels = (labels_l == it->second)/255;
|
||||
new_local_labels.convertTo(labels, CV_32F);
|
||||
// one-vs-rest (OvR) scheme
|
||||
labels_binary = (labels_l == it->second)/255;
|
||||
labels_binary.convertTo(labels, CV_32F);
|
||||
if(this->params.train_method == LogisticRegression::BATCH)
|
||||
new_theta = batch_gradient_descent(data_t, labels, init_theta);
|
||||
else
|
||||
@ -208,38 +201,28 @@ bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
|
||||
}
|
||||
}
|
||||
|
||||
// check that the estimates are stable and finite
|
||||
this->learnt_thetas = thetas.clone();
|
||||
if( cvIsNaN( (double)sum(this->learnt_thetas)[0] ) )
|
||||
{
|
||||
CV_Error( CV_StsBadArg, "check training parameters. Invalid training classifier" );
|
||||
}
|
||||
|
||||
// success
|
||||
ok = true;
|
||||
return ok;
|
||||
}
|
||||
|
||||
float LogisticRegressionImpl::predict(InputArray samples, OutputArray results, int flags) const
|
||||
{
|
||||
/* returns a class of the predicted class
|
||||
class names can be 1,2,3,4, .... etc */
|
||||
Mat thetas, data, pred_labs;
|
||||
data = samples.getMat();
|
||||
|
||||
const bool rawout = flags & StatModel::RAW_OUTPUT;
|
||||
|
||||
// check if learnt_mats array is populated
|
||||
if(this->learnt_thetas.total()<=0)
|
||||
if(!this->isTrained())
|
||||
{
|
||||
CV_Error( CV_StsBadArg, "classifier should be trained first" );
|
||||
}
|
||||
if(data.type() != CV_32F)
|
||||
{
|
||||
CV_Error( CV_StsBadArg, "data must be of floating type" );
|
||||
}
|
||||
|
||||
// add a column of ones
|
||||
Mat data_t;
|
||||
hconcat( cv::Mat::ones( data.rows, 1, CV_32F ), data, data_t );
|
||||
|
||||
// coefficient matrix
|
||||
Mat thetas;
|
||||
if ( learnt_thetas.type() == CV_32F )
|
||||
{
|
||||
thetas = learnt_thetas;
|
||||
@ -248,52 +231,65 @@ float LogisticRegressionImpl::predict(InputArray samples, OutputArray results, i
|
||||
{
|
||||
this->learnt_thetas.convertTo( thetas, CV_32F );
|
||||
}
|
||||
|
||||
CV_Assert(thetas.rows > 0);
|
||||
|
||||
double min_val;
|
||||
double max_val;
|
||||
// data samples
|
||||
Mat data = samples.getMat();
|
||||
if(data.type() != CV_32F)
|
||||
{
|
||||
CV_Error( CV_StsBadArg, "data must be of floating type" );
|
||||
}
|
||||
|
||||
Point min_loc;
|
||||
Point max_loc;
|
||||
// add a column of ones to the data (bias/intercept term)
|
||||
Mat data_t;
|
||||
hconcat( cv::Mat::ones( data.rows, 1, CV_32F ), data, data_t );
|
||||
CV_Assert(data_t.cols == thetas.cols);
|
||||
|
||||
Mat labels;
|
||||
// predict class labels for samples (handles binary and multiclass cases)
|
||||
Mat labels_c;
|
||||
Mat pred_m;
|
||||
Mat temp_pred;
|
||||
Mat pred_m = Mat::zeros(data_t.rows, thetas.rows, data.type());
|
||||
|
||||
if(thetas.rows == 1)
|
||||
{
|
||||
temp_pred = calc_sigmoid(data_t*thetas.t());
|
||||
// apply sigmoid function
|
||||
temp_pred = calc_sigmoid(data_t * thetas.t());
|
||||
CV_Assert(temp_pred.cols==1);
|
||||
pred_m = temp_pred.clone();
|
||||
|
||||
// if greater than 0.5, predict class 0 or predict class 1
|
||||
temp_pred = (temp_pred>0.5)/255;
|
||||
temp_pred = (temp_pred > 0.5f) / 255;
|
||||
temp_pred.convertTo(labels_c, CV_32S);
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i = 0;i<thetas.rows;i++)
|
||||
// apply sigmoid function
|
||||
pred_m.create(data_t.rows, thetas.rows, data.type());
|
||||
for(int i = 0; i < thetas.rows; i++)
|
||||
{
|
||||
temp_pred = calc_sigmoid(data_t * thetas.row(i).t());
|
||||
vconcat(temp_pred, pred_m.col(i));
|
||||
}
|
||||
for(int i = 0;i<pred_m.rows;i++)
|
||||
|
||||
// predict class with the maximum output
|
||||
Point max_loc;
|
||||
Mat labels;
|
||||
for(int i = 0; i < pred_m.rows; i++)
|
||||
{
|
||||
temp_pred = pred_m.row(i);
|
||||
minMaxLoc( temp_pred, &min_val, &max_val, &min_loc, &max_loc, Mat() );
|
||||
minMaxLoc( temp_pred, NULL, NULL, NULL, &max_loc );
|
||||
labels.push_back(max_loc.x);
|
||||
}
|
||||
labels.convertTo(labels_c, CV_32S);
|
||||
}
|
||||
pred_labs = remap_labels(labels_c, this->reverse_mapper);
|
||||
// convert pred_labs to integer type
|
||||
|
||||
// return label of the predicted class. class names can be 1,2,3,...
|
||||
Mat pred_labs = remap_labels(labels_c, this->reverse_mapper);
|
||||
pred_labs.convertTo(pred_labs, CV_32S);
|
||||
|
||||
// return either the labels or the raw output
|
||||
if ( results.needed() )
|
||||
{
|
||||
if ( rawout )
|
||||
if ( flags & StatModel::RAW_OUTPUT )
|
||||
{
|
||||
pred_m.copyTo( results );
|
||||
}
|
||||
@ -303,7 +299,7 @@ float LogisticRegressionImpl::predict(InputArray samples, OutputArray results, i
|
||||
}
|
||||
}
|
||||
|
||||
return ( pred_labs.empty() ? 0.f : (float) pred_labs.at< int >( 0 ) );
|
||||
return ( pred_labs.empty() ? 0.f : static_cast<float>(pred_labs.at<int>(0)) );
|
||||
}
|
||||
|
||||
Mat LogisticRegressionImpl::calc_sigmoid(const Mat& data) const
|
||||
@ -595,11 +591,6 @@ void LogisticRegressionImpl::read(const FileNode& fn)
|
||||
}
|
||||
}
|
||||
|
||||
Mat LogisticRegressionImpl::get_learnt_thetas() const
|
||||
{
|
||||
return this->learnt_thetas;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user