From 581d4545366b4f9392352af8737cba73b8a9dc5a Mon Sep 17 00:00:00 2001 From: Alex Leontiev Date: Sun, 22 Sep 2013 00:14:49 +0800 Subject: [PATCH] Refined interface for Conjugate Gradient Some interface was refined (most notably, the method for returning Hessian was removed and the method for getting gradient was added as optional to base Solver::Function class) and basic code for setters/getters was added. Now is the time for the real work on an algorithm. --- modules/optim/doc/conjugate_gradient.rst | 11 +++ modules/optim/include/opencv2/optim.hpp | 13 +--- modules/optim/src/conjugate_gradient.cpp | 77 +++++++++++++++++++ modules/optim/src/simplex.cpp | 7 +- .../optim/test/test_conjugate_gradient.cpp | 61 +++++++++++++++ 5 files changed, 157 insertions(+), 12 deletions(-) create mode 100644 modules/optim/doc/conjugate_gradient.rst create mode 100644 modules/optim/src/conjugate_gradient.cpp create mode 100644 modules/optim/test/test_conjugate_gradient.cpp diff --git a/modules/optim/doc/conjugate_gradient.rst b/modules/optim/doc/conjugate_gradient.rst new file mode 100644 index 000000000..cd9697465 --- /dev/null +++ b/modules/optim/doc/conjugate_gradient.rst @@ -0,0 +1,11 @@ +Conjugate Gradient +======================= + +.. highlight:: cpp + +optim::ConjGradSolver +--------------------------------- + +.. ocv:class:: optim::ConjGradSolver + +This class is used diff --git a/modules/optim/include/opencv2/optim.hpp b/modules/optim/include/opencv2/optim.hpp index c1e7819b6..0a460cce0 100644 --- a/modules/optim/include/opencv2/optim.hpp +++ b/modules/optim/include/opencv2/optim.hpp @@ -55,6 +55,7 @@ public: public: virtual ~Function() {} virtual double calc(const double* x) const = 0; + virtual void getGradient(const double* /*x*/,double* /*grad*/) {} }; virtual Ptr getFunction() const = 0; @@ -86,17 +87,7 @@ CV_EXPORTS_W Ptr createDownhillSolver(const Ptr createConjGradSolver(const Ptr& f=Ptr(), diff --git a/modules/optim/src/conjugate_gradient.cpp b/modules/optim/src/conjugate_gradient.cpp new file mode 100644 index 000000000..7e555c6cc --- /dev/null +++ b/modules/optim/src/conjugate_gradient.cpp @@ -0,0 +1,77 @@ +#include "precomp.hpp" +#include "debug.hpp" + +namespace cv{namespace optim{ + + class ConjGradSolverImpl : public ConjGradSolver + { + public: + Ptr getFunction() const; + void setFunction(const Ptr& f); + TermCriteria getTermCriteria() const; + ConjGradSolverImpl(); + void setTermCriteria(const TermCriteria& termcrit); + double minimize(InputOutputArray x); + protected: + Ptr _Function; + TermCriteria _termcrit; + Mat_ d,r,buf_x,r_old; + private: + }; + + double ConjGradSolverImpl::minimize(InputOutputArray x){ + CV_Assert(_Function.empty()==false); + dprintf(("termcrit:\n\ttype: %d\n\tmaxCount: %d\n\tEPS: %g\n",_termcrit.type,_termcrit.maxCount,_termcrit.epsilon)); + + Mat x_mat=x.getMat(); + CV_Assert(MIN(x_mat.rows,x_mat.cols)==1); + int ndim=MAX(x_mat.rows,x_mat.cols); + CV_Assert(x_mat.type()==CV_64FC1); + + d.create(1,ndim); + r.create(1,ndim); + r_old.create(1,ndim); + + Mat_ proxy_x; + if(x_mat.rows>1){ + buf_x.create(1,ndim); + Mat_ proxy(ndim,1,(double*)buf_x.data); + x_mat.copyTo(proxy); + proxy_x=buf_x; + }else{ + proxy_x=x_mat; + } + + //here everything goes. check that everything is setted properly + + if(x_mat.rows>1){ + Mat(ndim, 1, CV_64F, (double*)proxy_x.data).copyTo(x); + } + return 0.0; + } + ConjGradSolverImpl::ConjGradSolverImpl(){ + _Function=Ptr(); + } + Ptr ConjGradSolverImpl::getFunction()const{ + return _Function; + } + void ConjGradSolverImpl::setFunction(const Ptr& f){ + _Function=f; + } + TermCriteria ConjGradSolverImpl::getTermCriteria()const{ + return _termcrit; + } + void ConjGradSolverImpl::setTermCriteria(const TermCriteria& termcrit){ + CV_Assert((termcrit.type==(TermCriteria::MAX_ITER+TermCriteria::EPS) && termcrit.epsilon>0 && termcrit.maxCount>0) || + ((termcrit.type==TermCriteria::MAX_ITER) && termcrit.maxCount>0)); + _termcrit=termcrit; + } + // both minRange & minError are specified by termcrit.epsilon; In addition, user may specify the number of iterations that the algorithm does. + Ptr createConjGradSolver(const Ptr& f, TermCriteria termcrit){ + ConjGradSolver *CG=new ConjGradSolverImpl(); + CG->setFunction(f); + CG->setTermCriteria(termcrit); + return Ptr(CG); + } +}} + diff --git a/modules/optim/src/simplex.cpp b/modules/optim/src/simplex.cpp index f45d0ce0b..54de6ed8c 100644 --- a/modules/optim/src/simplex.cpp +++ b/modules/optim/src/simplex.cpp @@ -19,6 +19,8 @@ namespace cv{namespace optim{ Ptr _Function; TermCriteria _termcrit; Mat _step; + Mat_ buf_x; + private: inline void createInitialSimplex(Mat_& simplex,Mat& step); inline double innerDownhillSimplex(cv::Mat_& p,double MinRange,double MinError,int& nfunk, @@ -209,7 +211,10 @@ namespace cv{namespace optim{ Mat_ proxy_x; if(x_mat.rows>1){ - proxy_x=x_mat.t(); + buf_x.create(1,_step.cols); + Mat_ proxy(_step.cols,1,(double*)buf_x.data); + x_mat.copyTo(proxy); + proxy_x=buf_x; }else{ proxy_x=x_mat; } diff --git a/modules/optim/test/test_conjugate_gradient.cpp b/modules/optim/test/test_conjugate_gradient.cpp new file mode 100644 index 000000000..2456a5cab --- /dev/null +++ b/modules/optim/test/test_conjugate_gradient.cpp @@ -0,0 +1,61 @@ +#include "test_precomp.hpp" +#include + +static void mytest(cv::Ptr solver,cv::Ptr ptr_F,cv::Mat& x, + cv::Mat& etalon_x,double etalon_res){ + solver->setFunction(ptr_F); + //int ndim=MAX(step.cols,step.rows); + double res=solver->minimize(x); + std::cout<<"res:\n\t"<getTermCriteria().epsilon; + ASSERT_TRUE(std::abs(res-etalon_res)::iterator it1=x.begin(),it2=etalon_x.begin();it1!=x.end();it1++,it2++){ + ASSERT_TRUE(std::abs((*it1)-(*it2)) solver=cv::optim::createConjGradSolver(); +#if 1 + { + cv::Ptr ptr_F(new SphereF()); + cv::Mat x=(cv::Mat_(1,2)<<1.0,1.0), + etalon_x=(cv::Mat_(1,2)<<0.0,0.0); + double etalon_res=0.0; + return; + mytest(solver,ptr_F,x,etalon_x,etalon_res); + } +#endif +#if 0 + { + cv::Ptr ptr_F(new RosenbrockF()); + cv::Mat x=(cv::Mat_(2,1)<<0.0,0.0), + step=(cv::Mat_(2,1)<<0.5,+0.5), + etalon_x=(cv::Mat_(2,1)<<1.0,1.0); + double etalon_res=0.0; + mytest(solver,ptr_F,x,step,etalon_x,etalon_res); + } +#endif +}