replaced Calonder descriptor implementation; added windowedMatchingMask()

This commit is contained in:
Maria Dimashova 2010-07-27 12:36:48 +00:00
parent 4f3de6ebf6
commit e83c9b08d8
3 changed files with 1306 additions and 288 deletions

View File

@ -605,6 +605,207 @@ protected:
/****************************************************************************************\
* Calonder Classifier *
\****************************************************************************************/
struct RTreeNode;
struct CV_EXPORTS BaseKeypoint
{
int x;
int y;
IplImage* image;
BaseKeypoint()
: x(0), y(0), image(NULL)
{}
BaseKeypoint(int x, int y, IplImage* image)
: x(x), y(y), image(image)
{}
};
class CV_EXPORTS RandomizedTree
{
public:
friend class RTreeClassifier;
static const int PATCH_SIZE = 32;
static const int DEFAULT_DEPTH = 9;
static const int DEFAULT_VIEWS = 5000;
static const size_t DEFAULT_REDUCED_NUM_DIM = 176;
static const float LOWER_QUANT_PERC = .03f;
static const float UPPER_QUANT_PERC = .92f;
RandomizedTree();
~RandomizedTree();
void train(std::vector<BaseKeypoint> const& base_set, RNG &rng,
int depth, int views, size_t reduced_num_dim, int num_quant_bits);
void train(std::vector<BaseKeypoint> const& base_set, RNG &rng,
PatchGenerator &make_patch, int depth, int views, size_t reduced_num_dim,
int num_quant_bits);
// following two funcs are EXPERIMENTAL (do not use unless you know exactly what you do)
static void quantizeVector(float *vec, int dim, int N, float bnds[2], int clamp_mode=0);
static void quantizeVector(float *src, int dim, int N, float bnds[2], uint8_t *dst);
// patch_data must be a 32x32 array (no row padding)
float* getPosterior(uchar* patch_data);
const float* getPosterior(uchar* patch_data) const;
uint8_t* getPosterior2(uchar* patch_data);
const uint8_t* getPosterior2(uchar* patch_data) const;
void read(const char* file_name, int num_quant_bits);
void read(std::istream &is, int num_quant_bits);
void write(const char* file_name) const;
void write(std::ostream &os) const;
int classes() { return classes_; }
int depth() { return depth_; }
//void setKeepFloatPosteriors(bool b) { keep_float_posteriors_ = b; }
void discardFloatPosteriors() { freePosteriors(1); }
inline void applyQuantization(int num_quant_bits) { makePosteriors2(num_quant_bits); }
// debug
void savePosteriors(std::string url, bool append=false);
void savePosteriors2(std::string url, bool append=false);
private:
int classes_;
int depth_;
int num_leaves_;
std::vector<RTreeNode> nodes_;
float **posteriors_; // 16-bytes aligned posteriors
uint8_t **posteriors2_; // 16-bytes aligned posteriors
std::vector<int> leaf_counts_;
void createNodes(int num_nodes, RNG &rng);
void allocPosteriorsAligned(int num_leaves, int num_classes);
void freePosteriors(int which); // which: 1=posteriors_, 2=posteriors2_, 3=both
void init(int classes, int depth, RNG &rng);
void addExample(int class_id, uchar* patch_data);
void finalize(size_t reduced_num_dim, int num_quant_bits);
int getIndex(uchar* patch_data) const;
inline float* getPosteriorByIndex(int index);
inline const float* getPosteriorByIndex(int index) const;
inline uint8_t* getPosteriorByIndex2(int index);
inline const uint8_t* getPosteriorByIndex2(int index) const;
//void makeRandomMeasMatrix(float *cs_phi, PHI_DISTR_TYPE dt, size_t reduced_num_dim);
void convertPosteriorsToChar();
void makePosteriors2(int num_quant_bits);
void compressLeaves(size_t reduced_num_dim);
void estimateQuantPercForPosteriors(float perc[2]);
};
inline uchar* getData(IplImage* image)
{
return reinterpret_cast<uchar*>(image->imageData);
}
inline float* RandomizedTree::getPosteriorByIndex(int index)
{
return const_cast<float*>(const_cast<const RandomizedTree*>(this)->getPosteriorByIndex(index));
}
inline const float* RandomizedTree::getPosteriorByIndex(int index) const
{
return posteriors_[index];
}
inline uint8_t* RandomizedTree::getPosteriorByIndex2(int index)
{
return const_cast<uint8_t*>(const_cast<const RandomizedTree*>(this)->getPosteriorByIndex2(index));
}
inline const uint8_t* RandomizedTree::getPosteriorByIndex2(int index) const
{
return posteriors2_[index];
}
struct CV_EXPORTS RTreeNode
{
short offset1, offset2;
RTreeNode() {}
RTreeNode(uchar x1, uchar y1, uchar x2, uchar y2)
: offset1(y1*RandomizedTree::PATCH_SIZE + x1),
offset2(y2*RandomizedTree::PATCH_SIZE + x2)
{}
//! Left child on 0, right child on 1
inline bool operator() (uchar* patch_data) const
{
return patch_data[offset1] > patch_data[offset2];
}
};
class CV_EXPORTS RTreeClassifier
{
public:
static const int DEFAULT_TREES = 48;
static const size_t DEFAULT_NUM_QUANT_BITS = 4;
RTreeClassifier();
void train(std::vector<BaseKeypoint> const& base_set,
RNG &rng,
int num_trees = RTreeClassifier::DEFAULT_TREES,
int depth = RandomizedTree::DEFAULT_DEPTH,
int views = RandomizedTree::DEFAULT_VIEWS,
size_t reduced_num_dim = RandomizedTree::DEFAULT_REDUCED_NUM_DIM,
int num_quant_bits = DEFAULT_NUM_QUANT_BITS);
void train(std::vector<BaseKeypoint> const& base_set,
RNG &rng,
PatchGenerator &make_patch,
int num_trees = RTreeClassifier::DEFAULT_TREES,
int depth = RandomizedTree::DEFAULT_DEPTH,
int views = RandomizedTree::DEFAULT_VIEWS,
size_t reduced_num_dim = RandomizedTree::DEFAULT_REDUCED_NUM_DIM,
int num_quant_bits = DEFAULT_NUM_QUANT_BITS);
// sig must point to a memory block of at least classes()*sizeof(float|uint8_t) bytes
void getSignature(IplImage *patch, uint8_t *sig) const;
void getSignature(IplImage *patch, float *sig) const;
void getSparseSignature(IplImage *patch, float *sig, float thresh) const;
// TODO: deprecated in favor of getSignature overload, remove
void getFloatSignature(IplImage *patch, float *sig) const { getSignature(patch, sig); }
static int countNonZeroElements(float *vec, int n, double tol=1e-10);
static inline void safeSignatureAlloc(uint8_t **sig, int num_sig=1, int sig_len=176);
static inline uint8_t* safeSignatureAlloc(int num_sig=1, int sig_len=176);
inline int classes() const { return classes_; }
inline int original_num_classes() const { return original_num_classes_; }
void setQuantization(int num_quant_bits);
void discardFloatPosteriors();
void read(const char* file_name);
void read(std::istream &is);
void write(const char* file_name) const;
void write(std::ostream &os) const;
// experimental and debug
void saveAllFloatPosteriors(std::string file_url);
void saveAllBytePosteriors(std::string file_url);
void setFloatPosteriorsFromTextfile_176(std::string url);
float countZeroElements();
std::vector<RandomizedTree> trees_;
private:
int classes_;
int num_quant_bits_;
mutable uint8_t **posteriors_;
mutable uint16_t *ptemp_;
int original_num_classes_;
bool keep_floats_;
};
#if 0
class CV_EXPORTS CalonderClassifier
{
public:
@ -645,6 +846,7 @@ public:
#endif
void read( const FileNode& fn );
void read( std::istream& is );
void write( FileStorage& fs ) const;
bool empty() const;
@ -722,6 +924,7 @@ private:
vector<uchar> quantizedPosteriors;
#endif
};
#endif
/****************************************************************************************\
* One-Way Descriptor *
@ -1339,9 +1542,8 @@ protected:
SURF surf;
};
#if 0
template<typename T>
class CalonderDescriptorExtractor : public DescriptorExtractor
class CV_EXPORTS CalonderDescriptorExtractor : public DescriptorExtractor
{
public:
CalonderDescriptorExtractor( const string& classifierFile );
@ -1371,15 +1573,23 @@ void CalonderDescriptorExtractor<T>::compute( const cv::Mat& image,
/// @todo Check 16-byte aligned
descriptors.create(keypoints.size(), classifier_.classes(), cv::DataType<T>::type);
IplImage ipl = (IplImage)image;
int patchSize = RandomizedTree::PATCH_SIZE;
int offset = patchSize / 2;
for (size_t i = 0; i < keypoints.size(); ++i) {
cv::Point2f keypt = keypoints[i].pt;
cv::WImageView1_b patch = features::extractPatch(&ipl, keypt);
classifier_.getSignature(patch.Ipl(), descriptors.ptr<T>(i));
cv::Point2f pt = keypoints[i].pt;
IplImage ipl = image( Rect(pt.x - offset, pt.y - offset, patchSize, patchSize) );
classifier_.getSignature( &ipl, descriptors.ptr<T>(i));
}
}
#endif
template<typename T>
void CalonderDescriptorExtractor<T>::read( const FileNode &fn )
{}
template<typename T>
void CalonderDescriptorExtractor<T>::write( FileStorage &fs ) const
{}
CV_EXPORTS Ptr<DescriptorExtractor> createDescriptorExtractor( const string& descriptorExtractorType );
@ -1478,7 +1688,7 @@ public:
/*
* Index the descriptors training set
*/
void index();
virtual void index() = 0;
/*
* Find the best match for each descriptor from a query set
@ -1574,18 +1784,15 @@ protected:
* Find matches; match() calls this. Must be implemented by the subclass.
* The mask may be empty.
*/
virtual void matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& mask, vector<int>& matches ) const = 0;
virtual void matchImpl( const Mat& query, const Mat& mask, vector<int>& matches ) const = 0;
/*
* Find matches; match() calls this. Must be implemented by the subclass.
* The mask may be empty.
*/
virtual void matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& mask, vector<DMatch>& matches ) const = 0;
virtual void matchImpl( const Mat& query, const Mat& mask, vector<DMatch>& matches ) const = 0;
virtual void matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& mask, vector<vector<DMatch> >& matches, float threshold ) const = 0;
virtual void matchImpl( const Mat& query, const Mat& mask, vector<vector<DMatch> >& matches, float threshold ) const = 0;
static bool possibleMatch( const Mat& mask, int index_1, int index_2 )
@ -1614,36 +1821,36 @@ inline void DescriptorMatcher::add( const Mat& descriptors )
inline void DescriptorMatcher::match( const Mat& query, vector<int>& matches ) const
{
matchImpl( query, train, Mat(), matches );
matchImpl( query, Mat(), matches );
}
inline void DescriptorMatcher::match( const Mat& query, const Mat& mask,
vector<int>& matches ) const
{
matchImpl( query, train, mask, matches );
matchImpl( query, mask, matches );
}
inline void DescriptorMatcher::match( const Mat& query, vector<DMatch>& matches ) const
{
matchImpl( query, train, Mat(), matches );
matchImpl( query, Mat(), matches );
}
inline void DescriptorMatcher::match( const Mat& query, const Mat& mask,
vector<DMatch>& matches ) const
{
matchImpl( query, train, mask, matches );
matchImpl( query, mask, matches );
}
inline void DescriptorMatcher::match( const Mat& query, vector<vector<DMatch> >& matches, float threshold ) const
{
matchImpl( query, train, Mat(), matches, threshold );
matchImpl( query, Mat(), matches, threshold );
}
inline void DescriptorMatcher::match( const Mat& query, const Mat& mask,
vector<vector<DMatch> >& matches, float threshold ) const
{
matchImpl( query, train, mask, matches, threshold );
matchImpl( query, mask, matches, threshold );
}
@ -1666,26 +1873,22 @@ class CV_EXPORTS BruteForceMatcher : public DescriptorMatcher
{
public:
BruteForceMatcher( Distance d = Distance() ) : distance(d) {}
virtual void index() {}
protected:
virtual void matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& mask, vector<int>& matches ) const;
virtual void matchImpl( const Mat& query, const Mat& mask, vector<int>& matches ) const;
virtual void matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& mask, vector<DMatch>& matches ) const;
virtual void matchImpl( const Mat& query, const Mat& mask, vector<DMatch>& matches ) const;
virtual void matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& mask, vector<vector<DMatch> >& matches, float threshold ) const;
virtual void matchImpl( const Mat& query, const Mat& mask, vector<vector<DMatch> >& matches, float threshold ) const;
Distance distance;
};
template<class Distance> inline
void BruteForceMatcher<Distance>::matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& mask, vector<int>& matches ) const
void BruteForceMatcher<Distance>::matchImpl( const Mat& query, const Mat& mask, vector<int>& matches ) const
{
vector<DMatch> matchings;
matchImpl( descriptors_1, descriptors_2, mask, matchings);
matchImpl( query, mask, matchings);
matches.clear();
matches.resize( matchings.size() );
for( size_t i=0;i<matchings.size();i++)
@ -1695,33 +1898,32 @@ void BruteForceMatcher<Distance>::matchImpl( const Mat& descriptors_1, const Mat
}
template<class Distance> inline
void BruteForceMatcher<Distance>::matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& mask, vector<DMatch>& matches ) const
void BruteForceMatcher<Distance>::matchImpl( const Mat& query, const Mat& mask, vector<DMatch>& matches ) const
{
typedef typename Distance::ValueType ValueType;
typedef typename Distance::ResultType DistanceType;
assert( mask.empty() || (mask.rows == descriptors_1.rows && mask.cols == descriptors_2.rows) );
assert( mask.empty() || (mask.rows == query.rows && mask.cols == train.rows) );
assert( descriptors_1.cols == descriptors_2.cols || descriptors_1.empty() || descriptors_2.empty() );
assert( DataType<ValueType>::type == descriptors_1.type() || descriptors_1.empty() );
assert( DataType<ValueType>::type == descriptors_2.type() || descriptors_2.empty() );
assert( query.cols == train.cols || query.empty() || train.empty() );
assert( DataType<ValueType>::type == query.type() || query.empty() );
assert( DataType<ValueType>::type == train.type() || train.empty() );
int dimension = descriptors_1.cols;
int dimension = query.cols;
matches.clear();
matches.resize(descriptors_1.rows);
matches.resize(query.rows);
for( int i = 0; i < descriptors_1.rows; i++ )
for( int i = 0; i < query.rows; i++ )
{
const ValueType* d1 = (const ValueType*)(descriptors_1.data + descriptors_1.step*i);
const ValueType* d1 = (const ValueType*)(query.data + query.step*i);
int matchIndex = -1;
DistanceType matchDistance = std::numeric_limits<DistanceType>::max();
for( int j = 0; j < descriptors_2.rows; j++ )
for( int j = 0; j < train.rows; j++ )
{
if( possibleMatch(mask, i, j) )
{
const ValueType* d2 = (const ValueType*)(descriptors_2.data + descriptors_2.step*j);
const ValueType* d2 = (const ValueType*)(train.data + train.step*j);
DistanceType curDistance = distance(d1, d2, dimension);
if( curDistance < matchDistance )
{
@ -1743,31 +1945,30 @@ void BruteForceMatcher<Distance>::matchImpl( const Mat& descriptors_1, const Mat
}
template<class Distance> inline
void BruteForceMatcher<Distance>::matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& mask, vector<vector<DMatch> >& matches, float threshold ) const
void BruteForceMatcher<Distance>::matchImpl( const Mat& query, const Mat& mask, vector<vector<DMatch> >& matches, float threshold ) const
{
typedef typename Distance::ValueType ValueType;
typedef typename Distance::ResultType DistanceType;
assert( mask.empty() || (mask.rows == descriptors_1.rows && mask.cols == descriptors_2.rows) );
assert( mask.empty() || (mask.rows == query.rows && mask.cols == train.rows) );
assert( descriptors_1.cols == descriptors_2.cols || descriptors_1.empty() || descriptors_2.empty() );
assert( DataType<ValueType>::type == descriptors_1.type() || descriptors_1.empty() );
assert( DataType<ValueType>::type == descriptors_2.type() || descriptors_2.empty() );
assert( query.cols == train.cols || query.empty() || train.empty() );
assert( DataType<ValueType>::type == query.type() || query.empty() );
assert( DataType<ValueType>::type == train.type() || train.empty() );
int dimension = descriptors_1.cols;
int dimension = query.cols;
matches.clear();
matches.resize( descriptors_1.rows );
matches.resize( query.rows );
for( int i = 0; i < descriptors_1.rows; i++ )
for( int i = 0; i < query.rows; i++ )
{
const ValueType* d1 = (const ValueType*)(descriptors_1.data + descriptors_1.step*i);
const ValueType* d1 = (const ValueType*)(query.data + query.step*i);
for( int j = 0; j < descriptors_2.rows; j++ )
for( int j = 0; j < train.rows; j++ )
{
if( possibleMatch(mask, i, j) )
{
const ValueType* d2 = (const ValueType*)(descriptors_2.data + descriptors_2.step*j);
const ValueType* d2 = (const ValueType*)(train.data + train.step*j);
DistanceType curDistance = distance(d1, d2, dimension);
if( curDistance < threshold )
{
@ -1783,8 +1984,7 @@ void BruteForceMatcher<Distance>::matchImpl( const Mat& descriptors_1, const Mat
}
template<>
void BruteForceMatcher<L2<float> >::matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& mask, vector<int>& matches ) const;
void BruteForceMatcher<L2<float> >::matchImpl( const Mat& query, const Mat& mask, vector<int>& matches ) const;
CV_EXPORTS Ptr<DescriptorMatcher> createDescriptorMatcher( const string& descriptorMatcherType );
@ -1952,76 +2152,6 @@ protected:
Params params;
};
/*
* CalonderDescriptorMatch
*/
#if 0
class CV_EXPORTS CalonderDescriptorMatch : public GenericDescriptorMatch
{
public:
class Params
{
public:
static const int DEFAULT_NUM_TREES = 80;
static const int DEFAULT_DEPTH = 9;
static const int DEFAULT_VIEWS = 5000;
static const size_t DEFAULT_REDUCED_NUM_DIM = 176;
static const size_t DEFAULT_NUM_QUANT_BITS = 4;
static const int DEFAULT_PATCH_SIZE = PATCH_SIZE;
Params( const RNG& _rng = RNG(), const PatchGenerator& _patchGen = PatchGenerator(),
int _numTrees=DEFAULT_NUM_TREES,
int _depth=DEFAULT_DEPTH,
int _views=DEFAULT_VIEWS,
size_t _reducedNumDim=DEFAULT_REDUCED_NUM_DIM,
int _numQuantBits=DEFAULT_NUM_QUANT_BITS,
bool _printStatus=true,
int _patchSize=DEFAULT_PATCH_SIZE );
Params( const string& _filename );
RNG rng;
PatchGenerator patchGen;
int numTrees;
int depth;
int views;
int patchSize;
size_t reducedNumDim;
int numQuantBits;
bool printStatus;
string filename;
};
CalonderDescriptorMatch();
CalonderDescriptorMatch( const Params& _params );
virtual ~CalonderDescriptorMatch();
void initialize( const Params& _params );
virtual void add( const Mat& image, vector<KeyPoint>& keypoints );
virtual void match( const Mat& image, vector<KeyPoint>& keypoints, vector<int>& indices );
virtual void classify( const Mat& image, vector<KeyPoint>& keypoints );
virtual void clear ();
virtual void read( const FileNode &fn );
virtual void write( FileStorage& fs ) const;
protected:
void trainRTreeClassifier();
Mat extractPatch( const Mat& image, const Point& pt, int patchSize ) const;
void calcBestProbAndMatchIdx( const Mat& image, const Point& pt,
float& bestProb, int& bestMatchIdx, float* signature );
Ptr<RTreeClassifier> classifier;
Params params;
};
#endif
/*
* FernDescriptorMatch
*/
@ -2128,6 +2258,9 @@ protected:
//vector<int> classIds;
};
CV_EXPORTS Mat windowedMatchingMask( const vector<KeyPoint>& keypoints1, const vector<KeyPoint>& keypoints2,
float maxDeltaX, float maxDeltaY );
struct CV_EXPORTS DrawMatchesFlags
{

File diff suppressed because it is too large Load Diff

View File

@ -51,6 +51,24 @@ using namespace std;
namespace cv
{
Mat windowedMatchingMask( const vector<KeyPoint>& keypoints1, const vector<KeyPoint>& keypoints2,
float maxDeltaX, float maxDeltaY )
{
if( keypoints1.empty() || keypoints2.empty() )
return Mat();
Mat mask( keypoints1.size(), keypoints2.size(), CV_8UC1 );
for( size_t i = 0; i < keypoints1.size(); i++ )
{
for( size_t j = 0; j < keypoints2.size(); j++ )
{
Point2f diff = keypoints2[j].pt - keypoints1[i].pt;
mask.at<uchar>(i, j) = std::abs(diff.x) < maxDeltaX && std::abs(diff.y) < maxDeltaY;
}
}
return mask;
}
void drawMatches( const Mat& img1, const vector<KeyPoint>& keypoints1,
const Mat& img2,const vector<KeyPoint>& keypoints2,
const vector<int>& matches, Mat& outImg,
@ -278,20 +296,19 @@ Ptr<DescriptorMatcher> createDescriptorMatcher( const string& descriptorMatcherT
* BruteForceMatcher L2 specialization *
\****************************************************************************************/
template<>
void BruteForceMatcher<L2<float> >::matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
const Mat& /*mask*/, vector<int>& matches ) const
void BruteForceMatcher<L2<float> >::matchImpl( const Mat& query, const Mat& /*mask*/, vector<int>& matches ) const
{
matches.clear();
matches.reserve( descriptors_1.rows );
matches.reserve( query.rows );
//TODO: remove _DEBUG if bag 416 fixed
#if (defined _DEBUG || !defined HAVE_EIGEN2)
Mat norms;
cv::reduce( descriptors_2.mul( descriptors_2 ), norms, 1, 0);
cv::reduce( train.mul( train ), norms, 1, 0);
norms = norms.t();
Mat desc_2t = descriptors_2.t();
for( int i=0;i<descriptors_1.rows;i++ )
Mat desc_2t = train.t();
for( int i=0;i<query.rows;i++ )
{
Mat distances = (-2)*descriptors_1.row(i)*desc_2t;
Mat distances = (-2)*query.row(i)*desc_2t;
distances += norms;
Point minLoc;
minMaxLoc ( distances, 0, 0, &minLoc );
@ -631,160 +648,6 @@ void OneWayDescriptorMatch::clear ()
base->clear ();
}
/****************************************************************************************\
* CalonderDescriptorMatch *
\****************************************************************************************/
#if 0
CalonderDescriptorMatch::Params::Params( const RNG& _rng, const PatchGenerator& _patchGen,
int _numTrees, int _depth, int _views,
size_t _reducedNumDim,
int _numQuantBits,
bool _printStatus,
int _patchSize ) :
rng(_rng), patchGen(_patchGen), numTrees(_numTrees), depth(_depth), views(_views),
patchSize(_patchSize), reducedNumDim(_reducedNumDim), numQuantBits(_numQuantBits), printStatus(_printStatus)
{}
CalonderDescriptorMatch::Params::Params( const string& _filename )
{
filename = _filename;
}
CalonderDescriptorMatch::CalonderDescriptorMatch()
{}
CalonderDescriptorMatch::CalonderDescriptorMatch( const Params& _params )
{
initialize(_params);
}
CalonderDescriptorMatch::~CalonderDescriptorMatch()
{}
void CalonderDescriptorMatch::initialize( const Params& _params )
{
classifier.release();
params = _params;
if( !params.filename.empty() )
{
classifier = new RTreeClassifier;
classifier->read( params.filename.c_str() );
}
}
void CalonderDescriptorMatch::add( const Mat& image, vector<KeyPoint>& keypoints )
{
if( params.filename.empty() )
collection.add( image, keypoints );
}
Mat CalonderDescriptorMatch::extractPatch( const Mat& image, const Point& pt, int patchSize ) const
{
const int offset = patchSize / 2;
return image( Rect(pt.x - offset, pt.y - offset, patchSize, patchSize) );
}
void CalonderDescriptorMatch::calcBestProbAndMatchIdx( const Mat& image, const Point& pt,
float& bestProb, int& bestMatchIdx, float* signature )
{
IplImage roi = extractPatch( image, pt, params.patchSize );
classifier->getSignature( &roi, signature );
bestProb = 0;
bestMatchIdx = -1;
for( int ci = 0; ci < classifier->classes(); ci++ )
{
if( signature[ci] > bestProb )
{
bestProb = signature[ci];
bestMatchIdx = ci;
}
}
}
void CalonderDescriptorMatch::trainRTreeClassifier()
{
if( classifier.empty() )
{
assert( params.filename.empty() );
classifier = new RTreeClassifier;
vector<BaseKeypoint> baseKeyPoints;
vector<IplImage> iplImages( collection.images.size() );
for( size_t imageIdx = 0; imageIdx < collection.images.size(); imageIdx++ )
{
iplImages[imageIdx] = collection.images[imageIdx];
for( size_t pointIdx = 0; pointIdx < collection.points[imageIdx].size(); pointIdx++ )
{
BaseKeypoint bkp;
KeyPoint kp = collection.points[imageIdx][pointIdx];
bkp.x = cvRound(kp.pt.x);
bkp.y = cvRound(kp.pt.y);
bkp.image = &iplImages[imageIdx];
baseKeyPoints.push_back(bkp);
}
}
classifier->train( baseKeyPoints, params.rng, params.patchGen, params.numTrees,
params.depth, params.views, params.reducedNumDim, params.numQuantBits,
params.printStatus );
}
}
void CalonderDescriptorMatch::match( const Mat& image, vector<KeyPoint>& keypoints, vector<int>& indices )
{
trainRTreeClassifier();
float bestProb = 0;
AutoBuffer<float> signature( classifier->classes() );
indices.resize( keypoints.size() );
for( size_t pi = 0; pi < keypoints.size(); pi++ )
calcBestProbAndMatchIdx( image, keypoints[pi].pt, bestProb, indices[pi], signature );
}
void CalonderDescriptorMatch::classify( const Mat& image, vector<KeyPoint>& keypoints )
{
trainRTreeClassifier();
AutoBuffer<float> signature( classifier->classes() );
for( size_t pi = 0; pi < keypoints.size(); pi++ )
{
float bestProb = 0;
int bestMatchIdx = -1;
calcBestProbAndMatchIdx( image, keypoints[pi].pt, bestProb, bestMatchIdx, signature );
keypoints[pi].class_id = collection.getKeyPoint(bestMatchIdx).class_id;
}
}
void CalonderDescriptorMatch::clear ()
{
GenericDescriptorMatch::clear();
classifier.release();
}
void CalonderDescriptorMatch::read( const FileNode &fn )
{
params.numTrees = fn["numTrees"];
params.depth = fn["depth"];
params.views = fn["views"];
params.patchSize = fn["patchSize"];
params.reducedNumDim = (int) fn["reducedNumDim"];
params.numQuantBits = fn["numQuantBits"];
params.printStatus = (int) fn["printStatus"] != 0;
}
void CalonderDescriptorMatch::write( FileStorage& fs ) const
{
fs << "numTrees" << params.numTrees;
fs << "depth" << params.depth;
fs << "views" << params.views;
fs << "patchSize" << params.patchSize;
fs << "reducedNumDim" << (int) params.reducedNumDim;
fs << "numQuantBits" << params.numQuantBits;
fs << "printStatus" << params.printStatus;
}
#endif
/****************************************************************************************\
* FernDescriptorMatch *
\****************************************************************************************/