make the correctness test pass

This commit is contained in:
Vadim Pisarevsky 2014-03-07 18:18:10 +04:00
parent 8e1918e86e
commit 06c138bd64
2 changed files with 137 additions and 1191 deletions

View File

@ -250,6 +250,11 @@ void FAST_t(InputArray _img, std::vector<KeyPoint>& keypoints, int threshold, bo
}
}
template<typename pt>
struct cmp_pt
{
bool operator ()(const pt& a, const pt& b) const { return a.y < b.y || (a.y == b.y && a.x < b.x); }
};
static bool ocl_FAST( InputArray _img, std::vector<KeyPoint>& keypoints,
int threshold, bool nonmax_suppression, int maxKeypoints )
@ -265,8 +270,8 @@ static bool ocl_FAST( InputArray _img, std::vector<KeyPoint>& keypoints,
UMat kp1(1, maxKeypoints*2+1, CV_32S), score;
UMat utemp(kp1, Rect(0,0,1,1));
utemp.setTo(Scalar::all(0));
UMat ucounter1(kp1, Rect(0,0,1,1));
ucounter1.setTo(Scalar::all(0));
if( nonmax_suppression )
{
@ -284,10 +289,15 @@ static bool ocl_FAST( InputArray _img, std::vector<KeyPoint>& keypoints,
return false;
Mat mcounter;
utemp.copyTo(mcounter);
ucounter1.copyTo(mcounter);
int i, counter = mcounter.at<int>(0);
counter = std::min(counter, maxKeypoints);
keypoints.clear();
if( counter == 0 )
return true;
if( !nonmax_suppression )
{
Mat m;
@ -299,24 +309,29 @@ static bool ocl_FAST( InputArray _img, std::vector<KeyPoint>& keypoints,
else
{
UMat kp2(1, maxKeypoints*3+1, CV_32S);
utemp = kp2(Rect(0,0,1,1));
utemp.setTo(Scalar::all(0));
UMat ucounter2 = kp2(Rect(0,0,1,1));
ucounter2.setTo(Scalar::all(0));
ocl::Kernel fastNMSKernel("FAST_nonmaxSupression", ocl::features2d::fast_oclsrc);
if (fastNMSKernel.empty())
return false;
size_t globalsize_nms[] = { counter };
if( !fastNMSKernel.args(ocl::KernelArg::PtrReadOnly(kp1),
ocl::KernelArg::PtrReadWrite(kp2),
ocl::KernelArg::ReadOnlyNoSize(score),
counter, maxKeypoints).run(2, globalsize, 0, true))
counter, counter).run(1, globalsize_nms, 0, true))
return false;
Mat m;
kp2(Rect(0, 0, counter*3+1, 1)).copyTo(m);
const Point3i* pt = (const Point3i*)(m.ptr<int>() + 1);
for( i = 0; i < counter; i++ )
keypoints.push_back(KeyPoint((float)pt[i].x, (float)pt[i].y, 7.f, -1, (float)pt[i].z));
Mat m2;
kp2(Rect(0, 0, counter*3+1, 1)).copyTo(m2);
Point3i* pt2 = (Point3i*)(m2.ptr<int>() + 1);
int newcounter = std::min(m2.at<int>(0), counter);
std::sort(pt2, pt2 + newcounter, cmp_pt<Point3i>());
for( i = 0; i < newcounter; i++ )
keypoints.push_back(KeyPoint((float)pt2[i].x, (float)pt2[i].y, 7.f, -1, (float)pt2[i].z));
}
return true;
@ -325,7 +340,7 @@ static bool ocl_FAST( InputArray _img, std::vector<KeyPoint>& keypoints,
void FAST(InputArray _img, std::vector<KeyPoint>& keypoints, int threshold, bool nonmax_suppression, int type)
{
if( ocl::useOpenCL() && /*_img.isUMat() &&*/ type == FastFeatureDetector::TYPE_9_16 &&
if( ocl::useOpenCL() && _img.isUMat() && type == FastFeatureDetector::TYPE_9_16 &&
ocl_FAST(_img, keypoints, threshold, nonmax_suppression, 10000))
return;

File diff suppressed because it is too large Load Diff