finished OpenCL port of ORB

This commit is contained in:
Vadim Pisarevsky 2014-03-13 22:56:53 +04:00
parent efdfca7a11
commit 3e854fa6e5
4 changed files with 744 additions and 313 deletions

View File

@ -28,7 +28,7 @@ PERF_TEST_P(orb, detect, testing::Values(ORB_IMAGES))
TEST_CYCLE() detector(frame, mask, points); TEST_CYCLE() detector(frame, mask, points);
sort(points.begin(), points.end(), comparators::KeypointGreater()); sort(points.begin(), points.end(), comparators::KeypointGreater());
SANITY_CHECK_KEYPOINTS(points); SANITY_CHECK_KEYPOINTS(points, 1e-5);
} }
PERF_TEST_P(orb, extract, testing::Values(ORB_IMAGES)) PERF_TEST_P(orb, extract, testing::Values(ORB_IMAGES))
@ -72,6 +72,6 @@ PERF_TEST_P(orb, full, testing::Values(ORB_IMAGES))
TEST_CYCLE() detector(frame, mask, points, descriptors, false); TEST_CYCLE() detector(frame, mask, points, descriptors, false);
perf::sort(points, descriptors); perf::sort(points, descriptors);
SANITY_CHECK_KEYPOINTS(points); SANITY_CHECK_KEYPOINTS(points, 1e-5);
SANITY_CHECK(descriptors); SANITY_CHECK(descriptors);
} }

View File

