Added support of GPU in stitching seam estimators

This commit is contained in:
Alexey Spizhevoy 2011-09-26 07:57:05 +00:00
parent 16f5c67914
commit 66b41b67f9
4 changed files with 401 additions and 35 deletions

View File

@ -51,9 +51,6 @@ namespace detail {
class CV_EXPORTS SeamFinder
{
public:
enum { NO, VORONOI, GC_COLOR, GC_COLOR_GRAD };
static Ptr<SeamFinder> createDefault(int type);
virtual ~SeamFinder() {}
virtual void find(const std::vector<Mat> &src, const std::vector<Point> &corners,
std::vector<Mat> &masks) = 0;
@ -89,10 +86,16 @@ private:
};
class CV_EXPORTS GraphCutSeamFinder : public SeamFinder
class CV_EXPORTS GraphCutSeamFinderBase
{
public:
enum { COST_COLOR, COST_COLOR_GRAD };
};
class CV_EXPORTS GraphCutSeamFinder : public GraphCutSeamFinderBase, public SeamFinder
{
public:
GraphCutSeamFinder(int cost_type = COST_COLOR_GRAD, float terminal_cost = 10000.f,
float bad_region_penalty = 1000.f);
@ -105,6 +108,33 @@ private:
Ptr<Impl> impl_;
};
#ifndef ANDROID
class CV_EXPORTS GraphCutSeamFinderGpu : public GraphCutSeamFinderBase, public PairwiseSeamFinder
{
public:
GraphCutSeamFinderGpu(int cost_type = COST_COLOR_GRAD, float terminal_cost = 10000.f,
float bad_region_penalty = 1000.f)
: cost_type_(cost_type), terminal_cost_(terminal_cost),
bad_region_penalty_(bad_region_penalty) {}
void find(const std::vector<cv::Mat> &src, const std::vector<cv::Point> &corners,
std::vector<cv::Mat> &masks);
void findInPair(size_t first, size_t second, Rect roi);
private:
void setGraphWeightsColor(const cv::Mat &img1, const cv::Mat &img2, const cv::Mat &mask1, const cv::Mat &mask2,
cv::Mat &terminals, cv::Mat &leftT, cv::Mat &rightT, cv::Mat &top, cv::Mat &bottom);
void setGraphWeightsColorGrad(const cv::Mat &img1, const cv::Mat &img2, const cv::Mat &dx1, const cv::Mat &dx2,
const cv::Mat &dy1, const cv::Mat &dy2, const cv::Mat &mask1, const cv::Mat &mask2,
cv::Mat &terminals, cv::Mat &leftT, cv::Mat &rightT, cv::Mat &top, cv::Mat &bottom);
std::vector<Mat> dx_, dy_;
int cost_type_;
float terminal_cost_;
float bad_region_penalty_;
};
#endif
} // namespace detail
} // namespace cv

View File

