refactored gpu module, added vec math operators for uint, added support of 2 channel images into gpu::sum (removed support of double)

This commit is contained in:
Alexey Spizhevoy 2010-12-15 15:12:32 +00:00
parent e5eec31be1
commit d8a7ff1e00
4 changed files with 292 additions and 71 deletions

View File

@ -486,10 +486,10 @@ void cv::gpu::flip(const GpuMat& src, GpuMat& dst, int flipCode)
namespace cv { namespace gpu { namespace mathfunc
{
template <typename T>
void sum_caller(const DevMem2D src, PtrStep buf, double* sum);
void sum_caller(const DevMem2D src, PtrStep buf, double* sum, int cn);
template <typename T>
void sum_multipass_caller(const DevMem2D src, PtrStep buf, double* sum);
void sum_multipass_caller(const DevMem2D src, PtrStep buf, double* sum, int cn);
template <typename T>
void sqsum_caller(const DevMem2D src, PtrStep buf, double* sum);
@ -499,7 +499,7 @@ namespace cv { namespace gpu { namespace mathfunc
namespace sum
{
void get_buf_size_required(int cols, int rows, int& bufcols, int& bufrows);
void get_buf_size_required(int cols, int rows, int cn, int& bufcols, int& bufrows);
}
}}}
@ -512,27 +512,26 @@ Scalar cv::gpu::sum(const GpuMat& src)
Scalar cv::gpu::sum(const GpuMat& src, GpuMat& buf)
{
using namespace mathfunc;
CV_Assert(src.channels() == 1);
typedef void (*Caller)(const DevMem2D, PtrStep, double*);
typedef void (*Caller)(const DevMem2D, PtrStep, double*, int);
static const Caller callers[2][7] =
{ { sum_multipass_caller<unsigned char>, sum_multipass_caller<char>,
sum_multipass_caller<unsigned short>, sum_multipass_caller<short>,
sum_multipass_caller<int>, sum_multipass_caller<float>, 0 },
{ sum_caller<unsigned char>, sum_caller<char>,
sum_caller<unsigned short>, sum_caller<short>,
sum_caller<int>, sum_caller<float>, sum_caller<double> } };
sum_caller<int>, sum_caller<float>, 0 } };
Size bufSize;
sum::get_buf_size_required(src.cols, src.rows, bufSize.width, bufSize.height);
sum::get_buf_size_required(src.cols, src.rows, src.channels(), bufSize.width, bufSize.height);
buf.create(bufSize, CV_8U);
Caller caller = callers[hasAtomicsSupport(getDevice())][src.type()];
Caller caller = callers[hasAtomicsSupport(getDevice())][src.depth()];
if (!caller) CV_Error(CV_StsBadArg, "sum: unsupported type");
double result;
caller(src, buf, &result);
return result;
double result[4];
caller(src, buf, result, src.channels());
return Scalar(result[0], result[1], result[2], result[3]);
}
Scalar cv::gpu::sqrSum(const GpuMat& src)
@ -553,10 +552,10 @@ Scalar cv::gpu::sqrSum(const GpuMat& src, GpuMat& buf)
sqsum_multipass_caller<int>, sqsum_multipass_caller<float>, 0 },
{ sqsum_caller<unsigned char>, sqsum_caller<char>,
sqsum_caller<unsigned short>, sqsum_caller<short>,
sqsum_caller<int>, sqsum_caller<float>, sqsum_caller<double> } };
sqsum_caller<int>, sqsum_caller<float>, 0 } };
Size bufSize;
sum::get_buf_size_required(src.cols, src.rows, bufSize.width, bufSize.height);
sum::get_buf_size_required(src.cols, src.rows, 1, bufSize.width, bufSize.height);
buf.create(bufSize, CV_8U);
Caller caller = callers[hasAtomicsSupport(getDevice())][src.type()];

View File

