added cv::gemm to T-API
This commit is contained in:
@@ -41,6 +41,7 @@
|
||||
//M*/
|
||||
|
||||
#include "precomp.hpp"
|
||||
#include "opencv2/core/opencl/runtime/opencl_clamdblas.hpp"
|
||||
|
||||
#ifdef HAVE_IPP
|
||||
#include "ippversion.h"
|
||||
@@ -693,11 +694,102 @@ static void GEMMStore_64fc( const Complexd* c_data, size_t c_step,
|
||||
GEMMStore(c_data, c_step, d_buf, d_buf_step, d_data, d_step, d_size, alpha, beta, flags);
|
||||
}
|
||||
|
||||
#ifdef HAVE_CLAMDBLAS
|
||||
|
||||
static bool ocl_gemm( InputArray matA, InputArray matB, double alpha,
|
||||
InputArray matC, double beta, OutputArray matD, int flags )
|
||||
{
|
||||
int type = matA.type(), esz = CV_ELEM_SIZE(type);
|
||||
bool haveC = matC.kind() != cv::_InputArray::NONE;
|
||||
Size sizeA = matA.size(), sizeB = matB.size(), sizeC = haveC ? matC.size() : Size(0, 0);
|
||||
bool atrans = (flags & GEMM_1_T) != 0, btrans = (flags & GEMM_2_T) != 0, ctrans = (flags & GEMM_3_T) != 0;
|
||||
|
||||
if (atrans)
|
||||
sizeA = Size(sizeA.height, sizeA.width);
|
||||
if (btrans)
|
||||
sizeB = Size(sizeB.height, sizeB.width);
|
||||
if (haveC && ctrans)
|
||||
sizeC = Size(sizeC.height, sizeC.width);
|
||||
|
||||
Size sizeD(sizeB.width, sizeA.height);
|
||||
|
||||
CV_Assert( matB.type() == type && (!haveC || matC.type() == type) );
|
||||
CV_Assert( sizeA.width == sizeB.height && (!haveC || sizeC == sizeD) );
|
||||
|
||||
matD.create(sizeD, type);
|
||||
if ( matA.offset() % esz != 0 || matA.step() % esz != 0 ||
|
||||
matB.offset() % esz != 0 || matB.step() % esz != 0 ||
|
||||
(haveC && (matC.offset() % esz != 0 || matC.step() % esz != 0)) )
|
||||
return false;
|
||||
|
||||
UMat A = matA.getUMat(), B = matB.getUMat(), D = matD.getUMat();
|
||||
if (haveC)
|
||||
ctrans ? transpose(matC, D) : matC.getMat().copyTo(D); // TODO fix it as soon as .copyTo works as expected
|
||||
else
|
||||
D.setTo(Scalar::all(0));
|
||||
|
||||
int M = sizeD.height, N = sizeD.width, K = sizeA.width;
|
||||
int lda = (int)A.step / esz, ldb = (int)B.step / esz, ldc = (int)D.step / esz;
|
||||
int offa = (int)A.offset / esz, offb = (int)B.offset / esz, offc = (int)D.offset / esz;
|
||||
|
||||
cl_command_queue clq = (cl_command_queue)ocl::Queue::getDefault().ptr();
|
||||
clAmdBlasTranspose transA = atrans ? clAmdBlasTrans : clAmdBlasNoTrans;
|
||||
clAmdBlasTranspose transB = btrans ? clAmdBlasTrans : clAmdBlasNoTrans;
|
||||
clAmdBlasOrder order = clAmdBlasRowMajor;
|
||||
clAmdBlasStatus status = clAmdBlasSuccess;
|
||||
|
||||
if (type == CV_32FC1)
|
||||
status = clAmdBlasSgemmEx(order, transA, transB, M, N, K,
|
||||
(cl_float)alpha, (const cl_mem)A.handle(ACCESS_READ), offa, lda,
|
||||
(const cl_mem)B.handle(ACCESS_READ), offb, ldb,
|
||||
(cl_float)beta, (cl_mem)D.handle(ACCESS_RW), offc, ldc,
|
||||
1, &clq, 0, NULL, NULL);
|
||||
else if (type == CV_64FC1)
|
||||
status = clAmdBlasDgemmEx(order, transA, transB, M, N, K,
|
||||
alpha, (const cl_mem)A.handle(ACCESS_READ), offa, lda,
|
||||
(const cl_mem)B.handle(ACCESS_READ), offb, ldb,
|
||||
beta, (cl_mem)D.handle(ACCESS_RW), offc, ldc,
|
||||
1, &clq, 0, NULL, NULL);
|
||||
else if (type == CV_32FC2)
|
||||
{
|
||||
cl_float2 alpha_2 = { { (cl_float)alpha, 0 } };
|
||||
cl_float2 beta_2 = { { (cl_float)beta, 0 } };
|
||||
status = clAmdBlasCgemmEx(order, transA, transB, M, N, K,
|
||||
alpha_2, (const cl_mem)A.handle(ACCESS_READ), offa, lda,
|
||||
(const cl_mem)B.handle(ACCESS_READ), offb, ldb,
|
||||
beta_2, (cl_mem)D.handle(ACCESS_RW), offc, ldc,
|
||||
1, &clq, 0, NULL, NULL);
|
||||
}
|
||||
else if (type == CV_64FC2)
|
||||
{
|
||||
cl_double2 alpha_2 = { { alpha, 0 } };
|
||||
cl_double2 beta_2 = { { beta, 0 } };
|
||||
status = clAmdBlasZgemmEx(order, transA, transB, M, N, K,
|
||||
alpha_2, (const cl_mem)A.handle(ACCESS_READ), offa, lda,
|
||||
(const cl_mem)B.handle(ACCESS_READ), offb, ldb,
|
||||
beta_2, (cl_mem)D.handle(ACCESS_RW), offc, ldc,
|
||||
1, &clq, 0, NULL, NULL);
|
||||
}
|
||||
else
|
||||
CV_Error(Error::StsUnsupportedFormat, "");
|
||||
|
||||
return status == clAmdBlasSuccess;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
void cv::gemm( InputArray matA, InputArray matB, double alpha,
|
||||
InputArray matC, double beta, OutputArray _matD, int flags )
|
||||
{
|
||||
#ifdef HAVE_CLAMDBLAS
|
||||
if (ocl::haveAmdBlas() && matA.dims() <= 2 && matB.dims() <= 2 && matC.dims() <= 2 &&
|
||||
ocl::useOpenCL() && _matD.isUMat() &&
|
||||
ocl_gemm(matA, matB, alpha, matC, beta, _matD, flags))
|
||||
return;
|
||||
#endif
|
||||
|
||||
const int block_lin_size = 128;
|
||||
const int block_size = block_lin_size * block_lin_size;
|
||||
|
||||
|
@@ -42,6 +42,8 @@
|
||||
#include "precomp.hpp"
|
||||
#include <map>
|
||||
|
||||
#include "opencv2/core/opencl/runtime/opencl_clamdblas.hpp"
|
||||
|
||||
#ifdef HAVE_OPENCL
|
||||
#include "opencv2/core/opencl/runtime/opencl_core.hpp"
|
||||
#else
|
||||
@@ -1309,29 +1311,23 @@ inline bool operator < (const HashKey& h1, const HashKey& h2)
|
||||
return h1.a < h2.a || (h1.a == h2.a && h1.b < h2.b);
|
||||
}
|
||||
|
||||
static bool g_isInitialized = false;
|
||||
static bool g_isOpenCLInitialized = false;
|
||||
static bool g_isOpenCLAvailable = false;
|
||||
|
||||
bool haveOpenCL()
|
||||
{
|
||||
if (!g_isInitialized)
|
||||
if (!g_isOpenCLInitialized)
|
||||
{
|
||||
if (!g_isInitialized)
|
||||
try
|
||||
{
|
||||
try
|
||||
{
|
||||
cl_uint n = 0;
|
||||
cl_int err = ::clGetPlatformIDs(0, NULL, &n);
|
||||
if (err != CL_SUCCESS)
|
||||
g_isOpenCLAvailable = false;
|
||||
else
|
||||
g_isOpenCLAvailable = true;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
g_isOpenCLAvailable = false;
|
||||
}
|
||||
g_isInitialized = true;
|
||||
cl_uint n = 0;
|
||||
g_isOpenCLAvailable = ::clGetPlatformIDs(0, NULL, &n) == CL_SUCCESS;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
g_isOpenCLAvailable = false;
|
||||
}
|
||||
g_isOpenCLInitialized = true;
|
||||
}
|
||||
return g_isOpenCLAvailable;
|
||||
}
|
||||
@@ -1353,6 +1349,80 @@ void setUseOpenCL(bool flag)
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef HAVE_CLAMDBLAS
|
||||
|
||||
class AmdBlasHelper
|
||||
{
|
||||
public:
|
||||
static AmdBlasHelper & getInstance()
|
||||
{
|
||||
static AmdBlasHelper amdBlas;
|
||||
return amdBlas;
|
||||
}
|
||||
|
||||
bool isAvailable() const
|
||||
{
|
||||
return g_isAmdBlasAvailable;
|
||||
}
|
||||
|
||||
~AmdBlasHelper()
|
||||
{
|
||||
try
|
||||
{
|
||||
clAmdBlasTeardown();
|
||||
}
|
||||
catch (...) { }
|
||||
}
|
||||
|
||||
protected:
|
||||
AmdBlasHelper()
|
||||
{
|
||||
if (!g_isAmdBlasInitialized)
|
||||
{
|
||||
AutoLock lock(m);
|
||||
|
||||
if (!g_isAmdBlasInitialized && haveOpenCL())
|
||||
{
|
||||
try
|
||||
{
|
||||
g_isAmdBlasAvailable = clAmdBlasSetup() == clAmdBlasSuccess;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
g_isAmdBlasAvailable = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
g_isAmdBlasAvailable = false;
|
||||
|
||||
g_isAmdBlasInitialized = true;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static Mutex m;
|
||||
static bool g_isAmdBlasInitialized;
|
||||
static bool g_isAmdBlasAvailable;
|
||||
};
|
||||
|
||||
bool AmdBlasHelper::g_isAmdBlasAvailable = false;
|
||||
bool AmdBlasHelper::g_isAmdBlasInitialized = false;
|
||||
Mutex AmdBlasHelper::m;
|
||||
|
||||
bool haveAmdBlas()
|
||||
{
|
||||
return AmdBlasHelper::getInstance().isAvailable();
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
bool haveAmdBlas()
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
void finish2()
|
||||
{
|
||||
Queue::getDefault().finish();
|
||||
|
Reference in New Issue
Block a user