Add OpenCL SVM paths for bagofwords_classification and points_classifier samples.

This commit is contained in:
Baichuan Su
2013-11-09 08:42:39 +08:00
committed by Ilya Lavrenov
parent d8a4d3a2eb
commit 00300baa53
3 changed files with 54 additions and 1 deletions

View File

@@ -1,6 +1,10 @@
#include "opencv2/opencv_modules.hpp"
#include "opencv2/core/core.hpp"
#include "opencv2/ml/ml.hpp"
#include "opencv2/highgui/highgui.hpp"
#ifdef HAVE_OPENCV_OCL
#include "opencv2/ocl/ocl.hpp"
#endif
#include <stdio.h>
@@ -133,7 +137,14 @@ static void find_decision_boundary_KNN( int K )
prepare_train_data( trainSamples, trainClasses );
// learn classifier
#ifdef HAVE_OPENCV_OCL
cv::ocl::KNearestNeighbour knnClassifier;
Mat temp, result;
knnClassifier.train(trainSamples, trainClasses, temp, false, K);
cv::ocl::oclMat testSample_ocl, reslut_ocl;
#else
CvKNearest knnClassifier( trainSamples, trainClasses, Mat(), false, K );
#endif
Mat testSample( 1, 2, CV_32FC1 );
for( int y = 0; y < img.rows; y += testStep )
@@ -142,9 +153,19 @@ static void find_decision_boundary_KNN( int K )
{
testSample.at<float>(0) = (float)x;
testSample.at<float>(1) = (float)y;
#ifdef HAVE_OPENCV_OCL
testSample_ocl.upload(testSample);
knnClassifier.find_nearest(testSample_ocl, K, reslut_ocl);
reslut_ocl.download(result);
int response = saturate_cast<int>(result.at<float>(0));
circle(imgDst, Point(x, y), 1, classColors[response]);
#else
int response = (int)knnClassifier.find_nearest( testSample, K );
circle( imgDst, Point(x,y), 1, classColors[response] );
#endif
}
}
}
@@ -159,7 +180,11 @@ static void find_decision_boundary_SVM( CvSVMParams params )
prepare_train_data( trainSamples, trainClasses );
// learn classifier
#ifdef HAVE_OPENCV_OCL
cv::ocl::CvSVM_OCL svmClassifier(trainSamples, trainClasses, Mat(), Mat(), params);
#else
CvSVM svmClassifier( trainSamples, trainClasses, Mat(), Mat(), params );
#endif
Mat testSample( 1, 2, CV_32FC1 );
for( int y = 0; y < img.rows; y += testStep )
@@ -178,7 +203,7 @@ static void find_decision_boundary_SVM( CvSVMParams params )
for( int i = 0; i < svmClassifier.get_support_vector_count(); i++ )
{
const float* supportVector = svmClassifier.get_support_vector(i);
circle( imgDst, Point(supportVector[0],supportVector[1]), 5, Scalar(255,255,255), -1 );
circle( imgDst, Point(saturate_cast<int>(supportVector[0]),saturate_cast<int>(supportVector[1])), 5, CV_RGB(255,255,255), -1 );
}
}