added c++ wrapper of latent svm

This commit is contained in:
Maria Dimashova
2011-10-03 16:45:46 +00:00
parent 8a799aa89a
commit 4d85ee7de1
2 changed files with 158 additions and 0 deletions

View File

@@ -141,3 +141,122 @@ CvSeq* cvLatentSvmDetectObjects(IplImage* image,
return result_seq;
}
namespace cv
{
LatentSvmDetector::ObjectDetection::ObjectDetection() : score(0.f), classID(-1)
{}
LatentSvmDetector::ObjectDetection::ObjectDetection( const Rect& _rect, float _score, int _classID ) :
rect(_rect), score(_score), classID(_classID)
{}
LatentSvmDetector::LatentSvmDetector()
{}
LatentSvmDetector::LatentSvmDetector( const vector<string>& filenames, const vector<string>& _classNames )
{
load( filenames, _classNames );
}
LatentSvmDetector::~LatentSvmDetector()
{
clear();
}
void LatentSvmDetector::clear()
{
for( size_t i = 0; i < detectors.size(); i++ )
cvReleaseLatentSvmDetector( &detectors[i] );
detectors.clear();
}
bool LatentSvmDetector::empty() const
{
return detectors.empty();
}
const vector<string>& LatentSvmDetector::getClassNames() const
{
return classNames;
}
size_t LatentSvmDetector::getClassCount() const
{
return classNames.size();
}
string extractModelName( const string& filename )
{
size_t startPos = filename.rfind('/');
if( startPos == string::npos )
startPos = filename.rfind('\\');
if( startPos == string::npos )
startPos = 0;
else
startPos++;
const int extentionSize = 4; //.xml
int substrLength = filename.size() - startPos - extentionSize;
return filename.substr(startPos, substrLength);
}
bool LatentSvmDetector::load( const vector<string>& filenames, const vector<string>& _classNames )
{
clear();
CV_Assert( _classNames.empty() || _classNames.size() == filenames.size() );
for( size_t i = 0; i < filenames.size(); i++ )
{
const string filename = filenames[i];
if( filename.length() < 5 || filename.substr(filename.length()-4, 4) != ".xml" )
continue;
CvLatentSvmDetector* detector = cvLoadLatentSvmDetector( filename.c_str() );
if( detector )
{
detectors.push_back( detector );
if( _classNames.empty() )
{
classNames.push_back( extractModelName(filenames[i]) );
}
else
classNames.push_back( _classNames[i] );
}
}
return !empty();
}
void LatentSvmDetector::detect( const Mat& image,
vector<ObjectDetection>& objectDetections,
float overlapThreshold,
int numThreads )
{
objectDetections.clear();
if( numThreads <= 0 )
numThreads = 1;
for( size_t classID = 0; classID < detectors.size(); classID++ )
{
IplImage image_ipl = image;
CvMemStorage* storage = cvCreateMemStorage(0);
CvSeq* detections = cvLatentSvmDetectObjects( &image_ipl, detectors[classID], storage, overlapThreshold, numThreads );
// convert results
objectDetections.reserve( objectDetections.size() + detections->total );
for( int detectionIdx = 0; detectionIdx < detections->total; detectionIdx++ )
{
CvObjectDetection detection = *(CvObjectDetection*)cvGetSeqElem( detections, detectionIdx );
objectDetections.push_back( ObjectDetection(Rect(detection.rect), detection.score, (int)classID) );
}
cvReleaseMemStorage( &storage );
}
}
} // namespace cv