@ -42,6 +42,7 @@
#include "opencv2/gpu/device/limits_gpu.hpp"
#include "opencv2/gpu/device/saturate_cast.hpp"
#include "opencv2/gpu/device/vecmath.hpp"
#include "transform.hpp"
#include "internal_shared.hpp"
@ -1451,11 +1452,11 @@ namespace cv { namespace gpu { namespace mathfunc
}
void get_buf_size_required(int cols, int rows, int& bufcols, int& bufrows)
void get_buf_size_required(int cols, int rows, int cn, int& bufcols, int& bufrows)
{
dim3 threads, grid;
estimate_thread_cfg(cols, rows, threads, grid);
bufcols = grid.x * grid.y * sizeof(double);
bufcols = grid.x * grid.y * sizeof(double) * cn;
bufrows = 1;
}
@ -1469,7 +1470,7 @@ namespace cv { namespace gpu { namespace mathfunc
}
template <typename T, typename R, typename Op, int nthreads>
__global__ void sum_kernel(const DevMem2D_<T> src, R* result)
__global__ void sum_kernel(const DevMem2D src, R* result)
{
__shared__ R smem[nthreads];
@ -1481,7 +1482,7 @@ namespace cv { namespace gpu { namespace mathfunc
R sum = 0;
for (int y = 0; y < ctheight && y0 + y * blockDim.y < src.rows; ++y)
{
const T* ptr = src.ptr(y0 + y * blockDim.y);
const T* ptr = (const T*)src.ptr(y0 + y * blockDim.y);
for (int x = 0; x < ctwidth && x0 + x * blockDim.x < src.cols; ++x)
sum += Op::call(ptr[x0 + x * blockDim.x]);
}
@ -1539,11 +1540,116 @@ namespace cv { namespace gpu { namespace mathfunc
result[0] = smem[0];
}
template <typename T, typename R, typename Op, int nthreads>
__global__ void sum_kernel_C2(const DevMem2D src, typename TypeVec<R, 2>::vec_t* result)
{
typedef typename TypeVec<T, 2>::vec_t SrcType;
typedef typename TypeVec<R, 2>::vec_t DstType;
__shared__ R smem[nthreads * 2];
const int x0 = blockIdx.x * blockDim.x * ctwidth + threadIdx.x;
const int y0 = blockIdx.y * blockDim.y * ctheight + threadIdx.y;
const int tid = threadIdx.y * blockDim.x + threadIdx.x;
const int bid = blockIdx.y * gridDim.x + blockIdx.x;
SrcType val;
DstType sum = VecTraits<DstType>::all(0);
for (int y = 0; y < ctheight && y0 + y * blockDim.y < src.rows; ++y)
{
const SrcType* ptr = (const SrcType*)src.ptr(y0 + y * blockDim.y);
for (int x = 0; x < ctwidth && x0 + x * blockDim.x < src.cols; ++x)
{
val = ptr[x0 + x * blockDim.x];
sum = sum + VecTraits<DstType>::make(Op::call(val.x), Op::call(val.y));
}
}
smem[tid] = sum.x;
smem[tid + nthreads] = sum.y;
__syncthreads();
sum_in_smem<nthreads, R>(smem, tid);
sum_in_smem<nthreads, R>(smem + nthreads, tid);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 110
__shared__ bool is_last;
if (tid == 0)
{
DstType res;
res.x = smem[0];
res.y = smem[nthreads];
result[bid] = res;
__threadfence();
unsigned int ticket = atomicInc(&blocks_finished, gridDim.x * gridDim.y);
is_last = (ticket == gridDim.x * gridDim.y - 1);
}
__syncthreads();
if (is_last)
{
DstType res = tid < gridDim.x * gridDim.y ? result[tid] : VecTraits<DstType>::all(0);
smem[tid] = res.x;
smem[tid + nthreads] = res.y;
__syncthreads();
sum_in_smem<nthreads, R>(smem, tid);
sum_in_smem<nthreads, R>(smem + nthreads, tid);
if (tid == 0)
{
res.x = smem[0];
res.y = smem[nthreads];
result[0] = res;
blocks_finished = 0;
}
}
#else
if (tid == 0)
{
DstType res;
res.x = smem[0];
res.y = smem[nthreads];
result[bid] = res;
}
#endif
}
template <typename T, typename R, int nthreads>
__global__ void sum_pass2_kernel_C2(typename TypeVec<R, 2>::vec_t* result, int size)
{
typedef typename TypeVec<R, 2>::vec_t DstType;
__shared__ R smem[nthreads * 2];
const int tid = threadIdx.y * blockDim.x + threadIdx.x;
DstType res = tid < gridDim.x * gridDim.y ? result[tid] : VecTraits<DstType>::all(0);
smem[tid] = res.x;
smem[tid + nthreads] = res.y;
__syncthreads();
sum_in_smem<nthreads, R>(smem, tid);
sum_in_smem<nthreads, R>(smem + nthreads, tid);
if (tid == 0)
{
res.x = smem[0];
res.y = smem[nthreads];
result[0] = res;
}
}
} // namespace sum
template <typename T>
void sum_multipass_caller(const DevMem2D src, PtrStep buf, double* sum)
void sum_multipass_caller(const DevMem2D src, PtrStep buf, double* sum, int cn)
{
using namespace sum;
typedef typename SumType<T>::R R;
@ -1552,23 +1658,76 @@ namespace cv { namespace gpu { namespace mathfunc
estimate_thread_cfg(src.cols, src.rows, threads, grid);
set_kernel_consts(src.cols, src.rows, threads, grid);
R* buf_ = (R*)buf.ptr(0);
sum_kernel<T, R, IdentityOp<R>, threads_x * threads_y><<<grid, threads>>>((const DevMem2D_<T>)src, buf_);
sum_pass2_kernel<T, R, threads_x * threads_y><<<1, threads_x * threads_y>>>(buf_, grid.x * grid.y);
switch (cn)
{
case 1:
sum_kernel<T, R, IdentityOp<R>, threads_x * threads_y><<<grid, threads>>>(
src, (typename TypeVec<R, 1>::vec_t*)buf.ptr(0));
sum_pass2_kernel<T, R, threads_x * threads_y><<<1, threads_x * threads_y>>>(
(typename TypeVec<R, 1>::vec_t*)buf.ptr(0), grid.x * grid.y);
case 2:
sum_kernel_C2<T, R, IdentityOp<R>, threads_x * threads_y><<<grid, threads>>>(
src, (typename TypeVec<R, 2>::vec_t*)buf.ptr(0));
sum_pass2_kernel_C2<T, R, threads_x * threads_y><<<1, threads_x * threads_y>>>(
(typename TypeVec<R, 2>::vec_t*)buf.ptr(0), grid.x * grid.y);
}
cudaSafeCall(cudaThreadSynchronize());
R result = 0;
cudaSafeCall(cudaMemcpy(&result, buf_, sizeof(result), cudaMemcpyDeviceToHost));
sum[0] = result;
R result[4] = {0, 0, 0, 0};
cudaSafeCall(cudaMemcpy(&result, buf.ptr(0), sizeof(R) * cn, cudaMemcpyDeviceToHost));
sum[0] = result[0];
sum[1] = result[1];
sum[2] = result[2];
sum[3] = result[3];
}
template void sum_multipass_caller<unsigned char>(const DevMem2D, PtrStep, double*);
template void sum_multipass_caller<char>(const DevMem2D, PtrStep, double*);
template void sum_multipass_caller<unsigned short>(const DevMem2D, PtrStep, double*);
template void sum_multipass_caller<short>(const DevMem2D, PtrStep, double*);
template void sum_multipass_caller<int>(const DevMem2D, PtrStep, double*);
template void sum_multipass_caller<float>(const DevMem2D, PtrStep, double*);
template void sum_multipass_caller<unsigned char>(const DevMem2D, PtrStep, double*, int);
template void sum_multipass_caller<char>(const DevMem2D, PtrStep, double*, int);
template void sum_multipass_caller<unsigned short>(const DevMem2D, PtrStep, double*, int);
template void sum_multipass_caller<short>(const DevMem2D, PtrStep, double*, int);
template void sum_multipass_caller<int>(const DevMem2D, PtrStep, double*, int);
template void sum_multipass_caller<float>(const DevMem2D, PtrStep, double*, int);
template <typename T>
void sum_caller(const DevMem2D src, PtrStep buf, double* sum, int cn)
{
using namespace sum;
typedef typename SumType<T>::R R;
dim3 threads, grid;
estimate_thread_cfg(src.cols, src.rows, threads, grid);
set_kernel_consts(src.cols, src.rows, threads, grid);
switch (cn)
{
case 1:
sum_kernel<T, R, IdentityOp<R>, threads_x * threads_y><<<grid, threads>>>(
src, (typename TypeVec<R, 1>::vec_t*)buf.ptr(0));
break;
case 2:
sum_kernel_C2<T, R, IdentityOp<R>, threads_x * threads_y><<<grid, threads>>>(
src, (typename TypeVec<R, 2>::vec_t*)buf.ptr(0));
break;
}
cudaSafeCall(cudaThreadSynchronize());
R result[4] = {0, 0, 0, 0};
cudaSafeCall(cudaMemcpy(&result, buf.ptr(0), sizeof(R) * cn, cudaMemcpyDeviceToHost));
sum[0] = result[0];
sum[1] = result[1];
sum[2] = result[2];
sum[3] = result[3];
}
template void sum_caller<unsigned char>(const DevMem2D, PtrStep, double*, int);
template void sum_caller<char>(const DevMem2D, PtrStep, double*, int);
template void sum_caller<unsigned short>(const DevMem2D, PtrStep, double*, int);
template void sum_caller<short>(const DevMem2D, PtrStep, double*, int);
template void sum_caller<int>(const DevMem2D, PtrStep, double*, int);
template void sum_caller<float>(const DevMem2D, PtrStep, double*, int);
template <typename T>
@ -1581,14 +1740,14 @@ namespace cv { namespace gpu { namespace mathfunc
estimate_thread_cfg(src.cols, src.rows, threads, grid);
set_kernel_consts(src.cols, src.rows, threads, grid);
R* buf_ = (R*)buf.ptr(0);
sum_kernel<T, R, SqrOp<R>, threads_x * threads_y><<<grid, threads>>>((const DevMem2D_<T>)src, buf_);
sum_pass2_kernel<T, R, threads_x * threads_y><<<1, threads_x * threads_y>>>(buf_, grid.x * grid.y);
sum_kernel<T, R, SqrOp<R>, threads_x * threads_y><<<grid, threads>>>(
src, (typename TypeVec<R, 1>::vec_t*)buf.ptr(0));
sum_pass2_kernel<T, R, threads_x * threads_y><<<1, threads_x * threads_y>>>(
(typename TypeVec<R, 1>::vec_t*)buf.ptr(0), grid.x * grid.y);
cudaSafeCall(cudaThreadSynchronize());
R result = 0;
cudaSafeCall(cudaMemcpy(&result, buf_, sizeof(result), cudaMemcpyDeviceToHost));
cudaSafeCall(cudaMemcpy(&result, buf.ptr(0), sizeof(R), cudaMemcpyDeviceToHost));
sum[0] = result;
}
@ -1600,35 +1759,6 @@ namespace cv { namespace gpu { namespace mathfunc
template void sqsum_multipass_caller<float>(const DevMem2D, PtrStep, double*);
template <typename T>
void sum_caller(const DevMem2D src, PtrStep buf, double* sum)
{
using namespace sum;
typedef typename SumType<T>::R R;
dim3 threads, grid;
estimate_thread_cfg(src.cols, src.rows, threads, grid);
set_kernel_consts(src.cols, src.rows, threads, grid);
R* buf_ = (R*)buf.ptr(0);
sum_kernel<T, R, IdentityOp<R>, threads_x * threads_y><<<grid, threads>>>((const DevMem2D_<T>)src, buf_);
cudaSafeCall(cudaThreadSynchronize());
R result = 0;
cudaSafeCall(cudaMemcpy(&result, buf_, sizeof(result), cudaMemcpyDeviceToHost));
sum[0] = result;
}
template void sum_caller<unsigned char>(const DevMem2D, PtrStep, double*);
template void sum_caller<char>(const DevMem2D, PtrStep, double*);
template void sum_caller<unsigned short>(const DevMem2D, PtrStep, double*);
template void sum_caller<short>(const DevMem2D, PtrStep, double*);
template void sum_caller<int>(const DevMem2D, PtrStep, double*);
template void sum_caller<float>(const DevMem2D, PtrStep, double*);
template void sum_caller<double>(const DevMem2D, PtrStep, double*);
template <typename T>
void sqsum_caller(const DevMem2D src, PtrStep buf, double* sum)
{
@ -1639,13 +1769,12 @@ namespace cv { namespace gpu { namespace mathfunc
estimate_thread_cfg(src.cols, src.rows, threads, grid);
set_kernel_consts(src.cols, src.rows, threads, grid);
R* buf_ = (R*)buf.ptr(0);
sum_kernel<T, R, SqrOp<R>, threads_x * threads_y><<<grid, threads>>>((const DevMem2D_<T>)src, buf_);
sum_kernel<T, R, SqrOp<R>, threads_x * threads_y><<<grid, threads>>>(
src, (typename TypeVec<R, 1>::vec_t*)buf.ptr(0));
cudaSafeCall(cudaThreadSynchronize());
R result = 0;
cudaSafeCall(cudaMemcpy(&result, buf_, sizeof(result), cudaMemcpyDeviceToHost));
cudaSafeCall(cudaMemcpy(&result, buf.ptr(0), sizeof(R), cudaMemcpyDeviceToHost));
sum[0] = result;
}
@ -1655,6 +1784,5 @@ namespace cv { namespace gpu { namespace mathfunc
template void sqsum_caller<short>(const DevMem2D, PtrStep, double*);
template void sqsum_caller<int>(const DevMem2D, PtrStep, double*);
template void sqsum_caller<float>(const DevMem2D, PtrStep, double*);
template void sqsum_caller<double>(const DevMem2D, PtrStep, double*);
}}}

View File

@ -866,6 +866,91 @@ namespace cv
return make_float4(a.x * s, a.y * s, a.z * s, a.w * s);
}
static __device__ uint1 operator+(const uint1& a, const uint1& b)
{
return make_uint1(a.x + b.x);
}
static __device__ uint1 operator-(const uint1& a, const uint1& b)
{
return make_uint1(a.x - b.x);
}
static __device__ uint1 operator*(const uint1& a, const uint1& b)
{
return make_uint1(a.x * b.x);
}
static __device__ uint1 operator/(const uint1& a, const uint1& b)
{
return make_uint1(a.x / b.x);
}
static __device__ float1 operator*(const uint1& a, float s)
{
return make_float1(a.x * s);
}
static __device__ uint2 operator+(const uint2& a, const uint2& b)
{
return make_uint2(a.x + b.x, a.y + b.y);
}
static __device__ uint2 operator-(const uint2& a, const uint2& b)
{
return make_uint2(a.x - b.x, a.y - b.y);
}
static __device__ uint2 operator*(const uint2& a, const uint2& b)
{
return make_uint2(a.x * b.x, a.y * b.y);
}
static __device__ uint2 operator/(const uint2& a, const uint2& b)
{
return make_uint2(a.x / b.x, a.y / b.y);
}
static __device__ float2 operator*(const uint2& a, float s)
{
return make_float2(a.x * s, a.y * s);
}
static __device__ uint3 operator+(const uint3& a, const uint3& b)
{
return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z);
}
static __device__ uint3 operator-(const uint3& a, const uint3& b)
{
return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z);
}
static __device__ uint3 operator*(const uint3& a, const uint3& b)
{
return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z);
}
static __device__ uint3 operator/(const uint3& a, const uint3& b)
{
return make_uint3(a.x / b.x, a.y / b.y, a.z / b.z);
}
static __device__ float3 operator*(const uint3& a, float s)
{
return make_float3(a.x * s, a.y * s, a.z * s);
}
static __device__ uint4 operator+(const uint4& a, const uint4& b)
{
return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
}
static __device__ uint4 operator-(const uint4& a, const uint4& b)
{
return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
}
static __device__ uint4 operator*(const uint4& a, const uint4& b)
{
return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
}
static __device__ uint4 operator/(const uint4& a, const uint4& b)
{
return make_uint4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
}
static __device__ float4 operator*(const uint4& a, float s)
{
return make_float4(a.x * s, a.y * s, a.z * s, a.w * s);
}
static __device__ float1 operator+(const float1& a, const float1& b)
{
return make_float1(a.x + b.x);

View File

@ -942,9 +942,18 @@ struct CV_GpuSumTest: CvTest
Scalar a, b;
double max_err = 1e-5;
int typemax = hasNativeDoubleSupport(getDevice()) ? CV_64F : CV_32F;
int typemax = CV_32F;
for (int type = CV_8U; type <= typemax; ++type)
{
gen(1 + rand() % 500, 1 + rand() % 500, CV_MAKETYPE(type, 2), src);
a = sum(src);
b = sum(GpuMat(src));
if (abs(a[0] - b[0]) + abs(a[1] - b[1]) > src.size().area() * max_err)
{
ts->printf(CvTS::CONSOLE, "cols: %d, rows: %d, expected: %f, actual: %f\n", src.cols, src.rows, a[0], b[0]);
ts->set_failed_test_info(CvTS::FAIL_INVALID_OUTPUT);
return;
}
gen(1 + rand() % 500, 1 + rand() % 500, type, src);
a = sum(src);
b = sum(GpuMat(src));