Added margin type, added tests with different scales of features.
Also fixed documentation, refactored sample.
This commit is contained in:
@@ -40,16 +40,13 @@ void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int
|
||||
void redraw(Data data, const Point points[2]);
|
||||
|
||||
//add point in train set, train SVMSGD algorithm and draw results on image
|
||||
void addPointRetrainAndRedraw(Data &data, int x, int y);
|
||||
void addPointRetrainAndRedraw(Data &data, int x, int y, int response);
|
||||
|
||||
|
||||
bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift)
|
||||
{
|
||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
|
||||
svmsgd->setOptimalParameters(SVMSGD::ASGD);
|
||||
svmsgd->setTermCriteria(TermCriteria(TermCriteria::EPS, 0, 0.00000001));
|
||||
svmsgd->setLambda(0.00000001);
|
||||
|
||||
svmsgd->setOptimalParameters();
|
||||
|
||||
cv::Ptr<TrainData> train_data = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses);
|
||||
svmsgd->train( train_data );
|
||||
@@ -64,6 +61,27 @@ bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift
|
||||
return false;
|
||||
}
|
||||
|
||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height)
|
||||
{
|
||||
std::pair<Point,Point> currentSegment;
|
||||
|
||||
currentSegment.first = Point(width, 0);
|
||||
currentSegment.second = Point(width, height);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
currentSegment.first = Point(0, height);
|
||||
currentSegment.second = Point(width, height);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
currentSegment.first = Point(0, 0);
|
||||
currentSegment.second = Point(width, 0);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
currentSegment.first = Point(0, 0);
|
||||
currentSegment.second = Point(0, height);
|
||||
segments.push_back(currentSegment);
|
||||
}
|
||||
|
||||
|
||||
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
|
||||
{
|
||||
@@ -123,27 +141,6 @@ bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2], int
|
||||
return true;
|
||||
}
|
||||
|
||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height)
|
||||
{
|
||||
std::pair<Point,Point> currentSegment;
|
||||
|
||||
currentSegment.first = Point(width, 0);
|
||||
currentSegment.second = Point(width, height);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
currentSegment.first = Point(0, height);
|
||||
currentSegment.second = Point(width, height);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
currentSegment.first = Point(0, 0);
|
||||
currentSegment.second = Point(width, 0);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
currentSegment.first = Point(0, 0);
|
||||
currentSegment.second = Point(0, height);
|
||||
segments.push_back(currentSegment);
|
||||
}
|
||||
|
||||
void redraw(Data data, const Point points[2])
|
||||
{
|
||||
data.img.setTo(0);
|
||||
@@ -162,19 +159,20 @@ void redraw(Data data, const Point points[2])
|
||||
imshow("Train svmsgd", data.img);
|
||||
}
|
||||
|
||||
void addPointRetrainAndRedraw(Data &data, int x, int y)
|
||||
void addPointRetrainAndRedraw(Data &data, int x, int y, int response)
|
||||
{
|
||||
Mat currentSample(1, 2, CV_32F);
|
||||
|
||||
currentSample.at<float>(0,0) = x;
|
||||
currentSample.at<float>(0,1) = y;
|
||||
data.samples.push_back(currentSample);
|
||||
data.responses.push_back(response);
|
||||
|
||||
Mat weights(1, 2, CV_32F);
|
||||
float shift = 0;
|
||||
|
||||
if (doTrain(data.samples, data.responses, weights, shift))
|
||||
{
|
||||
{
|
||||
Point points[2];
|
||||
findPointsForLine(weights, shift, points, data.img.cols, data.img.rows);
|
||||
|
||||
@@ -189,15 +187,12 @@ static void onMouse( int event, int x, int y, int, void* pData)
|
||||
|
||||
switch( event )
|
||||
{
|
||||
case CV_EVENT_LBUTTONUP:
|
||||
data.responses.push_back(1);
|
||||
addPointRetrainAndRedraw(data, x, y);
|
||||
|
||||
case CV_EVENT_LBUTTONUP:
|
||||
addPointRetrainAndRedraw(data, x, y, 1);
|
||||
break;
|
||||
|
||||
case CV_EVENT_RBUTTONDOWN:
|
||||
data.responses.push_back(-1);
|
||||
addPointRetrainAndRedraw(data, x, y);
|
||||
addPointRetrainAndRedraw(data, x, y, -1);
|
||||
break;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user