opencv/modules/imgproc/src/spilltree.cpp

499 lines
14 KiB
C++
Raw Normal View History

/* Original code has been submitted by Liu Liu.
----------------------------------------------------------------------------------
* Spill-Tree for Approximate KNN Search
* Author: Liu Liu
* mailto: liuliu.1987+opencv@gmail.com
* Refer to Paper:
* An Investigation of Practical Approximate Nearest Neighbor Algorithms
* cvMergeSpillTree TBD
*
* Redistribution and use in source and binary forms, with or
* without modification, are permitted provided that the following
* conditions are met:
* Redistributions of source code must retain the above
* copyright notice, this list of conditions and the following
* disclaimer.
* Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials
* provided with the distribution.
* The name of Contributor may not be used to endorse or
* promote products derived from this software without
* specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
* CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE CONTRIBUTORS BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
* OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
* TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
* OF SUCH DAMAGE.
*/
#include "precomp.hpp"
#include "_featuretree.h"
struct CvSpillTreeNode
{
bool leaf; // is leaf or not (leaf is the point that have no more child)
bool spill; // is not a non-overlapping point (defeatist search)
CvSpillTreeNode* lc; // left child (<)
CvSpillTreeNode* rc; // right child (>)
int cc; // child count
CvMat* u; // projection vector
CvMat* center; // center
int i; // original index
double r; // radius of remaining feature point
double ub; // upper bound
double lb; // lower bound
double mp; // mean point
double p; // projection value
};
struct CvSpillTree
{
CvSpillTreeNode* root;
CvMat** refmat; // leaf ref matrix
int total; // total leaves
int naive; // under this value, we perform naive search
int type; // mat type
double rho; // under this value, it is a spill tree
double tau; // the overlapping buffer ratio
};
struct CvResult
{
int index;
double distance;
};
// find the farthest node in the "list" from "node"
static inline CvSpillTreeNode*
icvFarthestNode( CvSpillTreeNode* node,
CvSpillTreeNode* list,
int total )
{
double farthest = -1.;
CvSpillTreeNode* result = NULL;
for ( int i = 0; i < total; i++ )
{
double norm = cvNorm( node->center, list->center );
if ( norm > farthest )
{
farthest = norm;
result = list;
}
list = list->rc;
}
return result;
}
// clone a new tree node
static inline CvSpillTreeNode*
icvCloneSpillTreeNode( CvSpillTreeNode* node )
{
CvSpillTreeNode* result = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
memcpy( result, node, sizeof(CvSpillTreeNode) );
return result;
}
// append the link-list of a tree node
static inline void
icvAppendSpillTreeNode( CvSpillTreeNode* node,
CvSpillTreeNode* append )
{
if ( node->lc == NULL )
{
node->lc = node->rc = append;
node->lc->lc = node->rc->rc = NULL;
} else {
append->lc = node->rc;
append->rc = NULL;
node->rc->rc = append;
node->rc = append;
}
node->cc++;
}
#define _dispatch_mat_ptr(x, step) (CV_MAT_DEPTH((x)->type) == CV_32F ? (void*)((x)->data.fl+(step)) : (CV_MAT_DEPTH((x)->type) == CV_64F ? (void*)((x)->data.db+(step)) : (void*)(0)))
static void
icvDFSInitSpillTreeNode( const CvSpillTree* tr,
const int d,
CvSpillTreeNode* node )
{
if ( node->cc <= tr->naive )
{
// already get to a leaf, terminate the recursion.
node->leaf = true;
node->spill = false;
return;
}
// random select a node, then find a farthest node from this one, then find a farthest from that one...
// to approximate the farthest node-pair
static CvRNG rng_state = cvRNG(0xdeadbeef);
int rn = cvRandInt( &rng_state ) % node->cc;
CvSpillTreeNode* lnode = NULL;
CvSpillTreeNode* rnode = node->lc;
for ( int i = 0; i < rn; i++ )
rnode = rnode->rc;
lnode = icvFarthestNode( rnode, node->lc, node->cc );
rnode = icvFarthestNode( lnode, node->lc, node->cc );
// u is the projection vector
node->u = cvCreateMat( 1, d, tr->type );
cvSub( lnode->center, rnode->center, node->u );
cvNormalize( node->u, node->u );
// find the center of node in hyperspace
node->center = cvCreateMat( 1, d, tr->type );
cvZero( node->center );
CvSpillTreeNode* it = node->lc;
for ( int i = 0; i < node->cc; i++ )
{
cvAdd( it->center, node->center, node->center );
it = it->rc;
}
cvConvertScale( node->center, node->center, 1./node->cc );
// project every node to "u", and find the mean point "mp"
it = node->lc;
node->r = -1.;
node->mp = 0;
for ( int i = 0; i < node->cc; i++ )
{
node->mp += ( it->p = cvDotProduct( it->center, node->u ) );
double norm = cvNorm( node->center, it->center );
if ( norm > node->r )
node->r = norm;
it = it->rc;
}
node->mp = node->mp / node->cc;
// overlapping buffer and upper bound, lower bound
double ob = (lnode->p-rnode->p)*tr->tau*.5;
node->ub = node->mp+ob;
node->lb = node->mp-ob;
int sl = 0, l = 0;
int sr = 0, r = 0;
it = node->lc;
for ( int i = 0; i < node->cc; i++ )
{
if ( it->p <= node->ub )
sl++;
if ( it->p >= node->lb )
sr++;
if ( it->p < node->mp )
l++;
else
r++;
it = it->rc;
}
// precision problem, return the node as it is.
if (( l == 0 )||( r == 0 ))
{
cvReleaseMat( &(node->u) );
cvReleaseMat( &(node->center) );
node->leaf = true;
node->spill = false;
return;
}
CvSpillTreeNode* lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
memset(lc, 0, sizeof(CvSpillTreeNode));
CvSpillTreeNode* rc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
memset(rc, 0, sizeof(CvSpillTreeNode));
lc->lc = lc->rc = rc->lc = rc->rc = NULL;
lc->cc = rc->cc = 0;
int undo = cvRound(node->cc*tr->rho);
if (( sl >= undo )||( sr >= undo ))
{
// it is not a spill point (defeatist search disabled)
it = node->lc;
for ( int i = 0; i < node->cc; i++ )
{
CvSpillTreeNode* next = it->rc;
if ( it->p < node->mp )
icvAppendSpillTreeNode( lc, it );
else
icvAppendSpillTreeNode( rc, it );
it = next;
}
node->spill = false;
} else {
// a spill point
it = node->lc;
for ( int i = 0; i < node->cc; i++ )
{
CvSpillTreeNode* next = it->rc;
if ( it->p < node->lb )
icvAppendSpillTreeNode( lc, it );
else if ( it->p > node->ub )
icvAppendSpillTreeNode( rc, it );
else {
CvSpillTreeNode* cit = icvCloneSpillTreeNode( it );
icvAppendSpillTreeNode( lc, it );
icvAppendSpillTreeNode( rc, cit );
}
it = next;
}
node->spill = true;
}
node->lc = lc;
node->rc = rc;
// recursion process
icvDFSInitSpillTreeNode( tr, d, node->lc );
icvDFSInitSpillTreeNode( tr, d, node->rc );
}
static CvSpillTree*
icvCreateSpillTree( const CvMat* raw_data,
const int naive,
const double rho,
const double tau )
{
int n = raw_data->rows;
int d = raw_data->cols;
CvSpillTree* tr = (CvSpillTree*)cvAlloc( sizeof(CvSpillTree) );
tr->root = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
memset(tr->root, 0, sizeof(CvSpillTreeNode));
tr->refmat = (CvMat**)cvAlloc( sizeof(CvMat*)*n );
tr->total = n;
tr->naive = naive;
tr->rho = rho;
tr->tau = tau;
tr->type = raw_data->type;
// tie a link-list to the root node
tr->root->lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
memset(tr->root->lc, 0, sizeof(CvSpillTreeNode));
tr->root->lc->center = cvCreateMatHeader( 1, d, tr->type );
cvSetData( tr->root->lc->center, _dispatch_mat_ptr(raw_data, 0), raw_data->step );
tr->refmat[0] = tr->root->lc->center;
tr->root->lc->lc = NULL;
tr->root->lc->leaf = true;
tr->root->lc->i = 0;
CvSpillTreeNode* node = tr->root->lc;
for ( int i = 1; i < n; i++ )
{
CvSpillTreeNode* newnode = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
memset(newnode, 0, sizeof(CvSpillTreeNode));
newnode->center = cvCreateMatHeader( 1, d, tr->type );
cvSetData( newnode->center, _dispatch_mat_ptr(raw_data, i*d), raw_data->step );
tr->refmat[i] = newnode->center;
newnode->lc = node;
newnode->i = i;
newnode->leaf = true;
newnode->rc = NULL;
node->rc = newnode;
node = newnode;
}
tr->root->rc = node;
tr->root->cc = n;
icvDFSInitSpillTreeNode( tr, d, tr->root );
return tr;
}
static void
icvSpillTreeNodeHeapify( CvResult * heap,
int i,
const int k )
{
if ( heap[i].index == -1 )
return;
int l, r, largest = i;
CvResult inp;
do {
i = largest;
r = (i+1)<<1;
l = r-1;
if (( l < k )&&( heap[l].index == -1 ))
largest = l;
else if (( r < k )&&( heap[r].index == -1 ))
largest = r;
else {
if (( l < k )&&( heap[l].distance > heap[i].distance ))
largest = l;
if (( r < k )&&( heap[r].distance > heap[largest].distance ))
largest = r;
}
if ( largest != i )
CV_SWAP( heap[largest], heap[i], inp );
} while ( largest != i );
}
static void
icvSpillTreeDFSearch( CvSpillTree* tr,
CvSpillTreeNode* node,
CvResult* heap,
int* es,
const CvMat* desc,
const int k,
const int emax,
bool * cache)
{
if ((emax > 0)&&( *es >= emax ))
return;
double dist, p=0;
double distance;
while ( node->spill )
{
// defeatist search
if ( !node->leaf )
p = cvDotProduct( node->u, desc );
if ( p < node->lb && node->lc->cc >= k ) // check the number of children larger than k otherwise you'll skip over better neighbor
node = node->lc;
else if ( p > node->ub && node->rc->cc >= k )
node = node->rc;
else
break;
if ( NULL == node )
return;
}
if ( node->leaf )
{
// a leaf, naive search
CvSpillTreeNode* it = node->lc;
for ( int i = 0; i < node->cc; i++ )
{
if ( !cache[it->i] )
{
distance = cvNorm( it->center, desc );
cache[it->i] = true;
if (( heap[0].index == -1)||( distance < heap[0].distance ))
{
CvResult current_result;
current_result.index = it->i;
current_result.distance = distance;
heap[0] = current_result;
icvSpillTreeNodeHeapify( heap, 0, k );
(*es)++;
}
}
it = it->rc;
}
return;
}
dist = cvNorm( node->center, desc );
// impossible case, skip
if (( heap[0].index != -1 )&&( dist-node->r > heap[0].distance ))
return;
p = cvDotProduct( node->u, desc );
// guided dfs
if ( p < node->mp )
{
icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax, cache );
icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax, cache );
} else {
icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax, cache );
icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax, cache );
}
}
static void
icvFindSpillTreeFeatures( CvSpillTree* tr,
const CvMat* desc,
CvMat* results,
CvMat* dist,
const int k,
const int emax )
{
assert( desc->type == tr->type );
CvResult* heap = (CvResult*)cvAlloc( k*sizeof(heap[0]) );
bool* cache = (bool*)cvAlloc( sizeof(bool)*tr->total );
for ( int j = 0; j < desc->rows; j++ )
{
CvMat _desc = cvMat( 1, desc->cols, desc->type, _dispatch_mat_ptr(desc, j*desc->cols) );
for ( int i = 0; i < k; i++ ) {
CvResult current;
current.index=-1;
current.distance=-1;
heap[i] = current;
}
memset( cache, 0, sizeof(bool)*tr->total );
int es = 0;
icvSpillTreeDFSearch( tr, tr->root, heap, &es, &_desc, k, emax, cache );
CvResult inp;
for ( int i = k-1; i > 0; i-- )
{
CV_SWAP( heap[i], heap[0], inp );
icvSpillTreeNodeHeapify( heap, 0, i );
}
int* rs = results->data.i+j*results->cols;
double* dt = dist->data.db+j*dist->cols;
for ( int i = 0; i < k; i++, rs++, dt++ )
if ( heap[i].index != -1 )
{
*rs = heap[i].index;
*dt = heap[i].distance;
} else
*rs = -1;
}
cvFree( &heap );
cvFree( &cache );
}
static void
icvDFSReleaseSpillTreeNode( CvSpillTreeNode* node )
{
if ( node->leaf )
{
CvSpillTreeNode* it = node->lc;
for ( int i = 0; i < node->cc; i++ )
{
CvSpillTreeNode* s = it;
it = it->rc;
cvFree( &s );
}
} else {
cvReleaseMat( &node->u );
cvReleaseMat( &node->center );
icvDFSReleaseSpillTreeNode( node->lc );
icvDFSReleaseSpillTreeNode( node->rc );
}
cvFree( &node );
}
static void
icvReleaseSpillTree( CvSpillTree** tr )
{
for ( int i = 0; i < (*tr)->total; i++ )
cvReleaseMat( &((*tr)->refmat[i]) );
cvFree( &((*tr)->refmat) );
icvDFSReleaseSpillTreeNode( (*tr)->root );
cvFree( tr );
}
class CvSpillTreeWrap : public CvFeatureTree {
CvSpillTree* tr;
public:
CvSpillTreeWrap(const CvMat* raw_data,
const int naive,
const double rho,
const double tau) {
tr = icvCreateSpillTree(raw_data, naive, rho, tau);
}
~CvSpillTreeWrap() {
icvReleaseSpillTree(&tr);
}
void FindFeatures(const CvMat* desc, int k, int emax, CvMat* results, CvMat* dist) {
icvFindSpillTreeFeatures(tr, desc, results, dist, k, emax);
}
};
CvFeatureTree* cvCreateSpillTree( const CvMat* raw_data,
const int naive,
const double rho,
const double tau ) {
return new CvSpillTreeWrap(raw_data, naive, rho, tau);
}