Added margin type, added tests with different scales of features.

Also fixed documentation, refactored sample.
This commit is contained in:
Marina Noskova
2016-02-09 18:42:23 +03:00
parent acd74037b3
commit bfdca05f25
6 changed files with 429 additions and 221 deletions

View File

@@ -40,16 +40,13 @@ void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int
void redraw(Data data, const Point points[2]);
//add point in train set, train SVMSGD algorithm and draw results on image
void addPointRetrainAndRedraw(Data &data, int x, int y);
void addPointRetrainAndRedraw(Data &data, int x, int y, int response);
bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift)
{
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
svmsgd->setOptimalParameters(SVMSGD::ASGD);
svmsgd->setTermCriteria(TermCriteria(TermCriteria::EPS, 0, 0.00000001));
svmsgd->setLambda(0.00000001);
svmsgd->setOptimalParameters();
cv::Ptr<TrainData> train_data = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses);
svmsgd->train( train_data );
@@ -64,6 +61,27 @@ bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift
return false;
}
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height)
{
std::pair<Point,Point> currentSegment;
currentSegment.first = Point(width, 0);
currentSegment.second = Point(width, height);
segments.push_back(currentSegment);
currentSegment.first = Point(0, height);
currentSegment.second = Point(width, height);
segments.push_back(currentSegment);
currentSegment.first = Point(0, 0);
currentSegment.second = Point(width, 0);
segments.push_back(currentSegment);
currentSegment.first = Point(0, 0);
currentSegment.second = Point(0, height);
segments.push_back(currentSegment);
}
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
{
@@ -123,27 +141,6 @@ bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2], int
return true;
}
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height)
{
std::pair<Point,Point> currentSegment;
currentSegment.first = Point(width, 0);
currentSegment.second = Point(width, height);
segments.push_back(currentSegment);
currentSegment.first = Point(0, height);
currentSegment.second = Point(width, height);
segments.push_back(currentSegment);
currentSegment.first = Point(0, 0);
currentSegment.second = Point(width, 0);
segments.push_back(currentSegment);
currentSegment.first = Point(0, 0);
currentSegment.second = Point(0, height);
segments.push_back(currentSegment);
}
void redraw(Data data, const Point points[2])
{
data.img.setTo(0);
@@ -162,19 +159,20 @@ void redraw(Data data, const Point points[2])
imshow("Train svmsgd", data.img);
}
void addPointRetrainAndRedraw(Data &data, int x, int y)
void addPointRetrainAndRedraw(Data &data, int x, int y, int response)
{
Mat currentSample(1, 2, CV_32F);
currentSample.at<float>(0,0) = x;
currentSample.at<float>(0,1) = y;
data.samples.push_back(currentSample);
data.responses.push_back(response);
Mat weights(1, 2, CV_32F);
float shift = 0;
if (doTrain(data.samples, data.responses, weights, shift))
{
{
Point points[2];
findPointsForLine(weights, shift, points, data.img.cols, data.img.rows);
@@ -189,15 +187,12 @@ static void onMouse( int event, int x, int y, int, void* pData)
switch( event )
{
case CV_EVENT_LBUTTONUP:
data.responses.push_back(1);
addPointRetrainAndRedraw(data, x, y);
case CV_EVENT_LBUTTONUP:
addPointRetrainAndRedraw(data, x, y, 1);
break;
case CV_EVENT_RBUTTONDOWN:
data.responses.push_back(-1);
addPointRetrainAndRedraw(data, x, y);
addPointRetrainAndRedraw(data, x, y, -1);
break;
}