@ -330,8 +330,7 @@ static bool ocl_FAST( InputArray _img, std::vector<KeyPoint>& keypoints,
void FAST(InputArray _img, std::vector<KeyPoint>& keypoints, int threshold, bool nonmax_suppression, int type) void FAST(InputArray _img, std::vector<KeyPoint>& keypoints, int threshold, bool nonmax_suppression, int type)
{ {
double t = (double)getTickCount(); if( ocl::useOpenCL() && _img.isUMat() && type == FastFeatureDetector::TYPE_9_16 &&
if( ocl::useOpenCL() && /*_img.isUMat() &&*/ type == FastFeatureDetector::TYPE_9_16 &&
ocl_FAST(_img, keypoints, threshold, nonmax_suppression, 10000)) ocl_FAST(_img, keypoints, threshold, nonmax_suppression, 10000))
; ;
@ -350,7 +349,6 @@ void FAST(InputArray _img, std::vector<KeyPoint>& keypoints, int threshold, bool
FAST_t<16>(_img, keypoints, threshold, nonmax_suppression); FAST_t<16>(_img, keypoints, threshold, nonmax_suppression);
break; break;
} }
printf("time=%.2fms\n", ((double)getTickCount() - t)*1000./getTickFrequency());
} }
@ -371,10 +369,16 @@ FastFeatureDetector::FastFeatureDetector( int _threshold, bool _nonmaxSuppressio
void FastFeatureDetector::detectImpl( InputArray _image, std::vector<KeyPoint>& keypoints, InputArray _mask ) const void FastFeatureDetector::detectImpl( InputArray _image, std::vector<KeyPoint>& keypoints, InputArray _mask ) const
{ {
Mat image = _image.getMat(), mask = _mask.getMat(), grayImage = image; Mat mask = _mask.getMat(), grayImage;
if( image.type() != CV_8U ) UMat ugrayImage;
cvtColor( image, grayImage, COLOR_BGR2GRAY ); _InputArray gray = _image;
FAST( grayImage, keypoints, threshold, nonmaxSuppression, type ); if( _image.type() != CV_8U )
{
_OutputArray ogray = _image.isUMat() ? _OutputArray(ugrayImage) : _OutputArray(grayImage);
cvtColor( _image, ogray, COLOR_BGR2GRAY );
gray = ogray;
}
FAST( gray, keypoints, threshold, nonmaxSuppression, type );
KeyPointsFilter::runByPixelsMask( keypoints, mask ); KeyPointsFilter::runByPixelsMask( keypoints, mask );
} }

View File

@ -0,0 +1,254 @@
// OpenCL port of the ORB feature detector and descriptor extractor
// Copyright (C) 2014, Itseez Inc. See the license at http://opencv.org
//
// The original code has been contributed by Peter Andreas Entschev, peter@entschev.com
#define LAYERINFO_SIZE 1
#define LAYERINFO_OFS 0
#define KEYPOINT_SIZE 3
#define ORIENTED_KEYPOINT_SIZE 4
#define KEYPOINT_X 0
#define KEYPOINT_Y 1
#define KEYPOINT_Z 2
#define KEYPOINT_ANGLE 3
/////////////////////////////////////////////////////////////
#ifdef ORB_RESPONSES
__kernel void
ORB_HarrisResponses(__global const uchar* imgbuf, int imgstep, int imgoffset0,
__global const int* layerinfo, __global const int* keypoints,
__global float* responses, int nkeypoints )
{
int idx = get_global_id(0);
if( idx < nkeypoints )
{
__global const int* kpt = keypoints + idx*KEYPOINT_SIZE;
__global const int* layer = layerinfo + kpt[KEYPOINT_Z]*LAYERINFO_SIZE;
__global const uchar* img = imgbuf + imgoffset0 + layer[LAYERINFO_OFS] +
(kpt[KEYPOINT_Y] - blockSize/2)*imgstep + (kpt[KEYPOINT_X] - blockSize/2);
int i, j;
int a = 0, b = 0, c = 0;
for( i = 0; i < blockSize; i++, img += imgstep-blockSize )
{
for( j = 0; j < blockSize; j++, img++ )
{
int Ix = (img[1] - img[-1])*2 + img[-imgstep+1] - img[-imgstep-1] + img[imgstep+1] - img[imgstep-1];
int Iy = (img[imgstep] - img[-imgstep])*2 + img[imgstep-1] - img[-imgstep-1] + img[imgstep+1] - img[-imgstep+1];
a += Ix*Ix;
b += Iy*Iy;
c += Ix*Iy;
}
}
responses[idx] = ((float)a * b - (float)c * c - HARRIS_K * (float)(a + b) * (a + b))*scale_sq_sq;
}
}
#endif
/////////////////////////////////////////////////////////////
#ifdef ORB_ANGLES
#define _DBL_EPSILON 2.2204460492503131e-16f
#define atan2_p1 (0.9997878412794807f*57.29577951308232f)
#define atan2_p3 (-0.3258083974640975f*57.29577951308232f)
#define atan2_p5 (0.1555786518463281f*57.29577951308232f)
#define atan2_p7 (-0.04432655554792128f*57.29577951308232f)
inline float fastAtan2( float y, float x )
{
float ax = fabs(x), ay = fabs(y);
float a, c, c2;
if( ax >= ay )
{
c = ay/(ax + _DBL_EPSILON);
c2 = c*c;
a = (((atan2_p7*c2 + atan2_p5)*c2 + atan2_p3)*c2 + atan2_p1)*c;
}
else
{
c = ax/(ay + _DBL_EPSILON);
c2 = c*c;
a = 90.f - (((atan2_p7*c2 + atan2_p5)*c2 + atan2_p3)*c2 + atan2_p1)*c;
}
if( x < 0 )
a = 180.f - a;
if( y < 0 )
a = 360.f - a;
return a;
}
__kernel void
ORB_ICAngle(__global const uchar* imgbuf, int imgstep, int imgoffset0,
__global const int* layerinfo, __global const int* keypoints,
__global float* responses, const __global int* u_max,
int nkeypoints, int half_k )
{
int idx = get_global_id(0);
if( idx < nkeypoints )
{
__global const int* kpt = keypoints + idx*KEYPOINT_SIZE;
__global const int* layer = layerinfo + kpt[KEYPOINT_Z]*LAYERINFO_SIZE;
__global const uchar* center = imgbuf + imgoffset0 + layer[LAYERINFO_OFS] +
kpt[KEYPOINT_Y]*imgstep + kpt[KEYPOINT_X];
int u, v, m_01 = 0, m_10 = 0;
// Treat the center line differently, v=0
for( u = -half_k; u <= half_k; u++ )
m_10 += u * center[u];
// Go line by line in the circular patch
for( v = 1; v <= half_k; v++ )
{
// Proceed over the two lines
int v_sum = 0;
int d = u_max[v];
for( u = -d; u <= d; u++ )
{
int val_plus = center[u + v*imgstep], val_minus = center[u - v*imgstep];
v_sum += (val_plus - val_minus);
m_10 += u * (val_plus + val_minus);
}
m_01 += v * v_sum;
}
// we do not use OpenCL's atan2 intrinsic,
// because we want to get _exactly_ the same results as the CPU version
responses[idx] = fastAtan2((float)m_01, (float)m_10);
}
}
#endif
/////////////////////////////////////////////////////////////
#ifdef ORB_DESCRIPTORS
__kernel void
ORB_computeDescriptor(__global const uchar* imgbuf, int imgstep, int imgoffset0,
__global const int* layerinfo, __global const int* keypoints,
__global uchar* _desc, const __global int* pattern,
int nkeypoints, int dsize )
{
int idx = get_global_id(0);
if( idx < nkeypoints )
{
int i;
__global const int* kpt = keypoints + idx*ORIENTED_KEYPOINT_SIZE;
__global const int* layer = layerinfo + kpt[KEYPOINT_Z]*LAYERINFO_SIZE;
__global const uchar* center = imgbuf + imgoffset0 + layer[LAYERINFO_OFS] +
kpt[KEYPOINT_Y]*imgstep + kpt[KEYPOINT_X];
float angle = as_float(kpt[KEYPOINT_ANGLE]);
angle *= 0.01745329251994329547f;
float sina = sin(angle);
float cosa = cos(angle);
__global uchar* desc = _desc + idx*dsize;
#define GET_VALUE(idx) \
center[mad24(convert_int_rte(pattern[(idx)*2] * sina + pattern[(idx)*2+1] * cosa), imgstep, \
convert_int_rte(pattern[(idx)*2] * cosa - pattern[(idx)*2+1] * sina))]
for( i = 0; i < dsize; i++ )
{
int val;
#if WTA_K == 2
int t0, t1;
t0 = GET_VALUE(0); t1 = GET_VALUE(1);
val = t0 < t1;
t0 = GET_VALUE(2); t1 = GET_VALUE(3);
val |= (t0 < t1) << 1;
t0 = GET_VALUE(4); t1 = GET_VALUE(5);
val |= (t0 < t1) << 2;
t0 = GET_VALUE(6); t1 = GET_VALUE(7);
val |= (t0 < t1) << 3;
t0 = GET_VALUE(8); t1 = GET_VALUE(9);
val |= (t0 < t1) << 4;
t0 = GET_VALUE(10); t1 = GET_VALUE(11);
val |= (t0 < t1) << 5;
t0 = GET_VALUE(12); t1 = GET_VALUE(13);
val |= (t0 < t1) << 6;
t0 = GET_VALUE(14); t1 = GET_VALUE(15);
val |= (t0 < t1) << 7;
pattern += 16*2;
#elif WTA_K == 3
int t0, t1, t2;
t0 = GET_VALUE(0); t1 = GET_VALUE(1); t2 = GET_VALUE(2);
val = t2 > t1 ? (t2 > t0 ? 2 : 0) : (t1 > t0);
t0 = GET_VALUE(3); t1 = GET_VALUE(4); t2 = GET_VALUE(5);
val |= (t2 > t1 ? (t2 > t0 ? 2 : 0) : (t1 > t0)) << 2;
t0 = GET_VALUE(6); t1 = GET_VALUE(7); t2 = GET_VALUE(8);
val |= (t2 > t1 ? (t2 > t0 ? 2 : 0) : (t1 > t0)) << 4;
t0 = GET_VALUE(9); t1 = GET_VALUE(10); t2 = GET_VALUE(11);
val |= (t2 > t1 ? (t2 > t0 ? 2 : 0) : (t1 > t0)) << 6;
pattern += 12*2;
#elif WTA_K == 4
int t0, t1, t2, t3, k, val;
int a, b;
t0 = GET_VALUE(0); t1 = GET_VALUE(1);
t2 = GET_VALUE(2); t3 = GET_VALUE(3);
a = 0, b = 2;
if( t1 > t0 ) t0 = t1, a = 1;
if( t3 > t2 ) t2 = t3, b = 3;
k = t0 > t2 ? a : b;
val = k;
t0 = GET_VALUE(4); t1 = GET_VALUE(5);
t2 = GET_VALUE(6); t3 = GET_VALUE(7);
a = 0, b = 2;
if( t1 > t0 ) t0 = t1, a = 1;
if( t3 > t2 ) t2 = t3, b = 3;
k = t0 > t2 ? a : b;
val |= k << 2;
t0 = GET_VALUE(8); t1 = GET_VALUE(9);
t2 = GET_VALUE(10); t3 = GET_VALUE(11);
a = 0, b = 2;
if( t1 > t0 ) t0 = t1, a = 1;
if( t3 > t2 ) t2 = t3, b = 3;
k = t0 > t2 ? a : b;
val |= k << 4;
t0 = GET_VALUE(12); t1 = GET_VALUE(13);
t2 = GET_VALUE(14); t3 = GET_VALUE(15);
a = 0, b = 2;
if( t1 > t0 ) t0 = t1, a = 1;
if( t3 > t2 ) t2 = t3, b = 3;
k = t0 > t2 ? a : b;
val |= k << 6;
pattern += 16*2;
#else
#error "unknown/undefined WTA_K value; should be 2, 3 or 4"
#endif
desc[i] = (uchar)val;
}
}
}
#endif

View File

@ -35,6 +35,7 @@
/** Authors: Ethan Rublee, Vincent Rabaud, Gary Bradski */ /** Authors: Ethan Rublee, Vincent Rabaud, Gary Bradski */
#include "precomp.hpp" #include "precomp.hpp"
#include "opencl_kernels.hpp"
#include <iterator> #include <iterator>
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@ -43,14 +44,86 @@ namespace cv
{ {
const float HARRIS_K = 0.04f; const float HARRIS_K = 0.04f;
const int DESCRIPTOR_SIZE = 32;
template<typename _Tp> inline void copyVectorToUMat(const std::vector<_Tp>& v, OutputArray um)
{
if(v.empty())
um.release();
Mat(1, (int)(v.size()*sizeof(v[0])), CV_8U, (void*)&v[0]).copyTo(um);
}
static bool
ocl_HarrisResponses(const UMat& imgbuf,
const UMat& layerinfo,
const UMat& keypoints,
UMat& responses,
int nkeypoints, int blockSize, float harris_k)
{
size_t globalSize[] = {nkeypoints};
float scale = 1.f/((1 << 2) * blockSize * 255.f);
float scale_sq_sq = scale * scale * scale * scale;
ocl::Kernel hr_ker("ORB_HarrisResponses", ocl::features2d::orb_oclsrc,
format("-D ORB_RESPONSES -D blockSize=%d -D scale_sq_sq=%.12ef -D HARRIS_K=%.12ff", blockSize, scale_sq_sq, harris_k));
if( hr_ker.empty() )
return false;
return hr_ker.args(ocl::KernelArg::ReadOnlyNoSize(imgbuf),
ocl::KernelArg::PtrReadOnly(layerinfo),
ocl::KernelArg::PtrReadOnly(keypoints),
ocl::KernelArg::PtrWriteOnly(responses),
nkeypoints).run(1, globalSize, 0, true);
}
static bool
ocl_ICAngles(const UMat& imgbuf, const UMat& layerinfo,
const UMat& keypoints, UMat& responses,
const UMat& umax, int nkeypoints, int half_k)
{
size_t globalSize[] = {nkeypoints};
ocl::Kernel icangle_ker("ORB_ICAngle", ocl::features2d::orb_oclsrc, "-D ORB_ANGLES");
if( icangle_ker.empty() )
return false;
return icangle_ker.args(ocl::KernelArg::ReadOnlyNoSize(imgbuf),
ocl::KernelArg::PtrReadOnly(layerinfo),
ocl::KernelArg::PtrReadOnly(keypoints),
ocl::KernelArg::PtrWriteOnly(responses),
ocl::KernelArg::PtrReadOnly(umax),
nkeypoints, half_k).run(1, globalSize, 0, true);
}
static bool
ocl_computeOrbDescriptors(const UMat& imgbuf, const UMat& layerInfo,
const UMat& keypoints, UMat& desc, const UMat& pattern,
int nkeypoints, int dsize, int WTA_K)
{
size_t globalSize[] = {nkeypoints};
ocl::Kernel desc_ker("ORB_computeDescriptor", ocl::features2d::orb_oclsrc,
format("-D ORB_DESCRIPTORS -D WTA_K=%d", WTA_K));
if( desc_ker.empty() )
return false;
return desc_ker.args(ocl::KernelArg::ReadOnlyNoSize(imgbuf),
ocl::KernelArg::PtrReadOnly(layerInfo),
ocl::KernelArg::PtrReadOnly(keypoints),
ocl::KernelArg::PtrWriteOnly(desc),
ocl::KernelArg::PtrReadOnly(pattern),
nkeypoints, dsize).run(1, globalSize, 0, true);
}
/** /**
* Function that computes the Harris responses in a * Function that computes the Harris responses in a
* blockSize x blockSize patch at given points in an image * blockSize x blockSize patch at given points in the image
*/ */
static void static void
HarrisResponses(const Mat& img, std::vector<KeyPoint>& pts, int blockSize, float harris_k) HarrisResponses(const Mat& img, const std::vector<Rect>& layerinfo,
std::vector<KeyPoint>& pts, int blockSize, float harris_k)
{ {
CV_Assert( img.type() == CV_8UC1 && blockSize*blockSize <= 2048 ); CV_Assert( img.type() == CV_8UC1 && blockSize*blockSize <= 2048 );
@ -60,8 +133,7 @@ HarrisResponses(const Mat& img, std::vector<KeyPoint>& pts, int blockSize, float
int step = (int)(img.step/img.elemSize1()); int step = (int)(img.step/img.elemSize1());
int r = blockSize/2; int r = blockSize/2;
float scale = (1 << 2) * blockSize * 255.0f; float scale = 1.f/((1 << 2) * blockSize * 255.f);
scale = 1.0f / scale;
float scale_sq_sq = scale * scale * scale * scale; float scale_sq_sq = scale * scale * scale * scale;
AutoBuffer<int> ofsbuf(blockSize*blockSize); AutoBuffer<int> ofsbuf(blockSize*blockSize);
@ -72,10 +144,11 @@ HarrisResponses(const Mat& img, std::vector<KeyPoint>& pts, int blockSize, float
for( ptidx = 0; ptidx < ptsize; ptidx++ ) for( ptidx = 0; ptidx < ptsize; ptidx++ )
{ {
int x0 = cvRound(pts[ptidx].pt.x - r); int x0 = cvRound(pts[ptidx].pt.x);
int y0 = cvRound(pts[ptidx].pt.y - r); int y0 = cvRound(pts[ptidx].pt.y);
int z = pts[ptidx].octave;
const uchar* ptr0 = ptr00 + y0*step + x0; const uchar* ptr0 = ptr00 + (y0 - r + layerinfo[z].y)*step + x0 - r + layerinfo[z].x;
int a = 0, b = 0, c = 0; int a = 0, b = 0, c = 0;
for( int k = 0; k < blockSize*blockSize; k++ ) for( int k = 0; k < blockSize*blockSize; k++ )
@ -94,158 +167,175 @@ HarrisResponses(const Mat& img, std::vector<KeyPoint>& pts, int blockSize, float
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
static float IC_Angle(const Mat& image, const int half_k, Point2f pt, static void ICAngles(const Mat& img, const std::vector<Rect>& layerinfo,
const std::vector<int> & u_max) std::vector<KeyPoint>& pts, const std::vector<int> & u_max, int half_k)
{ {
int m_01 = 0, m_10 = 0; int step = (int)img.step1();
size_t ptidx, ptsize = pts.size();
const uchar* center = &image.at<uchar> (cvRound(pt.y), cvRound(pt.x)); for( ptidx = 0; ptidx < ptsize; ptidx++ )
// Treat the center line differently, v=0
for (int u = -half_k; u <= half_k; ++u)
m_10 += u * center[u];
// Go line by line in the circular patch
int step = (int)image.step1();
for (int v = 1; v <= half_k; ++v)
{ {
// Proceed over the two lines const Rect& layer = layerinfo[pts[ptidx].octave];
int v_sum = 0; const uchar* center = &img.at<uchar>(cvRound(pts[ptidx].pt.y) + layer.y, cvRound(pts[ptidx].pt.x) + layer.x);
int d = u_max[v];
for (int u = -d; u <= d; ++u)
{
int val_plus = center[u + v*step], val_minus = center[u - v*step];
v_sum += (val_plus - val_minus);
m_10 += u * (val_plus + val_minus);
}
m_01 += v * v_sum;
}
return fastAtan2((float)m_01, (float)m_10); int m_01 = 0, m_10 = 0;
// Treat the center line differently, v=0
for (int u = -half_k; u <= half_k; ++u)
m_10 += u * center[u];
// Go line by line in the circular patch
for (int v = 1; v <= half_k; ++v)
{
// Proceed over the two lines
int v_sum = 0;
int d = u_max[v];
for (int u = -d; u <= d; ++u)
{
int val_plus = center[u + v*step], val_minus = center[u - v*step];
v_sum += (val_plus - val_minus);
m_10 += u * (val_plus + val_minus);
}
m_01 += v * v_sum;
}
pts[ptidx].angle = fastAtan2((float)m_01, (float)m_10);
}
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
static void computeOrbDescriptor(const KeyPoint& kpt, static void
const Mat& img, const Point* pattern, computeOrbDescriptors( const Mat& imagePyramid, const std::vector<Rect>& layerInfo,
uchar* desc, int dsize, int WTA_K) const std::vector<float>& layerScale, std::vector<KeyPoint>& keypoints,
Mat& descriptors, const std::vector<Point>& _pattern, int dsize, int WTA_K )
{ {
float angle = kpt.angle; int step = (int)imagePyramid.step;
//angle = cvFloor(angle/12)*12.f; int j, i, nkeypoints = (int)keypoints.size();
angle *= (float)(CV_PI/180.f);
float a = (float)cos(angle), b = (float)sin(angle);
const uchar* center = &img.at<uchar>(cvRound(kpt.pt.y), cvRound(kpt.pt.x)); for( j = 0; j < nkeypoints; j++ )
int step = (int)img.step; {
const KeyPoint& kpt = keypoints[j];
const Rect& layer = layerInfo[kpt.octave];
float scale = 1.f/layerScale[kpt.octave];
float angle = kpt.angle;
float x, y; angle *= (float)(CV_PI/180.f);
int ix, iy; float a = (float)cos(angle), b = (float)sin(angle);
#if 1
#define GET_VALUE(idx) \ const uchar* center = &imagePyramid.at<uchar>(cvRound(kpt.pt.y*scale) + layer.y,
(x = pattern[idx].x*a - pattern[idx].y*b, \ cvRound(kpt.pt.x*scale) + layer.x);
float x, y;
int ix, iy;
const Point* pattern = &_pattern[0];
uchar* desc = descriptors.ptr<uchar>(j);
#if 1
#define GET_VALUE(idx) \
(x = pattern[idx].x*a - pattern[idx].y*b, \
y = pattern[idx].x*b + pattern[idx].y*a, \
ix = cvRound(x), \
iy = cvRound(y), \
*(center + iy*step + ix) )
#else
#define GET_VALUE(idx) \
(x = pattern[idx].x*a - pattern[idx].y*b, \
y = pattern[idx].x*b + pattern[idx].y*a, \ y = pattern[idx].x*b + pattern[idx].y*a, \
ix = cvRound(x), \ ix = cvFloor(x), iy = cvFloor(y), \
iy = cvRound(y), \ x -= ix, y -= iy, \
*(center + iy*step + ix) ) cvRound(center[iy*step + ix]*(1-x)*(1-y) + center[(iy+1)*step + ix]*(1-x)*y + \
#else center[iy*step + ix+1]*x*(1-y) + center[(iy+1)*step + ix+1]*x*y))
#define GET_VALUE(idx) \ #endif
(x = pattern[idx].x*a - pattern[idx].y*b, \
y = pattern[idx].x*b + pattern[idx].y*a, \
ix = cvFloor(x), iy = cvFloor(y), \
x -= ix, y -= iy, \
cvRound(center[iy*step + ix]*(1-x)*(1-y) + center[(iy+1)*step + ix]*(1-x)*y + \
center[iy*step + ix+1]*x*(1-y) + center[(iy+1)*step + ix+1]*x*y))
#endif
if( WTA_K == 2 ) if( WTA_K == 2 )
{
for (int i = 0; i < dsize; ++i, pattern += 16)
{ {
int t0, t1, val; for (i = 0; i < dsize; ++i, pattern += 16)
t0 = GET_VALUE(0); t1 = GET_VALUE(1); {
val = t0 < t1; int t0, t1, val;
t0 = GET_VALUE(2); t1 = GET_VALUE(3); t0 = GET_VALUE(0); t1 = GET_VALUE(1);
val |= (t0 < t1) << 1; val = t0 < t1;
t0 = GET_VALUE(4); t1 = GET_VALUE(5); t0 = GET_VALUE(2); t1 = GET_VALUE(3);
val |= (t0 < t1) << 2; val |= (t0 < t1) << 1;
t0 = GET_VALUE(6); t1 = GET_VALUE(7); t0 = GET_VALUE(4); t1 = GET_VALUE(5);
val |= (t0 < t1) << 3; val |= (t0 < t1) << 2;
t0 = GET_VALUE(8); t1 = GET_VALUE(9); t0 = GET_VALUE(6); t1 = GET_VALUE(7);
val |= (t0 < t1) << 4; val |= (t0 < t1) << 3;
t0 = GET_VALUE(10); t1 = GET_VALUE(11); t0 = GET_VALUE(8); t1 = GET_VALUE(9);
val |= (t0 < t1) << 5; val |= (t0 < t1) << 4;
t0 = GET_VALUE(12); t1 = GET_VALUE(13); t0 = GET_VALUE(10); t1 = GET_VALUE(11);
val |= (t0 < t1) << 6; val |= (t0 < t1) << 5;
t0 = GET_VALUE(14); t1 = GET_VALUE(15); t0 = GET_VALUE(12); t1 = GET_VALUE(13);
val |= (t0 < t1) << 7; val |= (t0 < t1) << 6;
t0 = GET_VALUE(14); t1 = GET_VALUE(15);
val |= (t0 < t1) << 7;
desc[i] = (uchar)val; desc[i] = (uchar)val;
}
} }
} else if( WTA_K == 3 )
else if( WTA_K == 3 )
{
for (int i = 0; i < dsize; ++i, pattern += 12)
{ {
int t0, t1, t2, val; for (i = 0; i < dsize; ++i, pattern += 12)
t0 = GET_VALUE(0); t1 = GET_VALUE(1); t2 = GET_VALUE(2); {
val = t2 > t1 ? (t2 > t0 ? 2 : 0) : (t1 > t0); int t0, t1, t2, val;
t0 = GET_VALUE(0); t1 = GET_VALUE(1); t2 = GET_VALUE(2);
val = t2 > t1 ? (t2 > t0 ? 2 : 0) : (t1 > t0);
t0 = GET_VALUE(3); t1 = GET_VALUE(4); t2 = GET_VALUE(5); t0 = GET_VALUE(3); t1 = GET_VALUE(4); t2 = GET_VALUE(5);
val |= (t2 > t1 ? (t2 > t0 ? 2 : 0) : (t1 > t0)) << 2; val |= (t2 > t1 ? (t2 > t0 ? 2 : 0) : (t1 > t0)) << 2;
t0 = GET_VALUE(6); t1 = GET_VALUE(7); t2 = GET_VALUE(8); t0 = GET_VALUE(6); t1 = GET_VALUE(7); t2 = GET_VALUE(8);
val |= (t2 > t1 ? (t2 > t0 ? 2 : 0) : (t1 > t0)) << 4; val |= (t2 > t1 ? (t2 > t0 ? 2 : 0) : (t1 > t0)) << 4;
t0 = GET_VALUE(9); t1 = GET_VALUE(10); t2 = GET_VALUE(11); t0 = GET_VALUE(9); t1 = GET_VALUE(10); t2 = GET_VALUE(11);
val |= (t2 > t1 ? (t2 > t0 ? 2 : 0) : (t1 > t0)) << 6; val |= (t2 > t1 ? (t2 > t0 ? 2 : 0) : (t1 > t0)) << 6;
desc[i] = (uchar)val; desc[i] = (uchar)val;
}
} }
} else if( WTA_K == 4 )
else if( WTA_K == 4 )
{
for (int i = 0; i < dsize; ++i, pattern += 16)
{ {
int t0, t1, t2, t3, u, v, k, val; for (i = 0; i < dsize; ++i, pattern += 16)
t0 = GET_VALUE(0); t1 = GET_VALUE(1); {
t2 = GET_VALUE(2); t3 = GET_VALUE(3); int t0, t1, t2, t3, u, v, k, val;
u = 0, v = 2; t0 = GET_VALUE(0); t1 = GET_VALUE(1);
if( t1 > t0 ) t0 = t1, u = 1; t2 = GET_VALUE(2); t3 = GET_VALUE(3);
if( t3 > t2 ) t2 = t3, v = 3; u = 0, v = 2;
k = t0 > t2 ? u : v; if( t1 > t0 ) t0 = t1, u = 1;
val = k; if( t3 > t2 ) t2 = t3, v = 3;
k = t0 > t2 ? u : v;
val = k;
t0 = GET_VALUE(4); t1 = GET_VALUE(5); t0 = GET_VALUE(4); t1 = GET_VALUE(5);
t2 = GET_VALUE(6); t3 = GET_VALUE(7); t2 = GET_VALUE(6); t3 = GET_VALUE(7);
u = 0, v = 2; u = 0, v = 2;
if( t1 > t0 ) t0 = t1, u = 1; if( t1 > t0 ) t0 = t1, u = 1;
if( t3 > t2 ) t2 = t3, v = 3; if( t3 > t2 ) t2 = t3, v = 3;
k = t0 > t2 ? u : v; k = t0 > t2 ? u : v;
val |= k << 2; val |= k << 2;
t0 = GET_VALUE(8); t1 = GET_VALUE(9); t0 = GET_VALUE(8); t1 = GET_VALUE(9);
t2 = GET_VALUE(10); t3 = GET_VALUE(11); t2 = GET_VALUE(10); t3 = GET_VALUE(11);
u = 0, v = 2; u = 0, v = 2;
if( t1 > t0 ) t0 = t1, u = 1; if( t1 > t0 ) t0 = t1, u = 1;
if( t3 > t2 ) t2 = t3, v = 3; if( t3 > t2 ) t2 = t3, v = 3;
k = t0 > t2 ? u : v; k = t0 > t2 ? u : v;
val |= k << 4; val |= k << 4;
t0 = GET_VALUE(12); t1 = GET_VALUE(13); t0 = GET_VALUE(12); t1 = GET_VALUE(13);
t2 = GET_VALUE(14); t3 = GET_VALUE(15); t2 = GET_VALUE(14); t3 = GET_VALUE(15);
u = 0, v = 2; u = 0, v = 2;
if( t1 > t0 ) t0 = t1, u = 1; if( t1 > t0 ) t0 = t1, u = 1;
if( t3 > t2 ) t2 = t3, v = 3; if( t3 > t2 ) t2 = t3, v = 3;
k = t0 > t2 ? u : v; k = t0 > t2 ? u : v;
val |= k << 6; val |= k << 6;
desc[i] = (uchar)val; desc[i] = (uchar)val;
}
} }
else
CV_Error( Error::StsBadSize, "Wrong WTA_K. It can be only 2, 3 or 4." );
#undef GET_VALUE
} }
else
CV_Error( Error::StsBadSize, "Wrong WTA_K. It can be only 2, 3 or 4." );
#undef GET_VALUE
} }
@ -591,21 +681,37 @@ void ORB::operator()(InputArray image, InputArray mask, std::vector<KeyPoint>& k
} }
/** Compute the ORB keypoint orientations static void uploadORBKeypoints(const std::vector<KeyPoint>& src, std::vector<Vec3i>& buf, OutputArray dst)
* @param image the image to compute the features and descriptors on
* @param integral_image the integral image of the iamge (can be empty, but the computation will be slower)
* @param scale the scale at which we compute the orientation
* @param keypoints the resulting keypoints
*/
static void computeOrientation(const Mat& image, std::vector<KeyPoint>& keypoints,
int halfPatchSize, const std::vector<int>& umax)
{ {
// Process each keypoint size_t i, n = src.size();
for (std::vector<KeyPoint>::iterator keypoint = keypoints.begin(), buf.resize(std::max(buf.size(), n));
keypointEnd = keypoints.end(); keypoint != keypointEnd; ++keypoint) for( i = 0; i < n; i++ )
buf[i] = Vec3i(cvRound(src[i].pt.x), cvRound(src[i].pt.y), src[i].octave);
copyVectorToUMat(buf, dst);
}
typedef union if32_t
{
int i;
float f;
}
if32_t;
static void uploadORBKeypoints(const std::vector<KeyPoint>& src,
const std::vector<float>& layerScale,
std::vector<Vec4i>& buf, OutputArray dst)
{
size_t i, n = src.size();
buf.resize(std::max(buf.size(), n));
for( i = 0; i < n; i++ )
{ {
keypoint->angle = IC_Angle(image, halfPatchSize, keypoint->pt, umax); int z = src[i].octave;
float scale = 1.f/layerScale[z];
if32_t angle;
angle.f = src[i].angle;
buf[i] = Vec4i(cvRound(src[i].pt.x*scale), cvRound(src[i].pt.y*scale), z, angle.i);
} }
copyVectorToUMat(buf, dst);
} }
@ -614,13 +720,18 @@ static void computeOrientation(const Mat& image, std::vector<KeyPoint>& keypoint
* @param mask_pyramid the masks to apply at every level * @param mask_pyramid the masks to apply at every level
* @param keypoints the resulting keypoints, clustered per level * @param keypoints the resulting keypoints, clustered per level
*/ */
static void computeKeyPoints(const std::vector<Mat>& imagePyramid, static void computeKeyPoints(const Mat& imagePyramid,
const std::vector<Mat>& maskPyramid, const UMat& uimagePyramid,
std::vector<std::vector<KeyPoint> >& allKeypoints, const Mat& maskPyramid,
int nfeatures, int firstLevel, double scaleFactor, const std::vector<Rect>& layerInfo,
int edgeThreshold, int patchSize, int scoreType ) const UMat& ulayerInfo,
const std::vector<float>& layerScale,
std::vector<KeyPoint>& allKeypoints,
int nfeatures, double scaleFactor,
int edgeThreshold, int patchSize, int scoreType,
bool useOCL )
{ {
int nlevels = (int)imagePyramid.size(); int i, nkeypoints, level, nlevels = (int)layerInfo.size();
std::vector<int> nfeaturesPerLevel(nlevels); std::vector<int> nfeaturesPerLevel(nlevels);
// fill the extractors and descriptors for the corresponding scales // fill the extractors and descriptors for the corresponding scales
@ -628,7 +739,7 @@ static void computeKeyPoints(const std::vector<Mat>& imagePyramid,
float ndesiredFeaturesPerScale = nfeatures*(1 - factor)/(1 - (float)std::pow((double)factor, (double)nlevels)); float ndesiredFeaturesPerScale = nfeatures*(1 - factor)/(1 - (float)std::pow((double)factor, (double)nlevels));
int sumFeatures = 0; int sumFeatures = 0;
for( int level = 0; level < nlevels-1; level++ ) for( level = 0; level < nlevels-1; level++ )
{ {
nfeaturesPerLevel[level] = cvRound(ndesiredFeaturesPerScale); nfeaturesPerLevel[level] = cvRound(ndesiredFeaturesPerScale);
sumFeatures += nfeaturesPerLevel[level]; sumFeatures += nfeaturesPerLevel[level];
@ -657,66 +768,116 @@ static void computeKeyPoints(const std::vector<Mat>& imagePyramid,
++v0; ++v0;
} }
allKeypoints.resize(nlevels); allKeypoints.clear();
std::vector<KeyPoint> keypoints;
std::vector<int> counters(nlevels);
keypoints.reserve(nfeaturesPerLevel[0]*2);
for (int level = 0; level < nlevels; ++level) for( level = 0; level < nlevels; level++ )
{ {
int featuresNum = nfeaturesPerLevel[level]; int featuresNum = nfeaturesPerLevel[level];
allKeypoints[level].reserve(featuresNum*2); Mat img = imagePyramid(layerInfo[level]);
Mat mask = maskPyramid.empty() ? Mat() : maskPyramid(layerInfo[level]);
std::vector<KeyPoint> & keypoints = allKeypoints[level];
// Detect FAST features, 20 is a good threshold // Detect FAST features, 20 is a good threshold
FastFeatureDetector fd(20, true); FastFeatureDetector fd(20, true);
fd.detect(imagePyramid[level], keypoints, maskPyramid[level]); fd.detect(img, keypoints, mask);
// Remove keypoints very close to the border // Remove keypoints very close to the border
KeyPointsFilter::runByImageBorder(keypoints, imagePyramid[level].size(), edgeThreshold); KeyPointsFilter::runByImageBorder(keypoints, img.size(), edgeThreshold);
if( scoreType == ORB::HARRIS_SCORE ) // Keep more points than necessary as FAST does not give amazing corners
KeyPointsFilter::retainBest(keypoints, scoreType == ORB::HARRIS_SCORE ? 2 * featuresNum : featuresNum);
nkeypoints = (int)keypoints.size();
counters[level] = nkeypoints;
float sf = layerScale[level];
for( i = 0; i < nkeypoints; i++ )
{ {
// Keep more points than necessary as FAST does not give amazing corners keypoints[i].octave = level;
KeyPointsFilter::retainBest(keypoints, 2 * featuresNum); keypoints[i].size = patchSize*sf;
// Compute the Harris cornerness (better scoring than FAST)
HarrisResponses(imagePyramid[level], keypoints, 7, HARRIS_K);
} }
//cull to the final desired level, using the new Harris scores or the original FAST scores. std::copy(keypoints.begin(), keypoints.end(), std::back_inserter(allKeypoints));
KeyPointsFilter::retainBest(keypoints, featuresNum);
float sf = getScale(level, firstLevel, scaleFactor);
// Set the level of the coordinates
for (std::vector<KeyPoint>::iterator keypoint = keypoints.begin(),
keypointEnd = keypoints.end(); keypoint != keypointEnd; ++keypoint)
{
keypoint->octave = level;
keypoint->size = patchSize*sf;
}
computeOrientation(imagePyramid[level], keypoints, halfPatchSize, umax);
} }
}
std::vector<Vec3i> ukeypoints_buf;
/** Compute the ORB decriptors nkeypoints = (int)allKeypoints.size();
* @param image the image to compute the features and descriptors on Mat responses;
* @param integral_image the integral image of the image (can be empty, but the computation will be slower) UMat ukeypoints, uresponses(1, nkeypoints, CV_32F);
* @param level the scale at which we compute the orientation
* @param keypoints the keypoints to use
* @param descriptors the resulting descriptors
*/
static void computeDescriptors(const Mat& image, std::vector<KeyPoint>& keypoints, Mat& descriptors,
const std::vector<Point>& pattern, int dsize, int WTA_K)
{
//convert to grayscale if more than one color
CV_Assert(image.type() == CV_8UC1);
//create the descriptor mat, keypoints.size() rows, BYTES cols
descriptors = Mat::zeros((int)keypoints.size(), dsize, CV_8UC1);
for (size_t i = 0; i < keypoints.size(); i++) // Select best features using the Harris cornerness (better scoring than FAST)
computeOrbDescriptor(keypoints[i], image, &pattern[0], descriptors.ptr((int)i), dsize, WTA_K); if( scoreType == ORB::HARRIS_SCORE )
{
if( useOCL )
{
uploadORBKeypoints(allKeypoints, ukeypoints_buf, ukeypoints);
useOCL = ocl_HarrisResponses( uimagePyramid, ulayerInfo, ukeypoints,
uresponses, nkeypoints, 7, HARRIS_K );
if( useOCL )
{
uresponses.copyTo(responses);
for( i = 0; i < nkeypoints; i++ )
allKeypoints[i].response = responses.at<float>(i);
}
}
if( !useOCL )
HarrisResponses(imagePyramid, layerInfo, allKeypoints, 7, HARRIS_K);
std::vector<KeyPoint> newAllKeypoints;
newAllKeypoints.reserve(nfeaturesPerLevel[0]*nlevels);
int offset = 0;
for( level = 0; level < nlevels; level++ )
{
int featuresNum = nfeaturesPerLevel[level];
nkeypoints = counters[level];
keypoints.resize(nkeypoints);
std::copy(allKeypoints.begin() + offset,
allKeypoints.begin() + offset + nkeypoints,
keypoints.begin());
offset += nkeypoints;
//cull to the final desired level, using the new Harris scores.
KeyPointsFilter::retainBest(keypoints, featuresNum);
std::copy(keypoints.begin(), keypoints.end(), std::back_inserter(newAllKeypoints));
}
std::swap(allKeypoints, newAllKeypoints);
}
nkeypoints = (int)allKeypoints.size();
if( useOCL )
{
UMat uumax;
if( useOCL )
copyVectorToUMat(umax, uumax);
uploadORBKeypoints(allKeypoints, ukeypoints_buf, ukeypoints);
useOCL = ocl_ICAngles(uimagePyramid, ulayerInfo, ukeypoints, uresponses, uumax,
nkeypoints, halfPatchSize);
if( useOCL )
{
uresponses.copyTo(responses);
for( i = 0; i < nkeypoints; i++ )
allKeypoints[i].angle = responses.at<float>(i);
}
}
if( !useOCL )
{
ICAngles(imagePyramid, layerInfo, allKeypoints, umax, halfPatchSize);
}
for( i = 0; i < nkeypoints; i++ )
{
float scale = layerScale[allKeypoints[i].octave];
allKeypoints[i].pt *= scale;
}
} }
@ -728,8 +889,8 @@ static void computeDescriptors(const Mat& image, std::vector<KeyPoint>& keypoint
* @param do_keypoints if true, the keypoints are computed, otherwise used as an input * @param do_keypoints if true, the keypoints are computed, otherwise used as an input
* @param do_descriptors if true, also computes the descriptors * @param do_descriptors if true, also computes the descriptors
*/ */
void ORB::operator()( InputArray _image, InputArray _mask, std::vector<KeyPoint>& _keypoints, void ORB::operator()( InputArray _image, InputArray _mask, std::vector<KeyPoint>& keypoints,
OutputArray _descriptors, bool useProvidedKeypoints) const OutputArray _descriptors, bool useProvidedKeypoints ) const
{ {
CV_Assert(patchSize >= 2); CV_Assert(patchSize >= 2);
@ -744,11 +905,14 @@ void ORB::operator()( InputArray _image, InputArray _mask, std::vector<KeyPoint>
int halfPatchSize = patchSize / 2; int halfPatchSize = patchSize / 2;
int border = std::max(edgeThreshold, std::max(halfPatchSize, HARRIS_BLOCK_SIZE/2))+1; int border = std::max(edgeThreshold, std::max(halfPatchSize, HARRIS_BLOCK_SIZE/2))+1;
bool useOCL = ocl::useOpenCL();
Mat image = _image.getMat(), mask = _mask.getMat(); Mat image = _image.getMat(), mask = _mask.getMat();
if( image.type() != CV_8UC1 ) if( image.type() != CV_8UC1 )
cvtColor(_image, image, COLOR_BGR2GRAY); cvtColor(_image, image, COLOR_BGR2GRAY);
int levelsNum = this->nlevels; int i, level, nLevels = this->nlevels, nkeypoints = (int)keypoints.size();
bool sortedByLevel = true;
if( !do_keypoints ) if( !do_keypoints )
{ {
@ -761,129 +925,145 @@ void ORB::operator()( InputArray _image, InputArray _mask, std::vector<KeyPoint>
// //
// In short, ultimately the descriptor should // In short, ultimately the descriptor should
// ignore octave parameter and deal only with the keypoint size. // ignore octave parameter and deal only with the keypoint size.
levelsNum = 0; nLevels = 0;
for( size_t i = 0; i < _keypoints.size(); i++ ) for( i = 0; i < nkeypoints; i++ )
levelsNum = std::max(levelsNum, std::max(_keypoints[i].octave, 0)); {
levelsNum++; level = keypoints[i].octave;
CV_Assert(level >= 0);
if( i > 0 && level < keypoints[i-1].octave )
sortedByLevel = false;
nLevels = std::max(nLevels, level);
}
nLevels++;
} }
// Pre-compute the scale pyramids std::vector<Rect> layerInfo(nLevels);
std::vector<Mat> imagePyramid(levelsNum), maskPyramid(levelsNum); std::vector<int> layerOfs(nLevels);
for (int level = 0; level < levelsNum; ++level) std::vector<float> layerScale(nLevels);
Mat imagePyramid, maskPyramid;
UMat uimagePyramid, ulayerInfo;
int level_dy = image.rows + border*2;
Point level_ofs(0,0);
Size bufSize((image.cols + border*2 + 15) & -16, 0);
for( level = 0; level < nLevels; level++ )
{ {
float scale = 1/getScale(level, firstLevel, scaleFactor); float scale = getScale(level, firstLevel, scaleFactor);
Size sz(cvRound(image.cols*scale), cvRound(image.rows*scale)); layerScale[level] = scale;
Size sz(cvRound(image.cols/scale), cvRound(image.rows/scale));
Size wholeSize(sz.width + border*2, sz.height + border*2); Size wholeSize(sz.width + border*2, sz.height + border*2);
Mat temp(wholeSize, image.type()), masktemp; if( level_ofs.x + wholeSize.width > bufSize.width )
imagePyramid[level] = temp(Rect(border, border, sz.width, sz.height)); {
level_ofs = Point(0, level_ofs.y + level_dy);
level_dy = wholeSize.height;
}
Rect linfo(level_ofs.x + border, level_ofs.y + border, sz.width, sz.height);
layerInfo[level] = linfo;
layerOfs[level] = linfo.y*bufSize.width + linfo.x;
level_ofs.x += wholeSize.width;
}
bufSize.height = level_ofs.y + level_dy;
imagePyramid.create(bufSize, CV_8U);
if( !mask.empty() )
maskPyramid.create(bufSize, CV_8U);
Mat prevImg = image, prevMask = mask;
// Pre-compute the scale pyramids
for (level = 0; level < nLevels; ++level)
{
Rect linfo = layerInfo[level];
Size sz(linfo.width, linfo.height);
Size wholeSize(sz.width + border*2, sz.height + border*2);
Rect wholeLinfo = Rect(linfo.x - border, linfo.y - border, wholeSize.width, wholeSize.height);
Mat extImg = imagePyramid(wholeLinfo), extMask;
Mat currImg = extImg(Rect(border, border, sz.width, sz.height)), currMask;
if( !mask.empty() ) if( !mask.empty() )
{ {
masktemp = Mat(wholeSize, mask.type()); extMask = maskPyramid(wholeLinfo);
maskPyramid[level] = masktemp(Rect(border, border, sz.width, sz.height)); currMask = extMask(Rect(border, border, sz.width, sz.height));
} }
// Compute the resized image // Compute the resized image
if( level != firstLevel ) if( level != firstLevel )
{ {
if( level < firstLevel ) resize(prevImg, currImg, sz, 0, 0, INTER_LINEAR);
if( !mask.empty() )
{ {
resize(image, imagePyramid[level], sz, 0, 0, INTER_LINEAR); resize(prevMask, currMask, sz, 0, 0, INTER_LINEAR);
if (!mask.empty()) if( level > firstLevel )
resize(mask, maskPyramid[level], sz, 0, 0, INTER_LINEAR); threshold(currMask, currMask, 254, 0, THRESH_TOZERO);
}
else
{
resize(imagePyramid[level-1], imagePyramid[level], sz, 0, 0, INTER_LINEAR);
if (!mask.empty())
{
resize(maskPyramid[level-1], maskPyramid[level], sz, 0, 0, INTER_LINEAR);
threshold(maskPyramid[level], maskPyramid[level], 254, 0, THRESH_TOZERO);
}
} }
copyMakeBorder(imagePyramid[level], temp, border, border, border, border, copyMakeBorder(currImg, extImg, border, border, border, border,
BORDER_REFLECT_101+BORDER_ISOLATED); BORDER_REFLECT_101+BORDER_ISOLATED);
if (!mask.empty()) if (!mask.empty())
copyMakeBorder(maskPyramid[level], masktemp, border, border, border, border, copyMakeBorder(currMask, extMask, border, border, border, border,
BORDER_CONSTANT+BORDER_ISOLATED); BORDER_CONSTANT+BORDER_ISOLATED);
} }
else else
{ {
copyMakeBorder(image, temp, border, border, border, border, copyMakeBorder(image, extImg, border, border, border, border,
BORDER_REFLECT_101); BORDER_REFLECT_101);
if( !mask.empty() ) if( !mask.empty() )
copyMakeBorder(mask, masktemp, border, border, border, border, copyMakeBorder(mask, extMask, border, border, border, border,
BORDER_CONSTANT+BORDER_ISOLATED); BORDER_CONSTANT+BORDER_ISOLATED);
} }
prevImg = currImg;
prevMask = currMask;
} }
// Pre-compute the keypoints (we keep the best over all scales, so this has to be done beforehand if( useOCL )
std::vector < std::vector<KeyPoint> > allKeypoints; copyVectorToUMat(layerOfs, ulayerInfo);
if( do_keypoints ) if( do_keypoints )
{ {
if( useOCL )
imagePyramid.copyTo(uimagePyramid);
// Get keypoints, those will be far enough from the border that no check will be required for the descriptor // Get keypoints, those will be far enough from the border that no check will be required for the descriptor
computeKeyPoints(imagePyramid, maskPyramid, allKeypoints, computeKeyPoints(imagePyramid, uimagePyramid, maskPyramid,
nfeatures, firstLevel, scaleFactor, layerInfo, ulayerInfo, layerScale, keypoints,
edgeThreshold, patchSize, scoreType); nfeatures, scaleFactor, edgeThreshold, patchSize, scoreType, useOCL);
// make sure we have the right number of keypoints keypoints
/*std::vector<KeyPoint> temp;
for (int level = 0; level < n_levels; ++level)
{
std::vector<KeyPoint>& keypoints = all_keypoints[level];
temp.insert(temp.end(), keypoints.begin(), keypoints.end());
keypoints.clear();
}
KeyPoint::retainBest(temp, n_features_);
for (std::vector<KeyPoint>::iterator keypoint = temp.begin(),
keypoint_end = temp.end(); keypoint != keypoint_end; ++keypoint)
all_keypoints[keypoint->octave].push_back(*keypoint);*/
} }
else else
{ {
// Remove keypoints very close to the border KeyPointsFilter::runByImageBorder(keypoints, image.size(), edgeThreshold);
KeyPointsFilter::runByImageBorder(_keypoints, image.size(), edgeThreshold);
// Cluster the input keypoints depending on the level they were computed at if( !sortedByLevel )
allKeypoints.resize(levelsNum);
for (std::vector<KeyPoint>::iterator keypoint = _keypoints.begin(),
keypointEnd = _keypoints.end(); keypoint != keypointEnd; ++keypoint)
allKeypoints[keypoint->octave].push_back(*keypoint);
// Make sure we rescale the coordinates
for (int level = 0; level < levelsNum; ++level)
{ {
if (level == firstLevel) std::vector<std::vector<KeyPoint> > allKeypoints(nLevels);
continue; nkeypoints = (int)keypoints.size();
for( i = 0; i < nkeypoints; i++ )
std::vector<KeyPoint> & keypoints = allKeypoints[level]; {
float scale = 1/getScale(level, firstLevel, scaleFactor); level = keypoints[i].octave;
for (std::vector<KeyPoint>::iterator keypoint = keypoints.begin(), CV_Assert(0 <= level);
keypointEnd = keypoints.end(); keypoint != keypointEnd; ++keypoint) allKeypoints[level].push_back(keypoints[i]);
keypoint->pt *= scale; }
keypoints.clear();
for( level = 0; level < nLevels; level++ )
std::copy(allKeypoints[level].begin(), allKeypoints[level].end(), std::back_inserter(keypoints));
} }
} }
Mat descriptors;
std::vector<Point> pattern;
if( do_descriptors ) if( do_descriptors )
{ {
int nkeypoints = 0; int dsize = descriptorSize();
for (int level = 0; level < levelsNum; ++level)
nkeypoints += (int)allKeypoints[level].size(); nkeypoints = (int)keypoints.size();
if( nkeypoints == 0 ) if( nkeypoints == 0 )
_descriptors.release();
else
{ {
_descriptors.create(nkeypoints, descriptorSize(), CV_8U); _descriptors.release();
descriptors = _descriptors.getMat(); return;
} }
_descriptors.create(nkeypoints, dsize, CV_8U);
std::vector<Point> pattern;
const int npoints = 512; const int npoints = 512;
Point patternbuf[npoints]; Point patternbuf[npoints];
const Point* pattern0 = (const Point*)bit_pattern_31_; const Point* pattern0 = (const Point*)bit_pattern_31_;
@ -903,43 +1083,36 @@ void ORB::operator()( InputArray _image, InputArray _mask, std::vector<KeyPoint>
int ntuples = descriptorSize()*4; int ntuples = descriptorSize()*4;
initializeOrbPattern(pattern0, pattern, ntuples, WTA_K, npoints); initializeOrbPattern(pattern0, pattern, ntuples, WTA_K, npoints);
} }
}
_keypoints.clear(); for( level = 0; level < nLevels; level++ )
int offset = 0;
for (int level = 0; level < levelsNum; ++level)
{
// Get the features and compute their orientation
std::vector<KeyPoint>& keypoints = allKeypoints[level];
int nkeypoints = (int)keypoints.size();
// Compute the descriptors
if (do_descriptors)
{ {
Mat desc;
if (!descriptors.empty())
{
desc = descriptors.rowRange(offset, offset + nkeypoints);
}
offset += nkeypoints;
// preprocess the resized image // preprocess the resized image
Mat& workingMat = imagePyramid[level]; Mat workingMat = imagePyramid(layerInfo[level]);
//boxFilter(working_mat, working_mat, working_mat.depth(), Size(5,5), Point(-1,-1), true, BORDER_REFLECT_101); //boxFilter(working_mat, working_mat, working_mat.depth(), Size(5,5), Point(-1,-1), true, BORDER_REFLECT_101);
GaussianBlur(workingMat, workingMat, Size(7, 7), 2, 2, BORDER_REFLECT_101); GaussianBlur(workingMat, workingMat, Size(7, 7), 2, 2, BORDER_REFLECT_101);
computeDescriptors(workingMat, keypoints, desc, pattern, descriptorSize(), WTA_K);
} }
// Copy to the output data if( useOCL )
if (level != firstLevel)
{ {
float scale = getScale(level, firstLevel, scaleFactor); imagePyramid.copyTo(uimagePyramid);
for (std::vector<KeyPoint>::iterator keypoint = keypoints.begin(), std::vector<Vec4i> kptbuf;
keypointEnd = keypoints.end(); keypoint != keypointEnd; ++keypoint) UMat ukeypoints, upattern;
keypoint->pt *= scale; copyVectorToUMat(pattern, upattern);
uploadORBKeypoints(keypoints, layerScale, kptbuf, ukeypoints);
UMat udescriptors = _descriptors.getUMat();
useOCL = ocl_computeOrbDescriptors(uimagePyramid, ulayerInfo,
ukeypoints, udescriptors, upattern,
nkeypoints, dsize, WTA_K);
}
if( !useOCL )
{
Mat descriptors = _descriptors.getMat();
computeOrbDescriptors(imagePyramid, layerInfo, layerScale,
keypoints, descriptors, pattern, dsize, WTA_K);
} }
// And add the keypoints to the output
_keypoints.insert(_keypoints.end(), keypoints.begin(), keypoints.end());
} }
} }