diff --git a/modules/core/src/dxt.cpp b/modules/core/src/dxt.cpp index c258bb28a..1cb59860c 100644 --- a/modules/core/src/dxt.cpp +++ b/modules/core/src/dxt.cpp @@ -1867,13 +1867,17 @@ public: UMat src = _src.getUMat(); UMat dst = _dst.getUMat(); + int type = src.type(), depth = CV_MAT_DEPTH(type); + size_t globalsize[2]; size_t localsize[2]; String kernel_name; bool is1d = (flags & DFT_ROWS) != 0 || num_dfts == 1; bool inv = (flags & DFT_INVERSE) != 0; - String options = buildOptions; + String options = buildOptions + format(" -D FT=%s CT=%s%s", ocl::typeToStr(depth), + ocl::typeToStr(CV_MAKE_TYPE(depth, 2)), + depth == CV_64F ? " -D DOUBLE_SUPPORT" : ""); if (rows) { @@ -2039,9 +2043,11 @@ static bool ocl_dft_cols(InputArray _src, OutputArray _dst, int nonzero_cols, in static bool ocl_dft(InputArray _src, OutputArray _dst, int flags, int nonzero_rows) { - int type = _src.type(), cn = CV_MAT_CN(type); + int type = _src.type(), cn = CV_MAT_CN(type), depth = CV_MAT_DEPTH(type); Size ssize = _src.size(); - if ( !(type == CV_32FC1 || type == CV_32FC2) ) + bool doubleSupport = ocl::Device::getDefault().doubleFPConfig(); + + if ( !((cn == 1 || cn == 2) && (depth == CV_32F || (depth == CV_64F && doubleSupport))) ) return false; // if is not a multiplication of prime numbers { 2, 3, 5 } @@ -2082,7 +2088,7 @@ static bool ocl_dft(InputArray _src, OutputArray _dst, int flags, int nonzero_ro if (fftType == C2C || fftType == R2C) { // complex output - _dst.create(src.size(), CV_32FC2); + _dst.create(src.size(), CV_MAKETYPE(depth, 2)); output = _dst.getUMat(); } else @@ -2090,13 +2096,13 @@ static bool ocl_dft(InputArray _src, OutputArray _dst, int flags, int nonzero_ro // real output if (is1d) { - _dst.create(src.size(), CV_32FC1); + _dst.create(src.size(), CV_MAKETYPE(depth, 1)); output = _dst.getUMat(); } else { - _dst.create(src.size(), CV_32FC1); - output.create(src.size(), CV_32FC2); + _dst.create(src.size(), CV_MAKETYPE(depth, 1)); + output.create(src.size(), CV_MAKETYPE(depth, 2)); } } diff --git a/modules/core/src/opencl/fft.cl b/modules/core/src/opencl/fft.cl index b56f5c1dc..afc968496 100644 --- a/modules/core/src/opencl/fft.cl +++ b/modules/core/src/opencl/fft.cl @@ -12,6 +12,14 @@ #define fft5_4 -1.538841768587f #define fft5_5 0.363271264002f +#ifdef DOUBLE_SUPPORT +#ifdef cl_amd_fp64 +#pragma OPENCL EXTENSION cl_amd_fp64:enable +#elif defined (cl_khr_fp64) +#pragma OPENCL EXTENSION cl_khr_fp64:enable +#endif +#endif + __attribute__((always_inline)) float2 mul_float2(float2 a, float2 b) { return (float2)(fma(a.x, b.x, -a.y * b.y), fma(a.x, b.y, a.y * b.x)); @@ -530,25 +538,25 @@ __kernel void fft_multi_radix_rows(__global const uchar* src_ptr, int src_step, const int block_size = LOCAL_SIZE/kercn; if (y < nz) { - __local float2 smem[LOCAL_SIZE]; + __local CT smem[LOCAL_SIZE]; __global const float2* twiddles = (__global float2*) twiddles_ptr; const int ind = x; #ifdef IS_1D - float scale = 1.f/dst_cols; + FT scale = (FT) 1/dst_cols; #else - float scale = 1.f/(dst_cols*dst_rows); + FT scale = (FT) 1/(dst_cols*dst_rows); #endif #ifdef COMPLEX_INPUT - __global const float2* src = (__global const float2*)(src_ptr + mad24(y, src_step, mad24(x, (int)(sizeof(float)*2), src_offset))); + __global const CT* src = (__global const CT*)(src_ptr + mad24(y, src_step, mad24(x, (int)(sizeof(CT)), src_offset))); #pragma unroll for (int i=0; i