fix according to pull requests comments

This commit is contained in:
marina.kolpakova 2012-12-12 04:59:48 +04:00
parent 88c71d1b7d
commit 2d45af790e
2 changed files with 27 additions and 40 deletions

View File

@ -540,7 +540,7 @@ public:
// Param minScale is a minimum scale relative to the original size of the image on which cascade will be applyed. // Param minScale is a minimum scale relative to the original size of the image on which cascade will be applyed.
// Param minScale is a maximum scale relative to the original size of the image on which cascade will be applyed. // Param minScale is a maximum scale relative to the original size of the image on which cascade will be applyed.
// Param scales is a number of scales from minScale to maxScale. // Param scales is a number of scales from minScale to maxScale.
// Param rejfactor is used for NMS. // Param rejCriteria is used for NMS.
CV_WRAP SCascade(const double minScale = 0.4, const double maxScale = 5., const int scales = 55, const int rejCriteria = 1); CV_WRAP SCascade(const double minScale = 0.4, const double maxScale = 5., const int scales = 55, const int rejCriteria = 1);
CV_WRAP virtual ~SCascade(); CV_WRAP virtual ~SCascade();

View File

@ -443,15 +443,8 @@ namespace {
typedef cv::SCascade::Detection Detection; typedef cv::SCascade::Detection Detection;
typedef std::vector<Detection> dvector; typedef std::vector<Detection> dvector;
struct NMS
{
virtual ~NMS(){} struct ConfidenceGt
virtual void apply(dvector& objects) const = 0;
};
struct ConfidenceLess
{ {
bool operator()(const Detection& a, const Detection& b) const bool operator()(const Detection& a, const Detection& b) const
{ {
@ -459,44 +452,40 @@ struct ConfidenceLess
} }
}; };
struct DollarNMS: public NMS static float overlap(const cv::Rect &a, const cv::Rect &b)
{ {
virtual ~DollarNMS(){} int w = std::min(a.x + a.width, b.x + b.width) - std::max(a.x, b.x);
int h = std::min(a.y + a.height, b.y + b.height) - std::max(a.y, b.y);
static float overlap(const cv::Rect &a, const cv::Rect &b) return (w < 0 || h < 0)? 0.f : (float)(w * h);
}
void DollarNMS(dvector& objects)
{
static const float DollarThreshold = 0.65f;
std::sort(objects.begin(), objects.end(), ConfidenceGt());
for (dvector::iterator dIt = objects.begin(); dIt != objects.end(); ++dIt)
{ {
int w = std::min(a.x + a.width, b.x + b.width) - std::max(a.x, b.x); const Detection &a = *dIt;
int h = std::min(a.y + a.height, b.y + b.height) - std::max(a.y, b.y); for (dvector::iterator next = dIt + 1; next != objects.end(); )
return (w < 0 || h < 0)? 0.f : (float)(w * h);
}
virtual void apply(dvector& objects) const
{
std::sort(objects.begin(), objects.end(), ConfidenceLess());
for (dvector::iterator dIt = objects.begin(); dIt != objects.end(); ++dIt)
{ {
const Detection &a = *dIt; const Detection &b = *next;
for (dvector::iterator next = dIt + 1; next != objects.end(); )
{
const Detection &b = *next;
const float ovl = overlap(a.bb, b.bb) / std::min(a.bb.area(), b.bb.area()); const float ovl = overlap(a.bb, b.bb) / std::min(a.bb.area(), b.bb.area());
if (ovl > 0.65f) if (ovl > DollarThreshold)
next = objects.erase(next); next = objects.erase(next);
else else
++next; ++next;
}
} }
} }
}; }
cv::Ptr<NMS> createNMS(int type) static void suppress(int type, std::vector<Detection>& objects)
{ {
CV_Assert(type == cv::SCascade::DOLLAR); CV_Assert(type == cv::SCascade::DOLLAR);
return cv::Ptr<NMS>(new DollarNMS); DollarNMS(objects);
} }
} }
@ -522,8 +511,7 @@ void cv::SCascade::detectNoRoi(const cv::Mat& image, std::vector<Detection>& obj
} }
} }
if (rejCriteria != NO_REJECT) if (rejCriteria != NO_REJECT) suppress(rejCriteria, objects);
createNMS(rejCriteria)->apply(objects);
} }
void cv::SCascade::detect(cv::InputArray _image, cv::InputArray _rois, std::vector<Detection>& objects) const void cv::SCascade::detect(cv::InputArray _image, cv::InputArray _rois, std::vector<Detection>& objects) const
@ -572,8 +560,7 @@ void cv::SCascade::detect(cv::InputArray _image, cv::InputArray _rois, std::vect
} }
} }
if (rejCriteria != NO_REJECT) if (rejCriteria != NO_REJECT) suppress(rejCriteria, objects);
createNMS(rejCriteria)->apply(objects);
} }
void cv::SCascade::detect(InputArray _image, InputArray _rois, OutputArray _rects, OutputArray _confs) const void cv::SCascade::detect(InputArray _image, InputArray _rois, OutputArray _rects, OutputArray _confs) const