add xml serialization
This commit is contained in:
parent
69304611db
commit
c0f68ec400
@ -144,6 +144,8 @@ public:
|
|||||||
virtual float predict( const Mat& _sample, Mat& _votes, bool raw_mode, bool return_sum ) const;
|
virtual float predict( const Mat& _sample, Mat& _votes, bool raw_mode, bool return_sum ) const;
|
||||||
virtual void setRejectThresholds(cv::Mat& thresholds);
|
virtual void setRejectThresholds(cv::Mat& thresholds);
|
||||||
|
|
||||||
|
virtual void write( cv::FileStorage &fs, const Mat& thresholds = Mat()) const;
|
||||||
|
|
||||||
int logScale;
|
int logScale;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -155,6 +157,8 @@ protected:
|
|||||||
|
|
||||||
float predict( const Mat& _sample, const cv::Range range) const;
|
float predict( const Mat& _sample, const cv::Range range) const;
|
||||||
private:
|
private:
|
||||||
|
void traverse(const CvBoostTree* tree, cv::FileStorage& fs, const float* th = 0) const;
|
||||||
|
|
||||||
cv::Rect boundingBox;
|
cv::Rect boundingBox;
|
||||||
|
|
||||||
int npositives;
|
int npositives;
|
||||||
|
@ -47,6 +47,8 @@
|
|||||||
#include <opencv2/imgproc/imgproc.hpp>
|
#include <opencv2/imgproc/imgproc.hpp>
|
||||||
#include <opencv2/highgui/highgui.hpp>
|
#include <opencv2/highgui/highgui.hpp>
|
||||||
|
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
// ============ Octave ============ //
|
// ============ Octave ============ //
|
||||||
sft::Octave::Octave(cv::Rect bb, int np, int nn, int ls, int shr)
|
sft::Octave::Octave(cv::Rect bb, int np, int nn, int ls, int shr)
|
||||||
: logScale(ls), boundingBox(bb), npositives(np), nnegatives(nn), shrinkage(shr)
|
: logScale(ls), boundingBox(bb), npositives(np), nnegatives(nn), shrinkage(shr)
|
||||||
@ -293,6 +295,89 @@ void sft::Octave::generateNegatives(const Dataset& dataset)
|
|||||||
dprintf("Processing negatives finished:\n\trequested %d negatives, viewed %d samples.\n", nnegatives, total);
|
dprintf("Processing negatives finished:\n\trequested %d negatives, viewed %d samples.\n", nnegatives, total);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T> int sgn(T val) {
|
||||||
|
return (T(0) < val) - (val < T(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
void sft::Octave::traverse(const CvBoostTree* tree, cv::FileStorage& fs, const float* th) const
|
||||||
|
{
|
||||||
|
std::queue<const CvDTreeNode*> nodes;
|
||||||
|
nodes.push( tree->get_root());
|
||||||
|
const CvDTreeNode* tempNode;
|
||||||
|
int leafValIdx = 0;
|
||||||
|
int internalNodeIdx = 1;
|
||||||
|
float* leafs = new float[(int)pow(2.f, get_params().max_depth)];
|
||||||
|
|
||||||
|
fs << "{";
|
||||||
|
fs << "internalNodes" << "[";
|
||||||
|
while (!nodes.empty())
|
||||||
|
{
|
||||||
|
tempNode = nodes.front();
|
||||||
|
CV_Assert( tempNode->left );
|
||||||
|
if ( !tempNode->left->left && !tempNode->left->right)
|
||||||
|
{
|
||||||
|
leafs[-leafValIdx] = (float)tempNode->left->value;
|
||||||
|
fs << leafValIdx-- ;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
nodes.push( tempNode->left );
|
||||||
|
fs << internalNodeIdx++;
|
||||||
|
}
|
||||||
|
CV_Assert( tempNode->right );
|
||||||
|
if ( !tempNode->right->left && !tempNode->right->right)
|
||||||
|
{
|
||||||
|
leafs[-leafValIdx] = (float)tempNode->right->value;
|
||||||
|
fs << leafValIdx--;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
nodes.push( tempNode->right );
|
||||||
|
fs << internalNodeIdx++;
|
||||||
|
}
|
||||||
|
int fidx = tempNode->split->var_idx;
|
||||||
|
fs << fidx;
|
||||||
|
|
||||||
|
fs << tempNode->split->ord.c;
|
||||||
|
|
||||||
|
nodes.pop();
|
||||||
|
}
|
||||||
|
fs << "]";
|
||||||
|
|
||||||
|
fs << "leafValues" << "[";
|
||||||
|
for (int ni = 0; ni < -leafValIdx; ni++)
|
||||||
|
fs << ( (!th) ? leafs[ni] : (sgn(leafs[ni]) * *th));
|
||||||
|
fs << "]";
|
||||||
|
|
||||||
|
fs << "}";
|
||||||
|
}
|
||||||
|
|
||||||
|
void sft::Octave::write( cv::FileStorage &fso, const Mat& thresholds) const
|
||||||
|
{
|
||||||
|
fso << "{"
|
||||||
|
<< "scale" << logScale
|
||||||
|
<< "weaks" << weak->total
|
||||||
|
<< "trees" << "[";
|
||||||
|
// should be replased with the H.L. one
|
||||||
|
CvSeqReader reader;
|
||||||
|
cvStartReadSeq( weak, &reader);
|
||||||
|
|
||||||
|
for(int i = 0; i < weak->total; i++ )
|
||||||
|
{
|
||||||
|
CvBoostTree* tree;
|
||||||
|
CV_READ_SEQ_ELEM( tree, reader );
|
||||||
|
|
||||||
|
if (!thresholds.empty())
|
||||||
|
traverse(tree, fso, thresholds.ptr<float>(0)+ i);
|
||||||
|
else
|
||||||
|
traverse(tree, fso);
|
||||||
|
}
|
||||||
|
//
|
||||||
|
|
||||||
|
fso << "]"
|
||||||
|
<< "}";
|
||||||
|
}
|
||||||
|
|
||||||
bool sft::Octave::train(const Dataset& dataset, const FeaturePool& pool, int weaks, int treeDepth)
|
bool sft::Octave::train(const Dataset& dataset, const FeaturePool& pool, int weaks, int treeDepth)
|
||||||
{
|
{
|
||||||
CV_Assert(treeDepth == 2);
|
CV_Assert(treeDepth == 2);
|
||||||
|
@ -94,16 +94,41 @@ int main(int argc, char** argv)
|
|||||||
|
|
||||||
// 2. check and open output file
|
// 2. check and open output file
|
||||||
cv::FileStorage fso(cfg.outXmlPath, cv::FileStorage::WRITE);
|
cv::FileStorage fso(cfg.outXmlPath, cv::FileStorage::WRITE);
|
||||||
if(!fs.isOpened())
|
if(!fso.isOpened())
|
||||||
{
|
{
|
||||||
std::cout << "Training stopped. Output classifier Xml file " << cfg.outXmlPath << " can't be opened." << std::endl << std::flush;
|
std::cout << "Training stopped. Output classifier Xml file " << cfg.outXmlPath << " can't be opened." << std::endl << std::flush;
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cv::FileStorage fsr(cfg.outXmlPath + ".raw.xml" , cv::FileStorage::WRITE);
|
||||||
|
if(!fsr.isOpened())
|
||||||
|
{
|
||||||
|
std::cout << "Training stopped. Output classifier Xml file " <<cfg.outXmlPath + ".raw.xml" << " can't be opened." << std::endl << std::flush;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
// ovector strong;
|
// ovector strong;
|
||||||
// strong.reserve(cfg.octaves.size());
|
// strong.reserve(cfg.octaves.size());
|
||||||
|
|
||||||
// fso << "softcascade" << "{" << "octaves" << "[";
|
fso << cfg.cascadeName
|
||||||
|
<< "{"
|
||||||
|
<< "stageType" << "BOOST"
|
||||||
|
<< "featureType" << "ICF"
|
||||||
|
<< "octavesNum" << (int)cfg.octaves.size()
|
||||||
|
<< "width" << cfg.modelWinSize.width
|
||||||
|
<< "height" << cfg.modelWinSize.height
|
||||||
|
<< "shrinkage" << cfg.shrinkage
|
||||||
|
<< "octaves" << "[";
|
||||||
|
|
||||||
|
fsr << cfg.cascadeName
|
||||||
|
<< "{"
|
||||||
|
<< "stageType" << "BOOST"
|
||||||
|
<< "featureType" << "ICF"
|
||||||
|
<< "octavesNum" << (int)cfg.octaves.size()
|
||||||
|
<< "width" << cfg.modelWinSize.width
|
||||||
|
<< "height" << cfg.modelWinSize.height
|
||||||
|
<< "shrinkage" << cfg.shrinkage
|
||||||
|
<< "octaves" << "[";
|
||||||
|
|
||||||
// 3. Train all octaves
|
// 3. Train all octaves
|
||||||
for (ivector::const_iterator it = cfg.octaves.begin(); it != cfg.octaves.end(); ++it)
|
for (ivector::const_iterator it = cfg.octaves.begin(); it != cfg.octaves.end(); ++it)
|
||||||
@ -137,6 +162,8 @@ int main(int argc, char** argv)
|
|||||||
cv::Mat thresholds;
|
cv::Mat thresholds;
|
||||||
boost.setRejectThresholds(thresholds);
|
boost.setRejectThresholds(thresholds);
|
||||||
|
|
||||||
|
boost.write(fso, thresholds);
|
||||||
|
boost.write(fsr);
|
||||||
// std::cout << "thresholds " << thresholds << std::endl;
|
// std::cout << "thresholds " << thresholds << std::endl;
|
||||||
|
|
||||||
cv::FileStorage tfs(("thresholds." + cfg.resPath(it)).c_str(), cv::FileStorage::WRITE);
|
cv::FileStorage tfs(("thresholds." + cfg.resPath(it)).c_str(), cv::FileStorage::WRITE);
|
||||||
@ -146,7 +173,8 @@ int main(int argc, char** argv)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fso << "]" << "}";
|
fso << "]" << "}";
|
||||||
|
fsr << "]" << "}";
|
||||||
|
|
||||||
// // // 6. Set thresolds
|
// // // 6. Set thresolds
|
||||||
// // cascade.prune();
|
// // cascade.prune();
|
||||||
|
@ -1580,8 +1580,11 @@ bool CvCascadeBoost::isErrDesired()
|
|||||||
for( int i = 0; i < sCount; i++ )
|
for( int i = 0; i < sCount; i++ )
|
||||||
if( ((CvCascadeBoostTrainData*)data)->featureEvaluator->getCls( i ) == 1.0F )
|
if( ((CvCascadeBoostTrainData*)data)->featureEvaluator->getCls( i ) == 1.0F )
|
||||||
eval[numPos++] = predict( i, true );
|
eval[numPos++] = predict( i, true );
|
||||||
|
|
||||||
icvSortFlt( &eval[0], numPos, 0 );
|
icvSortFlt( &eval[0], numPos, 0 );
|
||||||
|
|
||||||
int thresholdIdx = (int)((1.0F - minHitRate) * numPos);
|
int thresholdIdx = (int)((1.0F - minHitRate) * numPos);
|
||||||
|
|
||||||
threshold = eval[ thresholdIdx ];
|
threshold = eval[ thresholdIdx ];
|
||||||
numPosTrue = numPos - thresholdIdx;
|
numPosTrue = numPos - thresholdIdx;
|
||||||
for( int i = thresholdIdx - 1; i >= 0; i--)
|
for( int i = thresholdIdx - 1; i >= 0; i--)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user