implemented gpu::gemm via CUBLAS

This commit is contained in:
Vladislav Vinogradov
2011-10-19 13:29:54 +00:00
parent 90ff3dd990
commit e7502e7641
6 changed files with 259 additions and 2 deletions

View File

@@ -48,6 +48,7 @@ using namespace std;
#if !defined (HAVE_CUDA)
void cv::gpu::gemm(const GpuMat&, const GpuMat&, double, const GpuMat&, double, GpuMat&, int, Stream&) { throw_nogpu(); }
void cv::gpu::transpose(const GpuMat&, GpuMat&, Stream&) { throw_nogpu(); }
void cv::gpu::flip(const GpuMat&, GpuMat&, int, Stream&) { throw_nogpu(); }
void cv::gpu::LUT(const GpuMat&, const Mat&, GpuMat&, Stream&) { throw_nogpu(); }
@@ -63,6 +64,133 @@ void cv::gpu::polarToCart(const GpuMat&, const GpuMat&, GpuMat&, GpuMat&, bool,
#else /* !defined (HAVE_CUDA) */
////////////////////////////////////////////////////////////////////////
// gemm
void cv::gpu::gemm(const GpuMat& src1, const GpuMat& src2, double alpha, const GpuMat& src3, double beta, GpuMat& dst, int flags, Stream& stream)
{
#ifndef HAVE_CUBLAS
OPENCV_GPU_UNUSED(src1);
OPENCV_GPU_UNUSED(src2);
OPENCV_GPU_UNUSED(alpha);
OPENCV_GPU_UNUSED(src3);
OPENCV_GPU_UNUSED(beta);
OPENCV_GPU_UNUSED(dst);
OPENCV_GPU_UNUSED(flags);
OPENCV_GPU_UNUSED(stream);
throw_nogpu();
#else
// CUBLAS works with column-major matrices
CV_Assert(src1.type() == CV_32FC1 || src1.type() == CV_32FC2 || src1.type() == CV_64FC1 || src1.type() == CV_64FC2);
CV_Assert(src2.type() == src1.type() && (src3.empty() || src3.type() == src1.type()));
bool tr1 = flags & GEMM_1_T;
bool tr2 = flags & GEMM_2_T;
bool tr3 = flags & GEMM_3_T;
Size src1Size = tr1 ? Size(src1.rows, src1.cols) : src1.size();
Size src2Size = tr2 ? Size(src2.rows, src2.cols) : src2.size();
Size src3Size = tr3 ? Size(src3.rows, src3.cols) : src3.size();
Size dstSize(src2Size.width, src1Size.height);
CV_Assert(src1Size.width == src2Size.height);
CV_Assert(src3.empty() || src3Size == dstSize);
dst.create(dstSize, CV_32FC1);
if (beta != 0)
{
if (src3.empty())
{
if (stream)
stream.enqueueMemSet(dst, Scalar::all(0));
else
dst.setTo(Scalar::all(0));
}
else
{
if (tr3)
{
transpose(src3, dst, stream);
}
else
{
if (stream)
stream.enqueueCopy(src3, dst);
else
src3.copyTo(dst);
}
}
}
cublasHandle_t handle;
cublasSafeCall( cublasCreate_v2(&handle) );
cublasSafeCall( cublasSetStream_v2(handle, StreamAccessor::getStream(stream)) );
cublasSafeCall( cublasSetPointerMode_v2(handle, CUBLAS_POINTER_MODE_HOST) );
const float alphaf = static_cast<float>(alpha);
const float betaf = static_cast<float>(beta);
const cuComplex alphacf = make_cuComplex(alphaf, 0);
const cuComplex betacf = make_cuComplex(betaf, 0);
const cuDoubleComplex alphac = make_cuDoubleComplex(alpha, 0);
const cuDoubleComplex betac = make_cuDoubleComplex(beta, 0);
cublasOperation_t transa = tr2 ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t transb = tr1 ? CUBLAS_OP_T : CUBLAS_OP_N;
switch (src1.type())
{
case CV_32FC1:
cublasSafeCall( cublasSgemm_v2(handle, transa, transb, tr2 ? src2.rows : src2.cols, tr1 ? src1.cols : src1.rows, tr2 ? src2.cols : src2.rows,
&alphaf,
src2.ptr<float>(), static_cast<int>(src2.step / sizeof(float)),
src1.ptr<float>(), static_cast<int>(src1.step / sizeof(float)),
&betaf,
dst.ptr<float>(), static_cast<int>(dst.step / sizeof(float))) );
break;
case CV_64FC1:
cublasSafeCall( cublasDgemm_v2(handle, transa, transb, tr2 ? src2.rows : src2.cols, tr1 ? src1.cols : src1.rows, tr2 ? src2.cols : src2.rows,
&alpha,
src2.ptr<double>(), static_cast<int>(src2.step / sizeof(double)),
src1.ptr<double>(), static_cast<int>(src1.step / sizeof(double)),
&beta,
dst.ptr<double>(), static_cast<int>(dst.step / sizeof(double))) );
break;
case CV_32FC2:
cublasSafeCall( cublasCgemm_v2(handle, transa, transb, tr2 ? src2.rows : src2.cols, tr1 ? src1.cols : src1.rows, tr2 ? src2.cols : src2.rows,
&alphacf,
src2.ptr<cuComplex>(), static_cast<int>(src2.step / sizeof(cuComplex)),
src1.ptr<cuComplex>(), static_cast<int>(src1.step / sizeof(cuComplex)),
&betacf,
dst.ptr<cuComplex>(), static_cast<int>(dst.step / sizeof(cuComplex))) );
break;
case CV_64FC2:
cublasSafeCall( cublasZgemm_v2(handle, transa, transb, tr2 ? src2.rows : src2.cols, tr1 ? src1.cols : src1.rows, tr2 ? src2.cols : src2.rows,
&alphac,
src2.ptr<cuDoubleComplex>(), static_cast<int>(src2.step / sizeof(cuDoubleComplex)),
src1.ptr<cuDoubleComplex>(), static_cast<int>(src1.step / sizeof(cuDoubleComplex)),
&betac,
dst.ptr<cuDoubleComplex>(), static_cast<int>(dst.step / sizeof(cuDoubleComplex))) );
break;
}
cublasSafeCall( cublasDestroy_v2(handle) );
#endif
}
////////////////////////////////////////////////////////////////////////
// transpose

View File

@@ -434,7 +434,7 @@ void cv::gpu::multiply(const GpuMat& src, const Scalar& sc, GpuMat& dst, double
{0/*multiply_gpu<double, unsigned char>*/, 0/*multiply_gpu<double, signed char>*/, 0/*multiply_gpu<double, unsigned short>*/, 0/*multiply_gpu<double, short>*/, 0/*multiply_gpu<double, int>*/, 0/*multiply_gpu<double, float>*/, multiply_gpu<double, double>}
};
CV_Assert(src.channels() == 1);
//CV_Assert(src.channels() == 1);
if (dtype < 0)
dtype = src.depth();
@@ -463,7 +463,7 @@ void cv::gpu::multiply(const GpuMat& src, const Scalar& sc, GpuMat& dst, double
const func_t func = funcs[src.depth()][dst.depth()];
CV_Assert(func != 0);
func(src, sc.val[0], dst, scale, stream);
func(src.reshape(1), sc.val[0], dst.reshape(1), scale, stream);
}
////////////////////////////////////////////////////////////////////////