added cv::gemm to T-API

This commit is contained in:
Ilya Lavrenov
2013-12-14 23:16:53 +04:00
parent 11071dd241
commit 2f34bb9aa0
4 changed files with 329 additions and 17 deletions

View File

@@ -41,6 +41,7 @@
//M*/
#include "precomp.hpp"
#include "opencv2/core/opencl/runtime/opencl_clamdblas.hpp"
#ifdef HAVE_IPP
#include "ippversion.h"
@@ -693,11 +694,102 @@ static void GEMMStore_64fc( const Complexd* c_data, size_t c_step,
GEMMStore(c_data, c_step, d_buf, d_buf_step, d_data, d_step, d_size, alpha, beta, flags);
}
#ifdef HAVE_CLAMDBLAS
static bool ocl_gemm( InputArray matA, InputArray matB, double alpha,
InputArray matC, double beta, OutputArray matD, int flags )
{
int type = matA.type(), esz = CV_ELEM_SIZE(type);
bool haveC = matC.kind() != cv::_InputArray::NONE;
Size sizeA = matA.size(), sizeB = matB.size(), sizeC = haveC ? matC.size() : Size(0, 0);
bool atrans = (flags & GEMM_1_T) != 0, btrans = (flags & GEMM_2_T) != 0, ctrans = (flags & GEMM_3_T) != 0;
if (atrans)
sizeA = Size(sizeA.height, sizeA.width);
if (btrans)
sizeB = Size(sizeB.height, sizeB.width);
if (haveC && ctrans)
sizeC = Size(sizeC.height, sizeC.width);
Size sizeD(sizeB.width, sizeA.height);
CV_Assert( matB.type() == type && (!haveC || matC.type() == type) );
CV_Assert( sizeA.width == sizeB.height && (!haveC || sizeC == sizeD) );
matD.create(sizeD, type);
if ( matA.offset() % esz != 0 || matA.step() % esz != 0 ||
matB.offset() % esz != 0 || matB.step() % esz != 0 ||
(haveC && (matC.offset() % esz != 0 || matC.step() % esz != 0)) )
return false;
UMat A = matA.getUMat(), B = matB.getUMat(), D = matD.getUMat();
if (haveC)
ctrans ? transpose(matC, D) : matC.getMat().copyTo(D); // TODO fix it as soon as .copyTo works as expected
else
D.setTo(Scalar::all(0));
int M = sizeD.height, N = sizeD.width, K = sizeA.width;
int lda = (int)A.step / esz, ldb = (int)B.step / esz, ldc = (int)D.step / esz;
int offa = (int)A.offset / esz, offb = (int)B.offset / esz, offc = (int)D.offset / esz;
cl_command_queue clq = (cl_command_queue)ocl::Queue::getDefault().ptr();
clAmdBlasTranspose transA = atrans ? clAmdBlasTrans : clAmdBlasNoTrans;
clAmdBlasTranspose transB = btrans ? clAmdBlasTrans : clAmdBlasNoTrans;
clAmdBlasOrder order = clAmdBlasRowMajor;
clAmdBlasStatus status = clAmdBlasSuccess;
if (type == CV_32FC1)
status = clAmdBlasSgemmEx(order, transA, transB, M, N, K,
(cl_float)alpha, (const cl_mem)A.handle(ACCESS_READ), offa, lda,
(const cl_mem)B.handle(ACCESS_READ), offb, ldb,
(cl_float)beta, (cl_mem)D.handle(ACCESS_RW), offc, ldc,
1, &clq, 0, NULL, NULL);
else if (type == CV_64FC1)
status = clAmdBlasDgemmEx(order, transA, transB, M, N, K,
alpha, (const cl_mem)A.handle(ACCESS_READ), offa, lda,
(const cl_mem)B.handle(ACCESS_READ), offb, ldb,
beta, (cl_mem)D.handle(ACCESS_RW), offc, ldc,
1, &clq, 0, NULL, NULL);
else if (type == CV_32FC2)
{
cl_float2 alpha_2 = { { (cl_float)alpha, 0 } };
cl_float2 beta_2 = { { (cl_float)beta, 0 } };
status = clAmdBlasCgemmEx(order, transA, transB, M, N, K,
alpha_2, (const cl_mem)A.handle(ACCESS_READ), offa, lda,
(const cl_mem)B.handle(ACCESS_READ), offb, ldb,
beta_2, (cl_mem)D.handle(ACCESS_RW), offc, ldc,
1, &clq, 0, NULL, NULL);
}
else if (type == CV_64FC2)
{
cl_double2 alpha_2 = { { alpha, 0 } };
cl_double2 beta_2 = { { beta, 0 } };
status = clAmdBlasZgemmEx(order, transA, transB, M, N, K,
alpha_2, (const cl_mem)A.handle(ACCESS_READ), offa, lda,
(const cl_mem)B.handle(ACCESS_READ), offb, ldb,
beta_2, (cl_mem)D.handle(ACCESS_RW), offc, ldc,
1, &clq, 0, NULL, NULL);
}
else
CV_Error(Error::StsUnsupportedFormat, "");
return status == clAmdBlasSuccess;
}
#endif
}
void cv::gemm( InputArray matA, InputArray matB, double alpha,
InputArray matC, double beta, OutputArray _matD, int flags )
{
#ifdef HAVE_CLAMDBLAS
if (ocl::haveAmdBlas() && matA.dims() <= 2 && matB.dims() <= 2 && matC.dims() <= 2 &&
ocl::useOpenCL() && _matD.isUMat() &&
ocl_gemm(matA, matB, alpha, matC, beta, _matD, flags))
return;
#endif
const int block_lin_size = 128;
const int block_size = block_lin_size * block_lin_size;