Increasing the dimension of features space in the SVMSGD::train function.
This commit is contained in:
@@ -12,10 +12,8 @@ using namespace cv::ml;
|
||||
struct Data
|
||||
{
|
||||
Mat img;
|
||||
Mat samples;
|
||||
Mat responses;
|
||||
RNG rng;
|
||||
//Point points[2];
|
||||
Mat samples; //Set of train samples. Contains points on image
|
||||
Mat responses; //Set of responses for train samples
|
||||
|
||||
Data()
|
||||
{
|
||||
@@ -24,24 +22,36 @@ struct Data
|
||||
}
|
||||
};
|
||||
|
||||
bool doTrain(const Mat samples,const Mat responses, Mat &weights, float &shift);
|
||||
bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2]);
|
||||
bool findCrossPoint(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint);
|
||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments);
|
||||
//Train with SVMSGD algorithm
|
||||
//(samples, responses) is a train set
|
||||
//weights is a required vector for decision function of SVMSGD algorithm
|
||||
bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift);
|
||||
|
||||
//function finds two points for drawing line (wx = 0)
|
||||
bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2], int width, int height);
|
||||
|
||||
// function finds cross point of line (wx = 0) and segment ( (y = HEIGHT, 0 <= x <= WIDTH) or (x = WIDTH, 0 <= y <= HEIGHT) )
|
||||
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint);
|
||||
|
||||
//segments' initialization ( (y = HEIGHT, 0 <= x <= WIDTH) and (x = WIDTH, 0 <= y <= HEIGHT) )
|
||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height);
|
||||
|
||||
//redraw points' set and line (wx = 0)
|
||||
void redraw(Data data, const Point points[2]);
|
||||
void addPointsRetrainAndRedraw(Data &data, int x, int y);
|
||||
|
||||
//add point in train set, train SVMSGD algorithm and draw results on image
|
||||
void addPointRetrainAndRedraw(Data &data, int x, int y);
|
||||
|
||||
|
||||
bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift)
|
||||
{
|
||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
|
||||
svmsgd->setOptimalParameters(SVMSGD::ASGD);
|
||||
svmsgd->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 50000, 0.0000001));
|
||||
svmsgd->setLambda(0.01);
|
||||
svmsgd->setGamma0(1);
|
||||
// svmsgd->setC(5);
|
||||
svmsgd->setTermCriteria(TermCriteria(TermCriteria::EPS, 0, 0.00000001));
|
||||
svmsgd->setLambda(0.00000001);
|
||||
|
||||
cv::Ptr<TrainData> train_data = TrainData::create( samples, cv::ml::ROW_SAMPLE, responses );
|
||||
|
||||
cv::Ptr<TrainData> train_data = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses);
|
||||
svmsgd->train( train_data );
|
||||
|
||||
if (svmsgd->isTrained())
|
||||
@@ -49,36 +59,39 @@ bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift
|
||||
weights = svmsgd->getWeights();
|
||||
shift = svmsgd->getShift();
|
||||
|
||||
std::cout << weights << std::endl;
|
||||
std::cout << shift << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool findCrossPoint(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
|
||||
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
|
||||
{
|
||||
int x = 0;
|
||||
int y = 0;
|
||||
//с (0,0) всё плохо
|
||||
if (segment.first.x == segment.second.x && weights.at<float>(1) != 0)
|
||||
int xMin = std::min(segment.first.x, segment.second.x);
|
||||
int xMax = std::max(segment.first.x, segment.second.x);
|
||||
int yMin = std::min(segment.first.y, segment.second.y);
|
||||
int yMax = std::max(segment.first.y, segment.second.y);
|
||||
|
||||
CV_Assert(xMin == xMax || yMin == yMax);
|
||||
|
||||
if (xMin == xMax && weights.at<float>(1) != 0)
|
||||
{
|
||||
x = segment.first.x;
|
||||
y = -(weights.at<float>(0) * x + shift) / weights.at<float>(1);
|
||||
if (y >= 0 && y <= HEIGHT)
|
||||
x = xMin;
|
||||
y = std::floor( - (weights.at<float>(0) * x + shift) / weights.at<float>(1));
|
||||
if (y >= yMin && y <= yMax)
|
||||
{
|
||||
crossPoint.x = x;
|
||||
crossPoint.y = y;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
else if (segment.first.y == segment.second.y && weights.at<float>(0) != 0)
|
||||
else if (yMin == yMax && weights.at<float>(0) != 0)
|
||||
{
|
||||
y = segment.first.y;
|
||||
x = - (weights.at<float>(1) * y + shift) / weights.at<float>(0);
|
||||
if (x >= 0 && x <= WIDTH)
|
||||
y = yMin;
|
||||
x = std::floor( - (weights.at<float>(1) * y + shift) / weights.at<float>(0));
|
||||
if (x >= xMin && x <= xMax)
|
||||
{
|
||||
crossPoint.x = x;
|
||||
crossPoint.y = y;
|
||||
@@ -88,7 +101,7 @@ bool findCrossPoint(const Mat &weights, float shift, const std::pair<Point,Point
|
||||
return false;
|
||||
}
|
||||
|
||||
bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2])
|
||||
bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2], int width, int height)
|
||||
{
|
||||
if (weights.empty())
|
||||
{
|
||||
@@ -97,42 +110,43 @@ bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2])
|
||||
|
||||
int foundPointsCount = 0;
|
||||
std::vector<std::pair<Point,Point> > segments;
|
||||
fillSegments(segments);
|
||||
fillSegments(segments, width, height);
|
||||
|
||||
for (int i = 0; i < 4; i++)
|
||||
for (uint i = 0; i < segments.size(); i++)
|
||||
{
|
||||
if (findCrossPoint(weights, shift, segments[i], points[foundPointsCount]))
|
||||
if (findCrossPointWithBorders(weights, shift, segments[i], points[foundPointsCount]))
|
||||
foundPointsCount++;
|
||||
if (foundPointsCount > 2)
|
||||
if (foundPointsCount >= 2)
|
||||
break;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments)
|
||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height)
|
||||
{
|
||||
std::pair<Point,Point> curSegment;
|
||||
std::pair<Point,Point> currentSegment;
|
||||
|
||||
curSegment.first = Point(0,0);
|
||||
curSegment.second = Point(0,HEIGHT);
|
||||
segments.push_back(curSegment);
|
||||
currentSegment.first = Point(width, 0);
|
||||
currentSegment.second = Point(width, height);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
curSegment.first = Point(0,0);
|
||||
curSegment.second = Point(WIDTH,0);
|
||||
segments.push_back(curSegment);
|
||||
currentSegment.first = Point(0, height);
|
||||
currentSegment.second = Point(width, height);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
curSegment.first = Point(WIDTH,0);
|
||||
curSegment.second = Point(WIDTH,HEIGHT);
|
||||
segments.push_back(curSegment);
|
||||
currentSegment.first = Point(0, 0);
|
||||
currentSegment.second = Point(width, 0);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
curSegment.first = Point(0,HEIGHT);
|
||||
curSegment.second = Point(WIDTH,HEIGHT);
|
||||
segments.push_back(curSegment);
|
||||
currentSegment.first = Point(0, 0);
|
||||
currentSegment.second = Point(0, height);
|
||||
segments.push_back(currentSegment);
|
||||
}
|
||||
|
||||
void redraw(Data data, const Point points[2])
|
||||
{
|
||||
data.img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3);
|
||||
data.img.setTo(0);
|
||||
Point center;
|
||||
int radius = 3;
|
||||
Scalar color;
|
||||
@@ -143,48 +157,26 @@ void redraw(Data data, const Point points[2])
|
||||
color = (data.responses.at<float>(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128);
|
||||
circle(data.img, center, radius, color, 5);
|
||||
}
|
||||
line(data.img, points[0],points[1],cv::Scalar(1,255,1));
|
||||
line(data.img, points[0], points[1],cv::Scalar(1,255,1));
|
||||
|
||||
imshow("Train svmsgd", data.img);
|
||||
}
|
||||
|
||||
void addPointsRetrainAndRedraw(Data &data, int x, int y)
|
||||
void addPointRetrainAndRedraw(Data &data, int x, int y)
|
||||
{
|
||||
|
||||
Mat currentSample(1, 2, CV_32F);
|
||||
//start
|
||||
/*
|
||||
Mat _weights;
|
||||
_weights.create(1, 2, CV_32FC1);
|
||||
_weights.at<float>(0) = 1;
|
||||
_weights.at<float>(1) = -1;
|
||||
|
||||
int _x, _y;
|
||||
|
||||
for (int i=0;i<199;i++)
|
||||
{
|
||||
_x = data.rng.uniform(0,800);
|
||||
_y = data.rng.uniform(0,500);*/
|
||||
currentSample.at<float>(0,0) = x;
|
||||
currentSample.at<float>(0,1) = y;
|
||||
//if (currentSample.dot(_weights) > 0)
|
||||
//data.responses.push_back(1);
|
||||
// else data.responses.push_back(-1);
|
||||
|
||||
//finish
|
||||
data.samples.push_back(currentSample);
|
||||
|
||||
|
||||
|
||||
Mat weights(1, 2, CV_32F);
|
||||
float shift = 0;
|
||||
|
||||
if (doTrain(data.samples, data.responses, weights, shift))
|
||||
{
|
||||
{
|
||||
Point points[2];
|
||||
shift = 0;
|
||||
|
||||
findPointsForLine(weights, shift, points);
|
||||
findPointsForLine(weights, shift, points, data.img.cols, data.img.rows);
|
||||
|
||||
redraw(data, points);
|
||||
}
|
||||
@@ -199,13 +191,13 @@ static void onMouse( int event, int x, int y, int, void* pData)
|
||||
{
|
||||
case CV_EVENT_LBUTTONUP:
|
||||
data.responses.push_back(1);
|
||||
addPointsRetrainAndRedraw(data, x, y);
|
||||
addPointRetrainAndRedraw(data, x, y);
|
||||
|
||||
break;
|
||||
|
||||
case CV_EVENT_RBUTTONDOWN:
|
||||
data.responses.push_back(-1);
|
||||
addPointsRetrainAndRedraw(data, x, y);
|
||||
addPointRetrainAndRedraw(data, x, y);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -213,14 +205,10 @@ static void onMouse( int event, int x, int y, int, void* pData)
|
||||
|
||||
int main()
|
||||
{
|
||||
|
||||
Data data;
|
||||
|
||||
setMouseCallback( "Train svmsgd", onMouse, &data );
|
||||
waitKey();
|
||||
|
||||
|
||||
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
Reference in New Issue
Block a user