Improved thrust interop tutorial.

This commit is contained in:
Dan
2015-09-16 12:03:35 -04:00
parent 09d392f09d
commit 23fc5930b7
4 changed files with 104 additions and 47 deletions

View File

@@ -6,20 +6,10 @@
#include <thrust/iterator/counting_iterator.h>
#include <thrust/device_ptr.h>
template<typename T> struct
CV_TYPE
{
static const int DEPTH;
};
template<> static const int CV_TYPE<float>::DEPTH = CV_32F;
template<> static const int CV_TYPE<double>::DEPTH = CV_64F;
template<> static const int CV_TYPE<int>::DEPTH = CV_32S;
template<> static const int CV_TYPE<uchar>::DEPTH = CV_8U;
template<> static const int CV_TYPE<char>::DEPTH = CV_8S;
template<> static const int CV_TYPE<ushort>::DEPTH = CV_16U;
template<> static const int CV_TYPE<short>::DEPTH = CV_16S;
/*
@Brief step_functor is an object to correctly step a thrust iterator according to the stride of a matrix
*/
//! [step_functor]
template<typename T> struct step_functor : public thrust::unary_function<int, int>
{
int columns;
@@ -41,7 +31,8 @@ template<typename T> struct step_functor : public thrust::unary_function<int, in
return idx;
}
};
//! [step_functor]
//! [begin_itr]
/*
@Brief GpuMatBeginItr returns a thrust compatible iterator to the beginning of a GPU mat's memory.
@Param mat is the input matrix
@@ -52,11 +43,13 @@ thrust::permutation_iterator<thrust::device_ptr<T>, thrust::transform_iterator<s
{
if (channel == -1)
mat = mat.reshape(1);
CV_Assert(mat.depth() == CV_TYPE<T>::DEPTH);
CV_Assert(mat.depth() == cv::DataType<T>::depth);
CV_Assert(channel < mat.channels());
return thrust::make_permutation_iterator(thrust::device_pointer_cast(mat.ptr<T>(0) + channel),
thrust::make_transform_iterator(thrust::make_counting_iterator(0), step_functor<T>(mat.cols, mat.step / sizeof(T), mat.channels())));
}
//! [begin_itr]
//! [end_itr]
/*
@Brief GpuMatEndItr returns a thrust compatible iterator to the end of a GPU mat's memory.
@Param mat is the input matrix
@@ -67,8 +60,11 @@ thrust::permutation_iterator<thrust::device_ptr<T>, thrust::transform_iterator<s
{
if (channel == -1)
mat = mat.reshape(1);
CV_Assert(mat.depth() == CV_TYPE<T>::DEPTH);
CV_Assert(mat.depth() == cv::DataType<T>::depth);
CV_Assert(channel < mat.channels());
return thrust::make_permutation_iterator(thrust::device_pointer_cast(mat.ptr<T>(0) + channel),
thrust::make_transform_iterator(thrust::make_counting_iterator(mat.rows*mat.cols), step_functor<T>(mat.cols, mat.step / sizeof(T), mat.channels())));
}
}
//! [end_itr]