refactored gpu::dft

This commit is contained in:
Alexey Spizhevoy
2010-12-27 07:35:41 +00:00
parent a379d011fd
commit 8f0d36b8b6
3 changed files with 28 additions and 45 deletions

View File

@@ -76,7 +76,7 @@ void cv::gpu::cornerHarris(const GpuMat&, GpuMat&, int, int, double, int) { thro
void cv::gpu::cornerMinEigenVal(const GpuMat&, GpuMat&, int, int, int) { throw_nogpu(); }
void cv::gpu::mulSpectrums(const GpuMat&, const GpuMat&, GpuMat&, int, bool) { throw_nogpu(); }
void cv::gpu::mulAndScaleSpectrums(const GpuMat&, const GpuMat&, GpuMat&, int, float, bool) { throw_nogpu(); }
void cv::gpu::dft(const GpuMat&, GpuMat&, int, int, bool) { throw_nogpu(); }
void cv::gpu::dft(const GpuMat&, GpuMat&, Size, int) { throw_nogpu(); }
void cv::gpu::convolve(const GpuMat&, const GpuMat&, GpuMat&, bool) { throw_nogpu(); }
@@ -1130,14 +1130,14 @@ void cv::gpu::mulAndScaleSpectrums(const GpuMat& a, const GpuMat& b, GpuMat& c,
//////////////////////////////////////////////////////////////////////////////
// dft
void cv::gpu::dft(const GpuMat& src, GpuMat& dst, int flags, int nonZeroRows, bool odd)
void cv::gpu::dft(const GpuMat& src, GpuMat& dst, Size dft_size, int flags)
{
CV_Assert(src.type() == CV_32F || src.type() == CV_32FC2);
// We don't support unpacked output (in the case of real input)
CV_Assert(!(flags & DFT_COMPLEX_OUTPUT));
bool is_1d_input = (src.rows == 1) || (src.cols == 1);
bool is_1d_input = (dft_size.height == 1) || (dft_size.width == 1);
int is_row_dft = flags & DFT_ROWS;
int is_scaled_dft = flags & DFT_SCALE;
int is_inverse = flags & DFT_INVERSE;
@@ -1156,63 +1156,49 @@ void cv::gpu::dft(const GpuMat& src, GpuMat& dst, int flags, int nonZeroRows, bo
if (src_data.data != src.data)
src.copyTo(src_data);
Size dft_size_ = dft_size;
if (is_1d_input && !is_row_dft)
// If the source matrix is single column reshape it into single row
src_data = src_data.reshape(0, std::min(src.rows, src.cols));
{
// If the source matrix is single column handle it as single row
dft_size_.width = std::max(dft_size.width, dft_size.height);
dft_size_.height = std::min(dft_size.width, dft_size.height);
}
cufftType dft_type = CUFFT_R2C;
if (is_complex_input)
dft_type = is_complex_output ? CUFFT_C2C : CUFFT_C2R;
int dft_rows = src_data.rows;
int dft_cols = src_data.cols;
if (is_complex_input && !is_complex_output)
dft_cols = (src_data.cols - 1) * 2 + (int)odd;
CV_Assert(dft_cols > 1);
CV_Assert(dft_size_.width > 1);
cufftHandle plan;
if (is_1d_input || is_row_dft)
cufftPlan1d(&plan, dft_cols, dft_type, dft_rows);
cufftPlan1d(&plan, dft_size_.width, dft_type, dft_size_.height);
else
cufftPlan2d(&plan, dft_rows, dft_cols, dft_type);
int dst_cols, dst_rows;
cufftPlan2d(&plan, dft_size_.height, dft_size_.width, dft_type);
if (is_complex_input)
{
if (is_complex_output)
{
createContinuous(src.rows, src.cols, CV_32FC2, dst);
createContinuous(dft_size, CV_32FC2, dst);
cufftSafeCall(cufftExecC2C(
plan, src_data.ptr<cufftComplex>(), dst.ptr<cufftComplex>(),
is_inverse ? CUFFT_INVERSE : CUFFT_FORWARD));
}
else
{
dst_rows = src.rows;
dst_cols = (src.cols - 1) * 2 + (int)odd;
if (src_data.size() != src.size())
{
dst_rows = (src.rows - 1) * 2 + (int)odd;
dst_cols = src.cols;
}
createContinuous(dst_rows, dst_cols, CV_32F, dst);
createContinuous(dft_size, CV_32F, dst);
cufftSafeCall(cufftExecC2R(
plan, src_data.ptr<cufftComplex>(), dst.ptr<cufftReal>()));
}
}
else
{
dst_rows = src.rows;
dst_cols = src.cols / 2 + 1;
if (src_data.size() != src.size())
{
dst_rows = src.rows / 2 + 1;
dst_cols = src.cols;
}
if (dft_size == dft_size_)
createContinuous(Size(dft_size.width / 2 + 1, dft_size.height), CV_32FC2, dst);
else
createContinuous(Size(dft_size.width, dft_size.height / 2 + 1), CV_32FC2, dst);
createContinuous(dst_rows, dst_cols, CV_32FC2, dst);
cufftSafeCall(cufftExecR2C(
plan, src_data.ptr<cufftReal>(), dst.ptr<cufftComplex>()));
}
@@ -1220,7 +1206,7 @@ void cv::gpu::dft(const GpuMat& src, GpuMat& dst, int flags, int nonZeroRows, bo
cufftSafeCall(cufftDestroy(plan));
if (is_scaled_dft)
multiply(dst, Scalar::all(1. / (dft_rows * dft_cols)), dst);
multiply(dst, Scalar::all(1. / (dft_size.area())), dst);
}
//////////////////////////////////////////////////////////////////////////////