used new device layer for cv::gpu::multiply
This commit is contained in:
		| @@ -40,172 +40,185 @@ | ||||
| // | ||||
| //M*/ | ||||
|  | ||||
| #if !defined CUDA_DISABLER | ||||
| #include "opencv2/opencv_modules.hpp" | ||||
|  | ||||
| #include "opencv2/core/cuda/common.hpp" | ||||
| #include "opencv2/core/cuda/functional.hpp" | ||||
| #include "opencv2/core/cuda/transform.hpp" | ||||
| #include "opencv2/core/cuda/saturate_cast.hpp" | ||||
| #include "opencv2/core/cuda/simd_functions.hpp" | ||||
| #ifndef HAVE_OPENCV_CUDEV | ||||
|  | ||||
| #include "arithm_func_traits.hpp" | ||||
| #error "opencv_cudev is required" | ||||
|  | ||||
| using namespace cv::cuda; | ||||
| using namespace cv::cuda::device; | ||||
| #else | ||||
|  | ||||
| namespace arithm | ||||
| #include "opencv2/cudev.hpp" | ||||
|  | ||||
| using namespace cv::cudev; | ||||
|  | ||||
| void mulMat(const GpuMat& src1, const GpuMat& src2, GpuMat& dst, const GpuMat&, double scale, Stream& stream, int); | ||||
| void mulMat_8uc4_32f(const GpuMat& src1, const GpuMat& src2, GpuMat& dst, Stream& stream); | ||||
| void mulMat_16sc4_32f(const GpuMat& src1, const GpuMat& src2, GpuMat& dst, Stream& stream); | ||||
|  | ||||
| namespace | ||||
| { | ||||
|     struct Mul_8uc4_32f : binary_function<uint, float, uint> | ||||
|     { | ||||
|         __device__ __forceinline__ uint operator ()(uint a, float b) const | ||||
|         { | ||||
|             uint res = 0; | ||||
|  | ||||
|             res |= (saturate_cast<uchar>((0xffu & (a      )) * b)      ); | ||||
|             res |= (saturate_cast<uchar>((0xffu & (a >>  8)) * b) <<  8); | ||||
|             res |= (saturate_cast<uchar>((0xffu & (a >> 16)) * b) << 16); | ||||
|             res |= (saturate_cast<uchar>((0xffu & (a >> 24)) * b) << 24); | ||||
|  | ||||
|             return res; | ||||
|         } | ||||
|  | ||||
|         __host__ __device__ __forceinline__ Mul_8uc4_32f() {} | ||||
|         __host__ __device__ __forceinline__ Mul_8uc4_32f(const Mul_8uc4_32f&) {} | ||||
|     }; | ||||
|  | ||||
|     struct Mul_16sc4_32f : binary_function<short4, float, short4> | ||||
|     { | ||||
|         __device__ __forceinline__ short4 operator ()(short4 a, float b) const | ||||
|         { | ||||
|             return make_short4(saturate_cast<short>(a.x * b), saturate_cast<short>(a.y * b), | ||||
|                                saturate_cast<short>(a.z * b), saturate_cast<short>(a.w * b)); | ||||
|         } | ||||
|  | ||||
|         __host__ __device__ __forceinline__ Mul_16sc4_32f() {} | ||||
|         __host__ __device__ __forceinline__ Mul_16sc4_32f(const Mul_16sc4_32f&) {} | ||||
|     }; | ||||
|  | ||||
|     template <typename T, typename D> struct Mul : binary_function<T, T, D> | ||||
|     template <typename T, typename D> struct MulOp : binary_function<T, T, D> | ||||
|     { | ||||
|         __device__ __forceinline__ D operator ()(T a, T b) const | ||||
|         { | ||||
|             return saturate_cast<D>(a * b); | ||||
|         } | ||||
|  | ||||
|         __host__ __device__ __forceinline__ Mul() {} | ||||
|         __host__ __device__ __forceinline__ Mul(const Mul&) {} | ||||
|     }; | ||||
|  | ||||
|     template <typename T, typename S, typename D> struct MulScale : binary_function<T, T, D> | ||||
|     template <typename T, typename S, typename D> struct MulScaleOp : binary_function<T, T, D> | ||||
|     { | ||||
|         S scale; | ||||
|  | ||||
|         __host__ explicit MulScale(S scale_) : scale(scale_) {} | ||||
|  | ||||
|         __device__ __forceinline__ D operator ()(T a, T b) const | ||||
|         { | ||||
|             return saturate_cast<D>(scale * a * b); | ||||
|         } | ||||
|     }; | ||||
| } | ||||
|  | ||||
| namespace cv { namespace cuda { namespace device | ||||
| { | ||||
|     template <> struct TransformFunctorTraits<arithm::Mul_8uc4_32f> : arithm::ArithmFuncTraits<sizeof(uint), sizeof(uint)> | ||||
|     template <typename ScalarDepth> struct TransformPolicy : DefaultTransformPolicy | ||||
|     { | ||||
|     }; | ||||
|  | ||||
|     template <typename T, typename D> struct TransformFunctorTraits< arithm::Mul<T, D> > : arithm::ArithmFuncTraits<sizeof(T), sizeof(D)> | ||||
|     template <> struct TransformPolicy<double> : DefaultTransformPolicy | ||||
|     { | ||||
|         enum { | ||||
|             shift = 1 | ||||
|         }; | ||||
|     }; | ||||
|  | ||||
|     template <typename T, typename S, typename D> struct TransformFunctorTraits< arithm::MulScale<T, S, D> > : arithm::ArithmFuncTraits<sizeof(T), sizeof(D)> | ||||
|     { | ||||
|     }; | ||||
| }}} | ||||
|  | ||||
| namespace arithm | ||||
| { | ||||
|     void mulMat_8uc4_32f(PtrStepSz<uint> src1, PtrStepSzf src2, PtrStepSz<uint> dst, cudaStream_t stream) | ||||
|     { | ||||
|         device::transform(src1, src2, dst, Mul_8uc4_32f(), WithOutMask(), stream); | ||||
|     } | ||||
|  | ||||
|     void mulMat_16sc4_32f(PtrStepSz<short4> src1, PtrStepSzf src2, PtrStepSz<short4> dst, cudaStream_t stream) | ||||
|     { | ||||
|         device::transform(src1, src2, dst, Mul_16sc4_32f(), WithOutMask(), stream); | ||||
|     } | ||||
|  | ||||
|     template <typename T, typename S, typename D> | ||||
|     void mulMat(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream) | ||||
|     void mulMatImpl(const GpuMat& src1, const GpuMat& src2, const GpuMat& dst, double scale, Stream& stream) | ||||
|     { | ||||
|         if (scale == 1) | ||||
|         { | ||||
|             Mul<T, D> op; | ||||
|             device::transform((PtrStepSz<T>) src1, (PtrStepSz<T>) src2, (PtrStepSz<D>) dst, op, WithOutMask(), stream); | ||||
|             MulOp<T, D> op; | ||||
|             gridTransformBinary_< TransformPolicy<S> >(globPtr<T>(src1), globPtr<T>(src2), globPtr<D>(dst), op, stream); | ||||
|         } | ||||
|         else | ||||
|         { | ||||
|             MulScale<T, S, D> op(static_cast<S>(scale)); | ||||
|             device::transform((PtrStepSz<T>) src1, (PtrStepSz<T>) src2, (PtrStepSz<D>) dst, op, WithOutMask(), stream); | ||||
|             MulScaleOp<T, S, D> op; | ||||
|             op.scale = static_cast<S>(scale); | ||||
|             gridTransformBinary_< TransformPolicy<S> >(globPtr<T>(src1), globPtr<T>(src2), globPtr<D>(dst), op, stream); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     template void mulMat<uchar, float, uchar>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<uchar, float, schar>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<uchar, float, ushort>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<uchar, float, short>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<uchar, float, int>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<uchar, float, float>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<uchar, double, double>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|  | ||||
|     template void mulMat<schar, float, uchar>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<schar, float, schar>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<schar, float, ushort>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<schar, float, short>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<schar, float, int>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<schar, float, float>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<schar, double, double>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|  | ||||
|     //template void mulMat<ushort, float, uchar>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     //template void mulMat<ushort, float, schar>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<ushort, float, ushort>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<ushort, float, short>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<ushort, float, int>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<ushort, float, float>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<ushort, double, double>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|  | ||||
|     //template void mulMat<short, float, uchar>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     //template void mulMat<short, float, schar>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<short, float, ushort>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<short, float, short>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<short, float, int>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<short, float, float>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<short, double, double>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|  | ||||
|     //template void mulMat<int, float, uchar>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     //template void mulMat<int, float, schar>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     //template void mulMat<int, float, ushort>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     //template void mulMat<int, float, short>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<int, float, int>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<int, float, float>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<int, double, double>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|  | ||||
|     //template void mulMat<float, float, uchar>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     //template void mulMat<float, float, schar>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     //template void mulMat<float, float, ushort>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     //template void mulMat<float, float, short>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     //template void mulMat<float, float, int>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<float, float, float>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<float, double, double>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|  | ||||
|     //template void mulMat<double, double, uchar>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     //template void mulMat<double, double, schar>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     //template void mulMat<double, double, ushort>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     //template void mulMat<double, double, short>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     //template void mulMat<double, double, int>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     //template void mulMat<double, double, float>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
|     template void mulMat<double, double, double>(PtrStepSzb src1, PtrStepSzb src2, PtrStepSzb dst, double scale, cudaStream_t stream); | ||||
| } | ||||
|  | ||||
| #endif // CUDA_DISABLER | ||||
| void mulMat(const GpuMat& src1, const GpuMat& src2, GpuMat& dst, const GpuMat&, double scale, Stream& stream, int) | ||||
| { | ||||
|     typedef void (*func_t)(const GpuMat& src1, const GpuMat& src2, const GpuMat& dst, double scale, Stream& stream); | ||||
|     static const func_t funcs[7][7] = | ||||
|     { | ||||
|         { | ||||
|             mulMatImpl<uchar, float, uchar>, | ||||
|             mulMatImpl<uchar, float, schar>, | ||||
|             mulMatImpl<uchar, float, ushort>, | ||||
|             mulMatImpl<uchar, float, short>, | ||||
|             mulMatImpl<uchar, float, int>, | ||||
|             mulMatImpl<uchar, float, float>, | ||||
|             mulMatImpl<uchar, double, double> | ||||
|         }, | ||||
|         { | ||||
|             mulMatImpl<schar, float, uchar>, | ||||
|             mulMatImpl<schar, float, schar>, | ||||
|             mulMatImpl<schar, float, ushort>, | ||||
|             mulMatImpl<schar, float, short>, | ||||
|             mulMatImpl<schar, float, int>, | ||||
|             mulMatImpl<schar, float, float>, | ||||
|             mulMatImpl<schar, double, double> | ||||
|         }, | ||||
|         { | ||||
|             0 /*mulMatImpl<ushort, float, uchar>*/, | ||||
|             0 /*mulMatImpl<ushort, float, schar>*/, | ||||
|             mulMatImpl<ushort, float, ushort>, | ||||
|             mulMatImpl<ushort, float, short>, | ||||
|             mulMatImpl<ushort, float, int>, | ||||
|             mulMatImpl<ushort, float, float>, | ||||
|             mulMatImpl<ushort, double, double> | ||||
|         }, | ||||
|         { | ||||
|             0 /*mulMatImpl<short, float, uchar>*/, | ||||
|             0 /*mulMatImpl<short, float, schar>*/, | ||||
|             mulMatImpl<short, float, ushort>, | ||||
|             mulMatImpl<short, float, short>, | ||||
|             mulMatImpl<short, float, int>, | ||||
|             mulMatImpl<short, float, float>, | ||||
|             mulMatImpl<short, double, double> | ||||
|         }, | ||||
|         { | ||||
|             0 /*mulMatImpl<int, float, uchar>*/, | ||||
|             0 /*mulMatImpl<int, float, schar>*/, | ||||
|             0 /*mulMatImpl<int, float, ushort>*/, | ||||
|             0 /*mulMatImpl<int, float, short>*/, | ||||
|             mulMatImpl<int, float, int>, | ||||
|             mulMatImpl<int, float, float>, | ||||
|             mulMatImpl<int, double, double> | ||||
|         }, | ||||
|         { | ||||
|             0 /*mulMatImpl<float, float, uchar>*/, | ||||
|             0 /*mulMatImpl<float, float, schar>*/, | ||||
|             0 /*mulMatImpl<float, float, ushort>*/, | ||||
|             0 /*mulMatImpl<float, float, short>*/, | ||||
|             0 /*mulMatImpl<float, float, int>*/, | ||||
|             mulMatImpl<float, float, float>, | ||||
|             mulMatImpl<float, double, double> | ||||
|         }, | ||||
|         { | ||||
|             0 /*mulMatImpl<double, double, uchar>*/, | ||||
|             0 /*mulMatImpl<double, double, schar>*/, | ||||
|             0 /*mulMatImpl<double, double, ushort>*/, | ||||
|             0 /*mulMatImpl<double, double, short>*/, | ||||
|             0 /*mulMatImpl<double, double, int>*/, | ||||
|             0 /*mulMatImpl<double, double, float>*/, | ||||
|             mulMatImpl<double, double, double> | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     const int sdepth = src1.depth(); | ||||
|     const int ddepth = dst.depth(); | ||||
|  | ||||
|     CV_DbgAssert( sdepth < 7 && ddepth < 7 ); | ||||
|  | ||||
|     GpuMat src1_ = src1.reshape(1); | ||||
|     GpuMat src2_ = src2.reshape(1); | ||||
|     GpuMat dst_ = dst.reshape(1); | ||||
|  | ||||
|     const func_t func = funcs[sdepth][ddepth]; | ||||
|  | ||||
|     if (!func) | ||||
|         CV_Error(cv::Error::StsUnsupportedFormat, "Unsupported combination of source and destination types"); | ||||
|  | ||||
|     func(src1_, src2_, dst_, scale, stream); | ||||
| } | ||||
|  | ||||
| namespace | ||||
| { | ||||
|     template <typename T> | ||||
|     struct MulOpSpecial : binary_function<T, float, T> | ||||
|     { | ||||
|         __device__ __forceinline__ T operator ()(const T& a, float b) const | ||||
|         { | ||||
|             typedef typename VecTraits<T>::elem_type elem_type; | ||||
|  | ||||
|             T res; | ||||
|  | ||||
|             res.x = saturate_cast<elem_type>(a.x * b); | ||||
|             res.y = saturate_cast<elem_type>(a.y * b); | ||||
|             res.z = saturate_cast<elem_type>(a.z * b); | ||||
|             res.w = saturate_cast<elem_type>(a.w * b); | ||||
|  | ||||
|             return res; | ||||
|         } | ||||
|     }; | ||||
| } | ||||
|  | ||||
| void mulMat_8uc4_32f(const GpuMat& src1, const GpuMat& src2, GpuMat& dst, Stream& stream) | ||||
| { | ||||
|     gridTransformBinary(globPtr<uchar4>(src1), globPtr<float>(src2), globPtr<uchar4>(dst), MulOpSpecial<uchar4>(), stream); | ||||
| } | ||||
|  | ||||
| void mulMat_16sc4_32f(const GpuMat& src1, const GpuMat& src2, GpuMat& dst, Stream& stream) | ||||
| { | ||||
|     gridTransformBinary(globPtr<short4>(src1), globPtr<float>(src2), globPtr<short4>(dst), MulOpSpecial<short4>(), stream); | ||||
| } | ||||
|  | ||||
| #endif | ||||
|   | ||||
| @@ -40,105 +40,143 @@ | ||||
| // | ||||
| //M*/ | ||||
|  | ||||
| #if !defined CUDA_DISABLER | ||||
| #include "opencv2/opencv_modules.hpp" | ||||
|  | ||||
| #include "opencv2/core/cuda/common.hpp" | ||||
| #include "opencv2/core/cuda/functional.hpp" | ||||
| #include "opencv2/core/cuda/transform.hpp" | ||||
| #include "opencv2/core/cuda/saturate_cast.hpp" | ||||
| #include "opencv2/core/cuda/simd_functions.hpp" | ||||
| #ifndef HAVE_OPENCV_CUDEV | ||||
|  | ||||
| #include "arithm_func_traits.hpp" | ||||
| #error "opencv_cudev is required" | ||||
|  | ||||
| using namespace cv::cuda; | ||||
| using namespace cv::cuda::device; | ||||
| #else | ||||
|  | ||||
| namespace arithm | ||||
| #include "opencv2/cudev.hpp" | ||||
|  | ||||
| using namespace cv::cudev; | ||||
|  | ||||
| void mulScalar(const GpuMat& src, cv::Scalar val, bool, GpuMat& dst, const GpuMat& mask, double scale, Stream& stream, int); | ||||
|  | ||||
| namespace | ||||
| { | ||||
|     template <typename T, typename S, typename D> struct MulScalar : unary_function<T, D> | ||||
|     template <typename SrcType, typename ScalarType, typename DstType> struct MulScalarOp : unary_function<SrcType, DstType> | ||||
|     { | ||||
|         S val; | ||||
|         ScalarType val; | ||||
|  | ||||
|         __host__ explicit MulScalar(S val_) : val(val_) {} | ||||
|  | ||||
|         __device__ __forceinline__ D operator ()(T a) const | ||||
|         __device__ __forceinline__ DstType operator ()(SrcType a) const | ||||
|         { | ||||
|             return saturate_cast<D>(a * val); | ||||
|             return saturate_cast<DstType>(saturate_cast<ScalarType>(a) * val); | ||||
|         } | ||||
|     }; | ||||
| } | ||||
|  | ||||
| namespace cv { namespace cuda { namespace device | ||||
| { | ||||
|     template <typename T, typename S, typename D> struct TransformFunctorTraits< arithm::MulScalar<T, S, D> > : arithm::ArithmFuncTraits<sizeof(T), sizeof(D)> | ||||
|     template <typename ScalarDepth> struct TransformPolicy : DefaultTransformPolicy | ||||
|     { | ||||
|     }; | ||||
| }}} | ||||
|  | ||||
| namespace arithm | ||||
| { | ||||
|     template <typename T, typename S, typename D> | ||||
|     void mulScalar(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream) | ||||
|     template <> struct TransformPolicy<double> : DefaultTransformPolicy | ||||
|     { | ||||
|         MulScalar<T, S, D> op(static_cast<S>(val)); | ||||
|         device::transform((PtrStepSz<T>) src1, (PtrStepSz<D>) dst, op, WithOutMask(), stream); | ||||
|         enum { | ||||
|             shift = 1 | ||||
|         }; | ||||
|     }; | ||||
|  | ||||
|     template <typename SrcType, typename ScalarDepth, typename DstType> | ||||
|     void mulScalarImpl(const GpuMat& src, cv::Scalar value, GpuMat& dst, Stream& stream) | ||||
|     { | ||||
|         typedef typename MakeVec<ScalarDepth, VecTraits<SrcType>::cn>::type ScalarType; | ||||
|  | ||||
|         cv::Scalar_<ScalarDepth> value_ = value; | ||||
|  | ||||
|         MulScalarOp<SrcType, ScalarType, DstType> op; | ||||
|         op.val = VecTraits<ScalarType>::make(value_.val); | ||||
|  | ||||
|         gridTransformUnary_< TransformPolicy<ScalarDepth> >(globPtr<SrcType>(src), globPtr<DstType>(dst), op, stream); | ||||
|     } | ||||
|  | ||||
|     template void mulScalar<uchar, float, uchar>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<uchar, float, schar>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<uchar, float, ushort>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<uchar, float, short>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<uchar, float, int>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<uchar, float, float>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<uchar, double, double>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|  | ||||
|     template void mulScalar<schar, float, uchar>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<schar, float, schar>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<schar, float, ushort>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<schar, float, short>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<schar, float, int>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<schar, float, float>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<schar, double, double>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|  | ||||
|     //template void mulScalar<ushort, float, uchar>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     //template void mulScalar<ushort, float, schar>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<ushort, float, ushort>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<ushort, float, short>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<ushort, float, int>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<ushort, float, float>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<ushort, double, double>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|  | ||||
|     //template void mulScalar<short, float, uchar>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     //template void mulScalar<short, float, schar>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<short, float, ushort>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<short, float, short>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<short, float, int>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<short, float, float>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<short, double, double>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|  | ||||
|     //template void mulScalar<int, float, uchar>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     //template void mulScalar<int, float, schar>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     //template void mulScalar<int, float, ushort>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     //template void mulScalar<int, float, short>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<int, float, int>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<int, float, float>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<int, double, double>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|  | ||||
|     //template void mulScalar<float, float, uchar>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     //template void mulScalar<float, float, schar>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     //template void mulScalar<float, float, ushort>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     //template void mulScalar<float, float, short>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     //template void mulScalar<float, float, int>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<float, float, float>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<float, double, double>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|  | ||||
|     //template void mulScalar<double, double, uchar>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     //template void mulScalar<double, double, schar>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     //template void mulScalar<double, double, ushort>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     //template void mulScalar<double, double, short>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     //template void mulScalar<double, double, int>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     //template void mulScalar<double, double, float>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
|     template void mulScalar<double, double, double>(PtrStepSzb src1, double val, PtrStepSzb dst, cudaStream_t stream); | ||||
| } | ||||
|  | ||||
| #endif // CUDA_DISABLER | ||||
| void mulScalar(const GpuMat& src, cv::Scalar val, bool, GpuMat& dst, const GpuMat&, double scale, Stream& stream, int) | ||||
| { | ||||
|     typedef void (*func_t)(const GpuMat& src, cv::Scalar val, GpuMat& dst, Stream& stream); | ||||
|     static const func_t funcs[7][7][4] = | ||||
|     { | ||||
|         { | ||||
|             {mulScalarImpl<uchar, float, uchar>, mulScalarImpl<uchar2, float, uchar2>, mulScalarImpl<uchar3, float, uchar3>, mulScalarImpl<uchar4, float, uchar4>}, | ||||
|             {mulScalarImpl<uchar, float, schar>, mulScalarImpl<uchar2, float, char2>, mulScalarImpl<uchar3, float, char3>, mulScalarImpl<uchar4, float, char4>}, | ||||
|             {mulScalarImpl<uchar, float, ushort>, mulScalarImpl<uchar2, float, ushort2>, mulScalarImpl<uchar3, float, ushort3>, mulScalarImpl<uchar4, float, ushort4>}, | ||||
|             {mulScalarImpl<uchar, float, short>, mulScalarImpl<uchar2, float, short2>, mulScalarImpl<uchar3, float, short3>, mulScalarImpl<uchar4, float, short4>}, | ||||
|             {mulScalarImpl<uchar, float, int>, mulScalarImpl<uchar2, float, int2>, mulScalarImpl<uchar3, float, int3>, mulScalarImpl<uchar4, float, int4>}, | ||||
|             {mulScalarImpl<uchar, float, float>, mulScalarImpl<uchar2, float, float2>, mulScalarImpl<uchar3, float, float3>, mulScalarImpl<uchar4, float, float4>}, | ||||
|             {mulScalarImpl<uchar, double, double>, mulScalarImpl<uchar2, double, double2>, mulScalarImpl<uchar3, double, double3>, mulScalarImpl<uchar4, double, double4>} | ||||
|         }, | ||||
|         { | ||||
|             {mulScalarImpl<schar, float, uchar>, mulScalarImpl<char2, float, uchar2>, mulScalarImpl<char3, float, uchar3>, mulScalarImpl<char4, float, uchar4>}, | ||||
|             {mulScalarImpl<schar, float, schar>, mulScalarImpl<char2, float, char2>, mulScalarImpl<char3, float, char3>, mulScalarImpl<char4, float, char4>}, | ||||
|             {mulScalarImpl<schar, float, ushort>, mulScalarImpl<char2, float, ushort2>, mulScalarImpl<char3, float, ushort3>, mulScalarImpl<char4, float, ushort4>}, | ||||
|             {mulScalarImpl<schar, float, short>, mulScalarImpl<char2, float, short2>, mulScalarImpl<char3, float, short3>, mulScalarImpl<char4, float, short4>}, | ||||
|             {mulScalarImpl<schar, float, int>, mulScalarImpl<char2, float, int2>, mulScalarImpl<char3, float, int3>, mulScalarImpl<char4, float, int4>}, | ||||
|             {mulScalarImpl<schar, float, float>, mulScalarImpl<char2, float, float2>, mulScalarImpl<char3, float, float3>, mulScalarImpl<char4, float, float4>}, | ||||
|             {mulScalarImpl<schar, double, double>, mulScalarImpl<char2, double, double2>, mulScalarImpl<char3, double, double3>, mulScalarImpl<char4, double, double4>} | ||||
|         }, | ||||
|         { | ||||
|             {0 /*mulScalarImpl<ushort, float, uchar>*/, 0 /*mulScalarImpl<ushort2, float, uchar2>*/, 0 /*mulScalarImpl<ushort3, float, uchar3>*/, 0 /*mulScalarImpl<ushort4, float, uchar4>*/}, | ||||
|             {0 /*mulScalarImpl<ushort, float, schar>*/, 0 /*mulScalarImpl<ushort2, float, char2>*/, 0 /*mulScalarImpl<ushort3, float, char3>*/, 0 /*mulScalarImpl<ushort4, float, char4>*/}, | ||||
|             {mulScalarImpl<ushort, float, ushort>, mulScalarImpl<ushort2, float, ushort2>, mulScalarImpl<ushort3, float, ushort3>, mulScalarImpl<ushort4, float, ushort4>}, | ||||
|             {mulScalarImpl<ushort, float, short>, mulScalarImpl<ushort2, float, short2>, mulScalarImpl<ushort3, float, short3>, mulScalarImpl<ushort4, float, short4>}, | ||||
|             {mulScalarImpl<ushort, float, int>, mulScalarImpl<ushort2, float, int2>, mulScalarImpl<ushort3, float, int3>, mulScalarImpl<ushort4, float, int4>}, | ||||
|             {mulScalarImpl<ushort, float, float>, mulScalarImpl<ushort2, float, float2>, mulScalarImpl<ushort3, float, float3>, mulScalarImpl<ushort4, float, float4>}, | ||||
|             {mulScalarImpl<ushort, double, double>, mulScalarImpl<ushort2, double, double2>, mulScalarImpl<ushort3, double, double3>, mulScalarImpl<ushort4, double, double4>} | ||||
|         }, | ||||
|         { | ||||
|             {0 /*mulScalarImpl<short, float, uchar>*/, 0 /*mulScalarImpl<short2, float, uchar2>*/, 0 /*mulScalarImpl<short3, float, uchar3>*/, 0 /*mulScalarImpl<short4, float, uchar4>*/}, | ||||
|             {0 /*mulScalarImpl<short, float, schar>*/, 0 /*mulScalarImpl<short2, float, char2>*/, 0 /*mulScalarImpl<short3, float, char3>*/, 0 /*mulScalarImpl<short4, float, char4>*/}, | ||||
|             {mulScalarImpl<short, float, ushort>, mulScalarImpl<short2, float, ushort2>, mulScalarImpl<short3, float, ushort3>, mulScalarImpl<short4, float, ushort4>}, | ||||
|             {mulScalarImpl<short, float, short>, mulScalarImpl<short2, float, short2>, mulScalarImpl<short3, float, short3>, mulScalarImpl<short4, float, short4>}, | ||||
|             {mulScalarImpl<short, float, int>, mulScalarImpl<short2, float, int2>, mulScalarImpl<short3, float, int3>, mulScalarImpl<short4, float, int4>}, | ||||
|             {mulScalarImpl<short, float, float>, mulScalarImpl<short2, float, float2>, mulScalarImpl<short3, float, float3>, mulScalarImpl<short4, float, float4>}, | ||||
|             {mulScalarImpl<short, double, double>, mulScalarImpl<short2, double, double2>, mulScalarImpl<short3, double, double3>, mulScalarImpl<short4, double, double4>} | ||||
|         }, | ||||
|         { | ||||
|             {0 /*mulScalarImpl<int, float, uchar>*/, 0 /*mulScalarImpl<int2, float, uchar2>*/, 0 /*mulScalarImpl<int3, float, uchar3>*/, 0 /*mulScalarImpl<int4, float, uchar4>*/}, | ||||
|             {0 /*mulScalarImpl<int, float, schar>*/, 0 /*mulScalarImpl<int2, float, char2>*/, 0 /*mulScalarImpl<int3, float, char3>*/, 0 /*mulScalarImpl<int4, float, char4>*/}, | ||||
|             {0 /*mulScalarImpl<int, float, ushort>*/, 0 /*mulScalarImpl<int2, float, ushort2>*/, 0 /*mulScalarImpl<int3, float, ushort3>*/, 0 /*mulScalarImpl<int4, float, ushort4>*/}, | ||||
|             {0 /*mulScalarImpl<int, float, short>*/, 0 /*mulScalarImpl<int2, float, short2>*/, 0 /*mulScalarImpl<int3, float, short3>*/, 0 /*mulScalarImpl<int4, float, short4>*/}, | ||||
|             {mulScalarImpl<int, float, int>, mulScalarImpl<int2, float, int2>, mulScalarImpl<int3, float, int3>, mulScalarImpl<int4, float, int4>}, | ||||
|             {mulScalarImpl<int, float, float>, mulScalarImpl<int2, float, float2>, mulScalarImpl<int3, float, float3>, mulScalarImpl<int4, float, float4>}, | ||||
|             {mulScalarImpl<int, double, double>, mulScalarImpl<int2, double, double2>, mulScalarImpl<int3, double, double3>, mulScalarImpl<int4, double, double4>} | ||||
|         }, | ||||
|         { | ||||
|             {0 /*mulScalarImpl<float, float, uchar>*/, 0 /*mulScalarImpl<float2, float, uchar2>*/, 0 /*mulScalarImpl<float3, float, uchar3>*/, 0 /*mulScalarImpl<float4, float, uchar4>*/}, | ||||
|             {0 /*mulScalarImpl<float, float, schar>*/, 0 /*mulScalarImpl<float2, float, char2>*/, 0 /*mulScalarImpl<float3, float, char3>*/, 0 /*mulScalarImpl<float4, float, char4>*/}, | ||||
|             {0 /*mulScalarImpl<float, float, ushort>*/, 0 /*mulScalarImpl<float2, float, ushort2>*/, 0 /*mulScalarImpl<float3, float, ushort3>*/, 0 /*mulScalarImpl<float4, float, ushort4>*/}, | ||||
|             {0 /*mulScalarImpl<float, float, short>*/, 0 /*mulScalarImpl<float2, float, short2>*/, 0 /*mulScalarImpl<float3, float, short3>*/, 0 /*mulScalarImpl<float4, float, short4>*/}, | ||||
|             {0 /*mulScalarImpl<float, float, int>*/, 0 /*mulScalarImpl<float2, float, int2>*/, 0 /*mulScalarImpl<float3, float, int3>*/, 0 /*mulScalarImpl<float4, float, int4>*/}, | ||||
|             {mulScalarImpl<float, float, float>, mulScalarImpl<float2, float, float2>, mulScalarImpl<float3, float, float3>, mulScalarImpl<float4, float, float4>}, | ||||
|             {mulScalarImpl<float, double, double>, mulScalarImpl<float2, double, double2>, mulScalarImpl<float3, double, double3>, mulScalarImpl<float4, double, double4>} | ||||
|         }, | ||||
|         { | ||||
|             {0 /*mulScalarImpl<double, double, uchar>*/, 0 /*mulScalarImpl<double2, double, uchar2>*/, 0 /*mulScalarImpl<double3, double, uchar3>*/, 0 /*mulScalarImpl<double4, double, uchar4>*/}, | ||||
|             {0 /*mulScalarImpl<double, double, schar>*/, 0 /*mulScalarImpl<double2, double, char2>*/, 0 /*mulScalarImpl<double3, double, char3>*/, 0 /*mulScalarImpl<double4, double, char4>*/}, | ||||
|             {0 /*mulScalarImpl<double, double, ushort>*/, 0 /*mulScalarImpl<double2, double, ushort2>*/, 0 /*mulScalarImpl<double3, double, ushort3>*/, 0 /*mulScalarImpl<double4, double, ushort4>*/}, | ||||
|             {0 /*mulScalarImpl<double, double, short>*/, 0 /*mulScalarImpl<double2, double, short2>*/, 0 /*mulScalarImpl<double3, double, short3>*/, 0 /*mulScalarImpl<double4, double, short4>*/}, | ||||
|             {0 /*mulScalarImpl<double, double, int>*/, 0 /*mulScalarImpl<double2, double, int2>*/, 0 /*mulScalarImpl<double3, double, int3>*/, 0 /*mulScalarImpl<double4, double, int4>*/}, | ||||
|             {0 /*mulScalarImpl<double, double, float>*/, 0 /*mulScalarImpl<double2, double, float2>*/, 0 /*mulScalarImpl<double3, double, float3>*/, 0 /*mulScalarImpl<double4, double, float4>*/}, | ||||
|             {mulScalarImpl<double, double, double>, mulScalarImpl<double2, double, double2>, mulScalarImpl<double3, double, double3>, mulScalarImpl<double4, double, double4>} | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     const int sdepth = src.depth(); | ||||
|     const int ddepth = dst.depth(); | ||||
|     const int cn = src.channels(); | ||||
|  | ||||
|     CV_DbgAssert( sdepth < 7 && ddepth < 7 && cn <= 4 ); | ||||
|  | ||||
|     val[0] *= scale; | ||||
|     val[1] *= scale; | ||||
|     val[2] *= scale; | ||||
|     val[3] *= scale; | ||||
|  | ||||
|     const func_t func = funcs[sdepth][ddepth][cn - 1]; | ||||
|  | ||||
|     if (!func) | ||||
|         CV_Error(cv::Error::StsUnsupportedFormat, "Unsupported combination of source and destination types"); | ||||
|  | ||||
|     func(src, val, dst, stream); | ||||
| } | ||||
|  | ||||
| #endif | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Vladislav Vinogradov
					Vladislav Vinogradov