@ -42,26 +42,9 @@
#include "precomp.hpp"
using namespace std;
namespace cv {
namespace detail {
Ptr<SeamFinder> SeamFinder::createDefault(int type)
{
if (type == NO)
return new NoSeamFinder();
if (type == VORONOI)
return new VoronoiSeamFinder();
if (type == GC_COLOR)
return new GraphCutSeamFinder(GraphCutSeamFinder::COST_COLOR);
if (type == GC_COLOR_GRAD)
return new GraphCutSeamFinder(GraphCutSeamFinder::COST_COLOR_GRAD);
CV_Error(CV_StsBadArg, "unsupported seam finding method");
return NULL;
}
void PairwiseSeamFinder::find(const vector<Mat> &src, const vector<Point> &corners,
vector<Mat> &masks)
{
@ -407,6 +390,333 @@ void GraphCutSeamFinder::find(const vector<Mat> &src, const vector<Point> &corne
impl_->find(src, corners, masks);
}
#ifndef ANDROID
void GraphCutSeamFinderGpu::find(const vector<Mat> &src, const vector<Point> &corners,
vector<Mat> &masks)
{
// Compute gradients
dx_.resize(src.size());
dy_.resize(src.size());
Mat dx, dy;
for (size_t i = 0; i < src.size(); ++i)
{
CV_Assert(src[i].channels() == 3);
Sobel(src[i], dx, CV_32F, 1, 0);
Sobel(src[i], dy, CV_32F, 0, 1);
dx_[i].create(src[i].size(), CV_32F);
dy_[i].create(src[i].size(), CV_32F);
for (int y = 0; y < src[i].rows; ++y)
{
const Point3f* dx_row = dx.ptr<Point3f>(y);
const Point3f* dy_row = dy.ptr<Point3f>(y);
float* dx_row_ = dx_[i].ptr<float>(y);
float* dy_row_ = dy_[i].ptr<float>(y);
for (int x = 0; x < src[i].cols; ++x)
{
dx_row_[x] = normL2(dx_row[x]);
dy_row_[x] = normL2(dy_row[x]);
}
}
}
PairwiseSeamFinder::find(src, corners, masks);
}
void GraphCutSeamFinderGpu::findInPair(size_t first, size_t second, Rect roi)
{
Mat img1 = images_[first], img2 = images_[second];
Mat dx1 = dx_[first], dx2 = dx_[second];
Mat dy1 = dy_[first], dy2 = dy_[second];
Mat mask1 = masks_[first], mask2 = masks_[second];
Point tl1 = corners_[first], tl2 = corners_[second];
const int gap = 10;
Mat subimg1(roi.height + 2 * gap, roi.width + 2 * gap, CV_32FC3);
Mat subimg2(roi.height + 2 * gap, roi.width + 2 * gap, CV_32FC3);
Mat submask1(roi.height + 2 * gap, roi.width + 2 * gap, CV_8U);
Mat submask2(roi.height + 2 * gap, roi.width + 2 * gap, CV_8U);
Mat subdx1(roi.height + 2 * gap, roi.width + 2 * gap, CV_32F);
Mat subdy1(roi.height + 2 * gap, roi.width + 2 * gap, CV_32F);
Mat subdx2(roi.height + 2 * gap, roi.width + 2 * gap, CV_32F);
Mat subdy2(roi.height + 2 * gap, roi.width + 2 * gap, CV_32F);
// Cut subimages and submasks with some gap
for (int y = -gap; y < roi.height + gap; ++y)
{
for (int x = -gap; x < roi.width + gap; ++x)
{
int y1 = roi.y - tl1.y + y;
int x1 = roi.x - tl1.x + x;
if (y1 >= 0 && x1 >= 0 && y1 < img1.rows && x1 < img1.cols)
{
subimg1.at<Point3f>(y + gap, x + gap) = img1.at<Point3f>(y1, x1);
submask1.at<uchar>(y + gap, x + gap) = mask1.at<uchar>(y1, x1);
subdx1.at<float>(y + gap, x + gap) = dx1.at<float>(y1, x1);
subdy1.at<float>(y + gap, x + gap) = dy1.at<float>(y1, x1);
}
else
{
subimg1.at<Point3f>(y + gap, x + gap) = Point3f(0, 0, 0);
submask1.at<uchar>(y + gap, x + gap) = 0;
subdx1.at<float>(y + gap, x + gap) = 0.f;
subdy1.at<float>(y + gap, x + gap) = 0.f;
}
int y2 = roi.y - tl2.y + y;
int x2 = roi.x - tl2.x + x;
if (y2 >= 0 && x2 >= 0 && y2 < img2.rows && x2 < img2.cols)
{
subimg2.at<Point3f>(y + gap, x + gap) = img2.at<Point3f>(y2, x2);
submask2.at<uchar>(y + gap, x + gap) = mask2.at<uchar>(y2, x2);
subdx2.at<float>(y + gap, x + gap) = dx2.at<float>(y2, x2);
subdy2.at<float>(y + gap, x + gap) = dy2.at<float>(y2, x2);
}
else
{
subimg2.at<Point3f>(y + gap, x + gap) = Point3f(0, 0, 0);
submask2.at<uchar>(y + gap, x + gap) = 0;
subdx2.at<float>(y + gap, x + gap) = 0.f;
subdy2.at<float>(y + gap, x + gap) = 0.f;
}
}
}
Mat terminals, leftT, rightT, top, bottom;
switch (cost_type_)
{
case GraphCutSeamFinder::COST_COLOR:
setGraphWeightsColor(subimg1, subimg2, submask1, submask2,
terminals, leftT, rightT, top, bottom);
break;
case GraphCutSeamFinder::COST_COLOR_GRAD:
setGraphWeightsColorGrad(subimg1, subimg2, subdx1, subdx2, subdy1, subdy2,
submask1, submask2, terminals, leftT, rightT, top, bottom);
break;
default:
CV_Error(CV_StsBadArg, "unsupported pixel similarity measure");
}
gpu::GpuMat terminals_d(terminals);
gpu::GpuMat leftT_d(leftT);
gpu::GpuMat rightT_d(rightT);
gpu::GpuMat top_d(top);
gpu::GpuMat bottom_d(bottom);
gpu::GpuMat labels_d, buf_d;
gpu::graphcut(terminals_d, leftT_d, rightT_d, top_d, bottom_d, labels_d, buf_d);
Mat_<uchar> labels = labels_d;
for (int y = 0; y < roi.height; ++y)
{
for (int x = 0; x < roi.width; ++x)
{
if (labels(y + gap, x + gap))
{
if (mask1.at<uchar>(roi.y - tl1.y + y, roi.x - tl1.x + x))
mask2.at<uchar>(roi.y - tl2.y + y, roi.x - tl2.x + x) = 0;
}
else
{
if (mask2.at<uchar>(roi.y - tl2.y + y, roi.x - tl2.x + x))
mask1.at<uchar>(roi.y - tl1.y + y, roi.x - tl1.x + x) = 0;
}
}
}
}
void GraphCutSeamFinderGpu::setGraphWeightsColor(const Mat &img1, const Mat &img2, const Mat &mask1, const Mat &mask2,
Mat &terminals, Mat &leftT, Mat &rightT, Mat &top, Mat &bottom)
{
const Size img_size = img1.size();
terminals.create(img_size, CV_32S);
leftT.create(Size(img_size.height, img_size.width), CV_32S);
rightT.create(Size(img_size.height, img_size.width), CV_32S);
top.create(img_size, CV_32S);
bottom.create(img_size, CV_32S);
Mat_<int> terminals_(terminals);
Mat_<int> leftT_(leftT);
Mat_<int> rightT_(rightT);
Mat_<int> top_(top);
Mat_<int> bottom_(bottom);
// Set terminal weights
for (int y = 0; y < img_size.height; ++y)
{
for (int x = 0; x < img_size.width; ++x)
{
float source = mask1.at<uchar>(y, x) ? terminal_cost_ : 0.f;
float sink = mask2.at<uchar>(y, x) ? terminal_cost_ : 0.f;
terminals_(y, x) = saturate_cast<int>((source - sink) * 255.f);
}
}
// Set regular edge weights
const float weight_eps = 1.f;
for (int y = 0; y < img_size.height; ++y)
{
for (int x = 0; x < img_size.width; ++x)
{
if (x > 0)
{
float weight = normL2(img1.at<Point3f>(y, x - 1), img2.at<Point3f>(y, x - 1)) +
normL2(img1.at<Point3f>(y, x), img2.at<Point3f>(y, x)) +
weight_eps;
if (!mask1.at<uchar>(y, x - 1) || !mask1.at<uchar>(y, x) ||
!mask2.at<uchar>(y, x - 1) || !mask2.at<uchar>(y, x))
weight += bad_region_penalty_;
leftT_(x, y) = saturate_cast<int>(weight * 255.f);
}
else
leftT_(x, y) = 0;
if (x < img_size.width - 1)
{
float weight = normL2(img1.at<Point3f>(y, x), img2.at<Point3f>(y, x)) +
normL2(img1.at<Point3f>(y, x + 1), img2.at<Point3f>(y, x + 1)) +
weight_eps;
if (!mask1.at<uchar>(y, x) || !mask1.at<uchar>(y, x + 1) ||
!mask2.at<uchar>(y, x) || !mask2.at<uchar>(y, x + 1))
weight += bad_region_penalty_;
rightT_(x, y) = saturate_cast<int>(weight * 255.f);
}
else
rightT_(x, y) = 0;
if (y > 0)
{
float weight = normL2(img1.at<Point3f>(y - 1, x), img2.at<Point3f>(y - 1, x)) +
normL2(img1.at<Point3f>(y, x), img2.at<Point3f>(y, x)) +
weight_eps;
if (!mask1.at<uchar>(y - 1, x) || !mask1.at<uchar>(y, x) ||
!mask2.at<uchar>(y - 1, x) || !mask2.at<uchar>(y, x))
weight += bad_region_penalty_;
top_(y, x) = saturate_cast<int>(weight * 255.f);
}
else
top_(y, x) = 0;
if (y < img_size.height - 1)
{
float weight = normL2(img1.at<Point3f>(y, x), img2.at<Point3f>(y, x)) +
normL2(img1.at<Point3f>(y + 1, x), img2.at<Point3f>(y + 1, x)) +
weight_eps;
if (!mask1.at<uchar>(y, x) || !mask1.at<uchar>(y + 1, x) ||
!mask2.at<uchar>(y, x) || !mask2.at<uchar>(y + 1, x))
weight += bad_region_penalty_;
bottom_(y, x) = saturate_cast<int>(weight * 255.f);
}
else
bottom_(y, x) = 0;
}
}
}
void GraphCutSeamFinderGpu::setGraphWeightsColorGrad(
const Mat &img1, const Mat &img2, const Mat &dx1, const Mat &dx2,
const Mat &dy1, const Mat &dy2, const Mat &mask1, const Mat &mask2,
Mat &terminals, Mat &leftT, Mat &rightT, Mat &top, Mat &bottom)
{
const Size img_size = img1.size();
terminals.create(img_size, CV_32S);
leftT.create(Size(img_size.height, img_size.width), CV_32S);
rightT.create(Size(img_size.height, img_size.width), CV_32S);
top.create(img_size, CV_32S);
bottom.create(img_size, CV_32S);
Mat_<int> terminals_(terminals);
Mat_<int> leftT_(leftT);
Mat_<int> rightT_(rightT);
Mat_<int> top_(top);
Mat_<int> bottom_(bottom);
// Set terminal weights
for (int y = 0; y < img_size.height; ++y)
{
for (int x = 0; x < img_size.width; ++x)
{
float source = mask1.at<uchar>(y, x) ? terminal_cost_ : 0.f;
float sink = mask2.at<uchar>(y, x) ? terminal_cost_ : 0.f;
terminals_(y, x) = saturate_cast<int>((source - sink) * 255.f);
}
}
// Set regular edge weights
const float weight_eps = 1.f;
for (int y = 0; y < img_size.height; ++y)
{
for (int x = 0; x < img_size.width; ++x)
{
if (x > 0)
{
float grad = dx1.at<float>(y, x - 1) + dx1.at<float>(y, x) +
dx2.at<float>(y, x - 1) + dx2.at<float>(y, x) + weight_eps;
float weight = (normL2(img1.at<Point3f>(y, x - 1), img2.at<Point3f>(y, x - 1)) +
normL2(img1.at<Point3f>(y, x), img2.at<Point3f>(y, x))) / grad +
weight_eps;
if (!mask1.at<uchar>(y, x - 1) || !mask1.at<uchar>(y, x) ||
!mask2.at<uchar>(y, x - 1) || !mask2.at<uchar>(y, x))
weight += bad_region_penalty_;
leftT_(x, y) = saturate_cast<int>(weight * 255.f);
}
else
leftT_(x, y) = 0;
if (x < img_size.width - 1)
{
float grad = dx1.at<float>(y, x) + dx1.at<float>(y, x + 1) +
dx2.at<float>(y, x) + dx2.at<float>(y, x + 1) + weight_eps;
float weight = (normL2(img1.at<Point3f>(y, x), img2.at<Point3f>(y, x)) +
normL2(img1.at<Point3f>(y, x + 1), img2.at<Point3f>(y, x + 1))) / grad +
weight_eps;
if (!mask1.at<uchar>(y, x) || !mask1.at<uchar>(y, x + 1) ||
!mask2.at<uchar>(y, x) || !mask2.at<uchar>(y, x + 1))
weight += bad_region_penalty_;
rightT_(x, y) = saturate_cast<int>(weight * 255.f);
}
else
rightT_(x, y) = 0;
if (y > 0)
{
float grad = dy1.at<float>(y - 1, x) + dy1.at<float>(y, x) +
dy2.at<float>(y - 1, x) + dy2.at<float>(y, x) + weight_eps;
float weight = (normL2(img1.at<Point3f>(y - 1, x), img2.at<Point3f>(y - 1, x)) +
normL2(img1.at<Point3f>(y, x), img2.at<Point3f>(y, x))) / grad +
weight_eps;
if (!mask1.at<uchar>(y - 1, x) || !mask1.at<uchar>(y, x) ||
!mask2.at<uchar>(y - 1, x) || !mask2.at<uchar>(y, x))
weight += bad_region_penalty_;
top_(y, x) = saturate_cast<int>(weight * 255.f);
}
else
top_(y, x) = 0;
if (y < img_size.height - 1)
{
float grad = dy1.at<float>(y, x) + dy1.at<float>(y + 1, x) +
dy2.at<float>(y, x) + dy2.at<float>(y + 1, x) + weight_eps;
float weight = (normL2(img1.at<Point3f>(y, x), img2.at<Point3f>(y, x)) +
normL2(img1.at<Point3f>(y + 1, x), img2.at<Point3f>(y + 1, x))) / grad +
weight_eps;
if (!mask1.at<uchar>(y, x) || !mask1.at<uchar>(y + 1, x) ||
!mask2.at<uchar>(y, x) || !mask2.at<uchar>(y + 1, x))
weight += bad_region_penalty_;
bottom_(y, x) = saturate_cast<int>(weight * 255.f);
}
else
bottom_(y, x) = 0;
}
}
}
#endif
} // namespace detail
} // namespace cv

View File

@ -63,16 +63,17 @@ Stitcher Stitcher::createDefault(bool try_use_gpu)
{
stitcher.setFeaturesFinder(new detail::SurfFeaturesFinderGpu());
stitcher.setWarper(new SphericalWarperGpu());
stitcher.setSeamFinder(new detail::GraphCutSeamFinderGpu());
}
else
#endif
{
stitcher.setFeaturesFinder(new detail::SurfFeaturesFinder());
stitcher.setWarper(new SphericalWarper());
stitcher.setSeamFinder(new detail::GraphCutSeamFinder());
}
stitcher.setExposureCompenstor(new detail::BlocksGainCompensator());
stitcher.setSeamFinder(new detail::GraphCutSeamFinder());
stitcher.setBlender(new detail::MultiBandBlender(try_use_gpu));
return stitcher;

View File

@ -132,7 +132,7 @@ std::string save_graph_to;
string warp_type = "spherical";
int expos_comp_type = ExposureCompensator::GAIN_BLOCKS;
float match_conf = 0.65f;
int seam_find_type = SeamFinder::GC_COLOR;
string seam_find_type = "gc_color";
int blend_type = Blender::MULTI_BAND;
float blend_strength = 5;
string result_name = "result.jpg";
@ -262,14 +262,11 @@ int parseCmdArgs(int argc, char** argv)
}
else if (string(argv[i]) == "--seam")
{
if (string(argv[i + 1]) == "no")
seam_find_type = SeamFinder::NO;
else if (string(argv[i + 1]) == "voronoi")
seam_find_type = SeamFinder::VORONOI;
else if (string(argv[i + 1]) == "gc_color")
seam_find_type = SeamFinder::GC_COLOR;
else if (string(argv[i + 1]) == "gc_colorgrad")
seam_find_type = SeamFinder::GC_COLOR_GRAD;
if (string(argv[i + 1]) == "no" ||
string(argv[i + 1]) == "voronoi" ||
string(argv[i + 1]) == "gc_color" ||
string(argv[i + 1]) == "gc_colorgrad")
seam_find_type = argv[i + 1];
else
{
cout << "Bad seam finding method\n";
@ -550,7 +547,35 @@ int main(int argc, char* argv[])
Ptr<ExposureCompensator> compensator = ExposureCompensator::createDefault(expos_comp_type);
compensator->feed(corners, images_warped, masks_warped);
Ptr<SeamFinder> seam_finder = SeamFinder::createDefault(seam_find_type);
Ptr<SeamFinder> seam_finder;
if (seam_find_type == "no")
seam_finder = new detail::NoSeamFinder();
else if (seam_find_type == "voronoi")
seam_finder = new detail::VoronoiSeamFinder();
else if (seam_find_type == "gc_color")
{
#ifndef ANDROID
if (try_gpu)
seam_finder = new detail::GraphCutSeamFinderGpu(GraphCutSeamFinderBase::COST_COLOR);
else
#endif
seam_finder = new detail::GraphCutSeamFinder(GraphCutSeamFinderBase::COST_COLOR);
}
else if (seam_find_type == "gc_colorgrad")
{
#ifndef ANDROID
if (try_gpu)
seam_finder = new detail::GraphCutSeamFinderGpu(GraphCutSeamFinderBase::COST_COLOR_GRAD);
else
#endif
seam_finder = new detail::GraphCutSeamFinder(GraphCutSeamFinderBase::COST_COLOR_GRAD);
}
if (seam_finder.empty())
{
cout << "Can't create the following seam finder '" << seam_find_type << "'\n";
return 1;
}
seam_finder->find(images_warped_f, corners, masks_warped);
// Release unused memory