diff --git a/modules/objdetect/include/opencv2/objdetect/objdetect.hpp b/modules/objdetect/include/opencv2/objdetect/objdetect.hpp index 5e2409378..e97828a62 100644 --- a/modules/objdetect/include/opencv2/objdetect/objdetect.hpp +++ b/modules/objdetect/include/opencv2/objdetect/objdetect.hpp @@ -286,6 +286,45 @@ namespace cv ///////////////////////////// Object Detection //////////////////////////// +/* + * This is a class wrapping up the structure CvLatentSvmDetector and functions working with it. + * The class goals are: + * 1) provide c++ interface; + * 2) make it possible to load and detect more than one class (model) unlike CvLatentSvmDetector. + */ +class CV_EXPORTS_W LatentSvmDetector +{ +public: + struct CV_EXPORTS_W ObjectDetection + { + ObjectDetection(); + ObjectDetection( const Rect& rect, float score, int classID=-1 ); + Rect rect; + float score; + int classID; + }; + + CV_WRAP LatentSvmDetector(); + CV_WRAP LatentSvmDetector( const vector& filenames, const vector& classNames=vector() ); + virtual ~LatentSvmDetector(); + + CV_WRAP virtual void clear(); + CV_WRAP virtual bool empty() const; + CV_WRAP bool load( const vector& filenames, const vector& classNames=vector() ); + + CV_WRAP virtual void detect( const Mat& image, + vector& objectDetections, + float overlapThreshold=0.5f, + int numThreads=-1 ); + + const vector& getClassNames() const; + size_t getClassCount() const; + +private: + vector detectors; + vector classNames; +}; + CV_EXPORTS void groupRectangles(CV_OUT CV_IN_OUT vector& rectList, int groupThreshold, double eps=0.2); CV_EXPORTS_W void groupRectangles(CV_OUT CV_IN_OUT vector& rectList, CV_OUT vector& weights, int groupThreshold, double eps=0.2); CV_EXPORTS void groupRectangles( vector& rectList, int groupThreshold, double eps, vector* weights, vector* levelWeights ); diff --git a/modules/objdetect/src/latentsvmdetector.cpp b/modules/objdetect/src/latentsvmdetector.cpp index 5b39e45c9..d8e8f7a4a 100644 --- a/modules/objdetect/src/latentsvmdetector.cpp +++ b/modules/objdetect/src/latentsvmdetector.cpp @@ -141,3 +141,122 @@ CvSeq* cvLatentSvmDetectObjects(IplImage* image, return result_seq; } + +namespace cv +{ +LatentSvmDetector::ObjectDetection::ObjectDetection() : score(0.f), classID(-1) +{} + +LatentSvmDetector::ObjectDetection::ObjectDetection( const Rect& _rect, float _score, int _classID ) : + rect(_rect), score(_score), classID(_classID) +{} + +LatentSvmDetector::LatentSvmDetector() +{} + +LatentSvmDetector::LatentSvmDetector( const vector& filenames, const vector& _classNames ) +{ + load( filenames, _classNames ); +} + +LatentSvmDetector::~LatentSvmDetector() +{ + clear(); +} + +void LatentSvmDetector::clear() +{ + for( size_t i = 0; i < detectors.size(); i++ ) + cvReleaseLatentSvmDetector( &detectors[i] ); + detectors.clear(); +} + +bool LatentSvmDetector::empty() const +{ + return detectors.empty(); +} + +const vector& LatentSvmDetector::getClassNames() const +{ + return classNames; +} + +size_t LatentSvmDetector::getClassCount() const +{ + return classNames.size(); +} + +string extractModelName( const string& filename ) +{ + size_t startPos = filename.rfind('/'); + if( startPos == string::npos ) + startPos = filename.rfind('\\'); + + if( startPos == string::npos ) + startPos = 0; + else + startPos++; + + const int extentionSize = 4; //.xml + + int substrLength = filename.size() - startPos - extentionSize; + + return filename.substr(startPos, substrLength); +} + +bool LatentSvmDetector::load( const vector& filenames, const vector& _classNames ) +{ + clear(); + + CV_Assert( _classNames.empty() || _classNames.size() == filenames.size() ); + + for( size_t i = 0; i < filenames.size(); i++ ) + { + const string filename = filenames[i]; + if( filename.length() < 5 || filename.substr(filename.length()-4, 4) != ".xml" ) + continue; + + CvLatentSvmDetector* detector = cvLoadLatentSvmDetector( filename.c_str() ); + if( detector ) + { + detectors.push_back( detector ); + if( _classNames.empty() ) + { + classNames.push_back( extractModelName(filenames[i]) ); + } + else + classNames.push_back( _classNames[i] ); + } + } + + return !empty(); +} + +void LatentSvmDetector::detect( const Mat& image, + vector& objectDetections, + float overlapThreshold, + int numThreads ) +{ + objectDetections.clear(); + if( numThreads <= 0 ) + numThreads = 1; + + for( size_t classID = 0; classID < detectors.size(); classID++ ) + { + IplImage image_ipl = image; + CvMemStorage* storage = cvCreateMemStorage(0); + CvSeq* detections = cvLatentSvmDetectObjects( &image_ipl, detectors[classID], storage, overlapThreshold, numThreads ); + + // convert results + objectDetections.reserve( objectDetections.size() + detections->total ); + for( int detectionIdx = 0; detectionIdx < detections->total; detectionIdx++ ) + { + CvObjectDetection detection = *(CvObjectDetection*)cvGetSeqElem( detections, detectionIdx ); + objectDetections.push_back( ObjectDetection(Rect(detection.rect), detection.score, (int)classID) ); + } + + cvReleaseMemStorage( &storage ); + } +} + +} // namespace cv