implemented gpu::gemm via CUBLAS
This commit is contained in:
@@ -48,6 +48,7 @@ using namespace std;
|
||||
|
||||
#if !defined (HAVE_CUDA)
|
||||
|
||||
void cv::gpu::gemm(const GpuMat&, const GpuMat&, double, const GpuMat&, double, GpuMat&, int, Stream&) { throw_nogpu(); }
|
||||
void cv::gpu::transpose(const GpuMat&, GpuMat&, Stream&) { throw_nogpu(); }
|
||||
void cv::gpu::flip(const GpuMat&, GpuMat&, int, Stream&) { throw_nogpu(); }
|
||||
void cv::gpu::LUT(const GpuMat&, const Mat&, GpuMat&, Stream&) { throw_nogpu(); }
|
||||
@@ -63,6 +64,133 @@ void cv::gpu::polarToCart(const GpuMat&, const GpuMat&, GpuMat&, GpuMat&, bool,
|
||||
|
||||
#else /* !defined (HAVE_CUDA) */
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// gemm
|
||||
|
||||
void cv::gpu::gemm(const GpuMat& src1, const GpuMat& src2, double alpha, const GpuMat& src3, double beta, GpuMat& dst, int flags, Stream& stream)
|
||||
{
|
||||
#ifndef HAVE_CUBLAS
|
||||
|
||||
OPENCV_GPU_UNUSED(src1);
|
||||
OPENCV_GPU_UNUSED(src2);
|
||||
OPENCV_GPU_UNUSED(alpha);
|
||||
OPENCV_GPU_UNUSED(src3);
|
||||
OPENCV_GPU_UNUSED(beta);
|
||||
OPENCV_GPU_UNUSED(dst);
|
||||
OPENCV_GPU_UNUSED(flags);
|
||||
OPENCV_GPU_UNUSED(stream);
|
||||
|
||||
throw_nogpu();
|
||||
|
||||
#else
|
||||
|
||||
// CUBLAS works with column-major matrices
|
||||
|
||||
CV_Assert(src1.type() == CV_32FC1 || src1.type() == CV_32FC2 || src1.type() == CV_64FC1 || src1.type() == CV_64FC2);
|
||||
CV_Assert(src2.type() == src1.type() && (src3.empty() || src3.type() == src1.type()));
|
||||
|
||||
bool tr1 = flags & GEMM_1_T;
|
||||
bool tr2 = flags & GEMM_2_T;
|
||||
bool tr3 = flags & GEMM_3_T;
|
||||
|
||||
Size src1Size = tr1 ? Size(src1.rows, src1.cols) : src1.size();
|
||||
Size src2Size = tr2 ? Size(src2.rows, src2.cols) : src2.size();
|
||||
Size src3Size = tr3 ? Size(src3.rows, src3.cols) : src3.size();
|
||||
Size dstSize(src2Size.width, src1Size.height);
|
||||
|
||||
CV_Assert(src1Size.width == src2Size.height);
|
||||
CV_Assert(src3.empty() || src3Size == dstSize);
|
||||
|
||||
dst.create(dstSize, CV_32FC1);
|
||||
|
||||
if (beta != 0)
|
||||
{
|
||||
if (src3.empty())
|
||||
{
|
||||
if (stream)
|
||||
stream.enqueueMemSet(dst, Scalar::all(0));
|
||||
else
|
||||
dst.setTo(Scalar::all(0));
|
||||
}
|
||||
else
|
||||
{
|
||||
if (tr3)
|
||||
{
|
||||
transpose(src3, dst, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (stream)
|
||||
stream.enqueueCopy(src3, dst);
|
||||
else
|
||||
src3.copyTo(dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cublasHandle_t handle;
|
||||
cublasSafeCall( cublasCreate_v2(&handle) );
|
||||
|
||||
cublasSafeCall( cublasSetStream_v2(handle, StreamAccessor::getStream(stream)) );
|
||||
|
||||
cublasSafeCall( cublasSetPointerMode_v2(handle, CUBLAS_POINTER_MODE_HOST) );
|
||||
|
||||
const float alphaf = static_cast<float>(alpha);
|
||||
const float betaf = static_cast<float>(beta);
|
||||
|
||||
const cuComplex alphacf = make_cuComplex(alphaf, 0);
|
||||
const cuComplex betacf = make_cuComplex(betaf, 0);
|
||||
|
||||
const cuDoubleComplex alphac = make_cuDoubleComplex(alpha, 0);
|
||||
const cuDoubleComplex betac = make_cuDoubleComplex(beta, 0);
|
||||
|
||||
cublasOperation_t transa = tr2 ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
cublasOperation_t transb = tr1 ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
|
||||
switch (src1.type())
|
||||
{
|
||||
case CV_32FC1:
|
||||
cublasSafeCall( cublasSgemm_v2(handle, transa, transb, tr2 ? src2.rows : src2.cols, tr1 ? src1.cols : src1.rows, tr2 ? src2.cols : src2.rows,
|
||||
&alphaf,
|
||||
src2.ptr<float>(), static_cast<int>(src2.step / sizeof(float)),
|
||||
src1.ptr<float>(), static_cast<int>(src1.step / sizeof(float)),
|
||||
&betaf,
|
||||
dst.ptr<float>(), static_cast<int>(dst.step / sizeof(float))) );
|
||||
break;
|
||||
|
||||
case CV_64FC1:
|
||||
cublasSafeCall( cublasDgemm_v2(handle, transa, transb, tr2 ? src2.rows : src2.cols, tr1 ? src1.cols : src1.rows, tr2 ? src2.cols : src2.rows,
|
||||
&alpha,
|
||||
src2.ptr<double>(), static_cast<int>(src2.step / sizeof(double)),
|
||||
src1.ptr<double>(), static_cast<int>(src1.step / sizeof(double)),
|
||||
&beta,
|
||||
dst.ptr<double>(), static_cast<int>(dst.step / sizeof(double))) );
|
||||
break;
|
||||
|
||||
case CV_32FC2:
|
||||
cublasSafeCall( cublasCgemm_v2(handle, transa, transb, tr2 ? src2.rows : src2.cols, tr1 ? src1.cols : src1.rows, tr2 ? src2.cols : src2.rows,
|
||||
&alphacf,
|
||||
src2.ptr<cuComplex>(), static_cast<int>(src2.step / sizeof(cuComplex)),
|
||||
src1.ptr<cuComplex>(), static_cast<int>(src1.step / sizeof(cuComplex)),
|
||||
&betacf,
|
||||
dst.ptr<cuComplex>(), static_cast<int>(dst.step / sizeof(cuComplex))) );
|
||||
break;
|
||||
|
||||
case CV_64FC2:
|
||||
cublasSafeCall( cublasZgemm_v2(handle, transa, transb, tr2 ? src2.rows : src2.cols, tr1 ? src1.cols : src1.rows, tr2 ? src2.cols : src2.rows,
|
||||
&alphac,
|
||||
src2.ptr<cuDoubleComplex>(), static_cast<int>(src2.step / sizeof(cuDoubleComplex)),
|
||||
src1.ptr<cuDoubleComplex>(), static_cast<int>(src1.step / sizeof(cuDoubleComplex)),
|
||||
&betac,
|
||||
dst.ptr<cuDoubleComplex>(), static_cast<int>(dst.step / sizeof(cuDoubleComplex))) );
|
||||
break;
|
||||
}
|
||||
|
||||
cublasSafeCall( cublasDestroy_v2(handle) );
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// transpose
|
||||
|
||||
|
@@ -434,7 +434,7 @@ void cv::gpu::multiply(const GpuMat& src, const Scalar& sc, GpuMat& dst, double
|
||||
{0/*multiply_gpu<double, unsigned char>*/, 0/*multiply_gpu<double, signed char>*/, 0/*multiply_gpu<double, unsigned short>*/, 0/*multiply_gpu<double, short>*/, 0/*multiply_gpu<double, int>*/, 0/*multiply_gpu<double, float>*/, multiply_gpu<double, double>}
|
||||
};
|
||||
|
||||
CV_Assert(src.channels() == 1);
|
||||
//CV_Assert(src.channels() == 1);
|
||||
|
||||
if (dtype < 0)
|
||||
dtype = src.depth();
|
||||
@@ -463,7 +463,7 @@ void cv::gpu::multiply(const GpuMat& src, const Scalar& sc, GpuMat& dst, double
|
||||
const func_t func = funcs[src.depth()][dst.depth()];
|
||||
CV_Assert(func != 0);
|
||||
|
||||
func(src, sc.val[0], dst, scale, stream);
|
||||
func(src.reshape(1), sc.val[0], dst.reshape(1), scale, stream);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
Reference in New Issue
Block a user