FaceRecognizer supports updating a model now. Documentation has been updated to reflect latest changes.

This commit is contained in:
Philipp Wagner
2012-07-30 19:31:10 +02:00
parent b9d7c712f5
commit 62a8f6783e
3 changed files with 99 additions and 25 deletions

View File

@@ -1,5 +1,5 @@
/*
* Copyright (c) 2011. Philipp Wagner <bytefish[at]gmx[dot]de>.
* Copyright (c) 2011,2012. Philipp Wagner <bytefish[at]gmx[dot]de>.
* Released to public domain under terms of the BSD Simplified license.
*
* Redistribution and use in source and binary forms, with or without
@@ -197,10 +197,10 @@ public:
void predict(InputArray _src, int &label, double &dist) const;
// See FaceRecognizer::load.
virtual void load(const FileStorage& fs);
void load(const FileStorage& fs);
// See FaceRecognizer::save.
virtual void save(FileStorage& fs) const;
void save(FileStorage& fs) const;
AlgorithmInfo* info() const;
};
@@ -223,6 +223,12 @@ private:
vector<Mat> _histograms;
Mat _labels;
// Computes a LBPH model with images in src and
// corresponding labels in labels, possibly preserving
// old model data.
void train(InputArrayOfArrays src, InputArray labels, bool preserveData);
public:
using FaceRecognizer::save;
using FaceRecognizer::load;
@@ -265,6 +271,10 @@ public:
// corresponding labels in labels.
void train(InputArrayOfArrays src, InputArray labels);
// Updates this LBPH model with images in src and
// corresponding labels in labels.
void update(InputArrayOfArrays src, InputArray labels);
// Predicts the label of a query image in src.
int predict(InputArray src) const;
@@ -290,6 +300,11 @@ public:
//------------------------------------------------------------------------------
// FaceRecognizer
//------------------------------------------------------------------------------
void FaceRecognizer::update(InputArrayOfArrays, InputArray) {
string error_msg = format("This FaceRecognizer (%s) does not support updating, you have to use FaceRecognizer::train to update it.", this->name().c_str());
CV_Error(CV_StsNotImplemented, error_msg);
}
void FaceRecognizer::save(const string& filename) const {
FileStorage fs(filename, FileStorage::WRITE);
if (!fs.isOpened())
@@ -727,28 +742,45 @@ void LBPH::save(FileStorage& fs) const {
fs << "labels" << _labels;
}
void LBPH::train(InputArrayOfArrays _src, InputArray _lbls) {
if(_src.kind() != _InputArray::STD_VECTOR_MAT && _src.kind() != _InputArray::STD_VECTOR_VECTOR) {
void LBPH::train(InputArrayOfArrays _in_src, InputArray _in_labels) {
this->train(_in_src, _in_labels, false);
}
void LBPH::update(InputArrayOfArrays _in_src, InputArray _in_labels) {
// got no data, just return
if(_in_src.total() == 0)
return;
this->train(_in_src, _in_labels, true);
}
void LBPH::train(InputArrayOfArrays _in_src, InputArray _in_labels, bool preserveData) {
if(_in_src.kind() != _InputArray::STD_VECTOR_MAT && _in_src.kind() != _InputArray::STD_VECTOR_VECTOR) {
string error_message = "The images are expected as InputArray::STD_VECTOR_MAT (a std::vector<Mat>) or _InputArray::STD_VECTOR_VECTOR (a std::vector< vector<...> >).";
CV_Error(CV_StsBadArg, error_message);
}
if(_src.total() == 0) {
if(_in_src.total() == 0) {
string error_message = format("Empty training data was given. You'll need more than one sample to learn a model.");
CV_Error(CV_StsUnsupportedFormat, error_message);
} else if(_lbls.getMat().type() != CV_32SC1) {
string error_message = format("Labels must be given as integer (CV_32SC1). Expected %d, but was %d.", CV_32SC1, _lbls.type());
} else if(_in_labels.getMat().type() != CV_32SC1) {
string error_message = format("Labels must be given as integer (CV_32SC1). Expected %d, but was %d.", CV_32SC1, _in_labels.type());
CV_Error(CV_StsUnsupportedFormat, error_message);
}
// get the vector of matrices
vector<Mat> src;
_src.getMatVector(src);
_in_src.getMatVector(src);
// get the label matrix
Mat labels = _lbls.getMat();
Mat labels = _in_labels.getMat();
// check if data is well- aligned
if(labels.total() != src.size()) {
string error_message = format("The number of samples (src) must equal the number of labels (labels). Was len(samples)=%d, len(labels)=%d.", src.size(), _labels.total());
CV_Error(CV_StsBadArg, error_message);
}
// if this model should be trained without preserving old data, delete old model data
if(!preserveData) {
_labels.release();
_histograms.clear();
}
// append labels to _labels matrix
for(size_t labelIdx = 0; labelIdx < labels.total(); labelIdx++) {
_labels.push_back(labels.at<int>((int)labelIdx));