477 lines
12 KiB
C++
477 lines
12 KiB
C++
/***********************************************************************
|
|
* Software License Agreement (BSD License)
|
|
*
|
|
* Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
|
|
* Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
|
|
*
|
|
* THE BSD LICENSE
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions
|
|
* are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright
|
|
* notice, this list of conditions and the following disclaimer.
|
|
* 2. Redistributions in binary form must reproduce the above copyright
|
|
* notice, this list of conditions and the following disclaimer in the
|
|
* documentation and/or other materials provided with the distribution.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
|
|
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
|
|
* OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
|
|
* IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
|
|
* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
|
|
* NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
|
|
* THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*************************************************************************/
|
|
|
|
#include <stdexcept>
|
|
#include <vector>
|
|
#include "flann.h"
|
|
#include "timer.h"
|
|
#include "common.h"
|
|
#include "logger.h"
|
|
#include "index_testing.h"
|
|
#include "saving.h"
|
|
#include "object_factory.h"
|
|
// index types
|
|
#include "kdtree_index.h"
|
|
#include "kmeans_index.h"
|
|
#include "composite_index.h"
|
|
#include "linear_index.h"
|
|
#include "autotuned_index.h"
|
|
|
|
#include <typeinfo>
|
|
using namespace std;
|
|
|
|
|
|
|
|
#include "flann.h"
|
|
|
|
#ifdef WIN32
|
|
#define EXPORTED extern "C" __declspec(dllexport)
|
|
#else
|
|
#define EXPORTED extern "C"
|
|
#endif
|
|
|
|
|
|
namespace cvflann
|
|
{
|
|
|
|
typedef ObjectFactory<IndexParams, flann_algorithm_t> ParamsFactory;
|
|
|
|
|
|
IndexParams* IndexParams::createFromParameters(const FLANNParameters& p)
|
|
{
|
|
IndexParams* params = ParamsFactory::instance().create(p.algorithm);
|
|
params->fromParameters(p);
|
|
|
|
return params;
|
|
}
|
|
|
|
NNIndex* LinearIndexParams::createIndex(const Matrix<float>& dataset) const
|
|
{
|
|
return new LinearIndex(dataset, *this);
|
|
}
|
|
|
|
|
|
NNIndex* KDTreeIndexParams::createIndex(const Matrix<float>& dataset) const
|
|
{
|
|
return new KDTreeIndex(dataset, *this);
|
|
}
|
|
|
|
NNIndex* KMeansIndexParams::createIndex(const Matrix<float>& dataset) const
|
|
{
|
|
return new KMeansIndex(dataset, *this);
|
|
}
|
|
|
|
|
|
NNIndex* CompositeIndexParams::createIndex(const Matrix<float>& dataset) const
|
|
{
|
|
return new CompositeIndex(dataset, *this);
|
|
}
|
|
|
|
|
|
NNIndex* AutotunedIndexParams::createIndex(const Matrix<float>& dataset) const
|
|
{
|
|
return new AutotunedIndex(dataset, *this);
|
|
}
|
|
|
|
|
|
NNIndex* SavedIndexParams::createIndex(const Matrix<float>& dataset) const
|
|
{
|
|
|
|
FILE* fin = fopen(filename.c_str(), "rb");
|
|
if (fin==NULL) {
|
|
return NULL;
|
|
}
|
|
IndexHeader header = load_header(fin);
|
|
rewind(fin);
|
|
IndexParams* params = ParamsFactory::instance().create(header.index_type);
|
|
NNIndex* nnIndex = params->createIndex(dataset);
|
|
nnIndex->loadIndex(fin);
|
|
fclose(fin);
|
|
delete params; //?
|
|
|
|
return nnIndex;
|
|
|
|
}
|
|
|
|
class StaticInit
|
|
{
|
|
public:
|
|
StaticInit()
|
|
{
|
|
ParamsFactory::instance().register_<LinearIndexParams>(LINEAR);
|
|
ParamsFactory::instance().register_<KDTreeIndexParams>(KDTREE);
|
|
ParamsFactory::instance().register_<KMeansIndexParams>(KMEANS);
|
|
ParamsFactory::instance().register_<CompositeIndexParams>(COMPOSITE);
|
|
ParamsFactory::instance().register_<AutotunedIndexParams>(AUTOTUNED);
|
|
ParamsFactory::instance().register_<SavedIndexParams>(SAVED);
|
|
}
|
|
};
|
|
StaticInit __init;
|
|
|
|
|
|
|
|
Index::Index(const Matrix<float>& dataset, const IndexParams& params)
|
|
{
|
|
nnIndex = params.createIndex(dataset);
|
|
nnIndex->buildIndex();
|
|
}
|
|
|
|
Index::~Index()
|
|
{
|
|
delete nnIndex;
|
|
}
|
|
|
|
|
|
void Index::knnSearch(const Matrix<float>& queries, Matrix<int>& indices, Matrix<float>& dists, int knn, const SearchParams& searchParams)
|
|
{
|
|
assert(queries.cols==nnIndex->veclen());
|
|
assert(indices.rows>=queries.rows);
|
|
assert(dists.rows>=queries.rows);
|
|
assert(indices.cols>=knn);
|
|
assert(dists.cols>=knn);
|
|
|
|
|
|
search_for_neighbors(*nnIndex, queries, indices, dists, searchParams);
|
|
}
|
|
|
|
int Index::radiusSearch(const Matrix<float>& query, Matrix<int> indices, Matrix<float> dists, float radius, const SearchParams& searchParams)
|
|
{
|
|
if (query.rows!=1) {
|
|
printf("I can only search one feature at a time for range search\n");
|
|
return -1;
|
|
}
|
|
assert(query.cols==nnIndex->veclen());
|
|
|
|
RadiusResultSet resultSet(radius);
|
|
resultSet.init(query.data, query.cols);
|
|
nnIndex->findNeighbors(resultSet,query.data,searchParams);
|
|
|
|
// TODO: optimize here
|
|
int* neighbors = resultSet.getNeighbors();
|
|
float* distances = resultSet.getDistances();
|
|
int count_nn = min((long)resultSet.size(), indices.cols);
|
|
|
|
assert (dists.cols>=count_nn);
|
|
|
|
for (int i=0;i<count_nn;++i) {
|
|
indices[0][i] = neighbors[i];
|
|
dists[0][i] = distances[i];
|
|
}
|
|
|
|
return count_nn;
|
|
}
|
|
|
|
|
|
void Index::save(string filename)
|
|
{
|
|
FILE* fout = fopen(filename.c_str(), "wb");
|
|
if (fout==NULL) {
|
|
logger.error("Cannot open file: %s", filename.c_str());
|
|
throw FLANNException("Cannot open file");
|
|
}
|
|
nnIndex->saveIndex(fout);
|
|
fclose(fout);
|
|
}
|
|
|
|
int Index::size() const
|
|
{
|
|
return nnIndex->size();
|
|
}
|
|
|
|
int Index::veclen() const
|
|
{
|
|
return nnIndex->veclen();
|
|
}
|
|
|
|
|
|
int hierarchicalClustering(const Matrix<float>& features, Matrix<float>& centers, const KMeansIndexParams& params)
|
|
{
|
|
KMeansIndex kmeans(features, params);
|
|
kmeans.buildIndex();
|
|
|
|
int clusterNum = kmeans.getClusterCenters(centers);
|
|
return clusterNum;
|
|
}
|
|
|
|
} // namespace FLANN
|
|
|
|
|
|
|
|
using namespace cvflann;
|
|
|
|
typedef NNIndex* NNIndexPtr;
|
|
typedef Matrix<float>* MatrixPtr;
|
|
|
|
|
|
|
|
void init_flann_parameters(FLANNParameters* p)
|
|
{
|
|
if (p != NULL) {
|
|
flann_log_verbosity(p->log_level);
|
|
if (p->random_seed>0) {
|
|
seed_random(p->random_seed);
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
EXPORTED void flann_log_verbosity(int level)
|
|
{
|
|
if (level>=0) {
|
|
logger.setLevel(level);
|
|
}
|
|
}
|
|
|
|
EXPORTED void flann_set_distance_type(flann_distance_t distance_type, int order)
|
|
{
|
|
flann_distance_type = distance_type;
|
|
flann_minkowski_order = order;
|
|
}
|
|
|
|
|
|
EXPORTED flann_index_t flann_build_index(float* dataset, int rows, int cols, float* /*speedup*/, FLANNParameters* flann_params)
|
|
{
|
|
try {
|
|
init_flann_parameters(flann_params);
|
|
if (flann_params == NULL) {
|
|
throw FLANNException("The flann_params argument must be non-null");
|
|
}
|
|
IndexParams* params = IndexParams::createFromParameters(*flann_params);
|
|
Index* index = new Index(Matrix<float>(rows,cols,dataset), *params);
|
|
|
|
return index;
|
|
}
|
|
catch (runtime_error& e) {
|
|
logger.error("Caught exception: %s\n",e.what());
|
|
return NULL;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
EXPORTED int flann_save_index(flann_index_t index_ptr, char* filename)
|
|
{
|
|
try {
|
|
if (index_ptr==NULL) {
|
|
throw FLANNException("Invalid index");
|
|
}
|
|
|
|
Index* index = (Index*)index_ptr;
|
|
index->save(filename);
|
|
|
|
return 0;
|
|
}
|
|
catch(runtime_error& e) {
|
|
logger.error("Caught exception: %s\n",e.what());
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
|
|
EXPORTED FLANN_INDEX flann_load_index(char* filename, float* dataset, int rows, int cols)
|
|
{
|
|
try {
|
|
Index* index = new Index(Matrix<float>(rows,cols,dataset), SavedIndexParams(filename));
|
|
return index;
|
|
}
|
|
catch(runtime_error& e) {
|
|
logger.error("Caught exception: %s\n",e.what());
|
|
return NULL;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
EXPORTED int flann_find_nearest_neighbors(float* dataset, int rows, int cols, float* testset, int tcount, int* result, float* dists, int nn, FLANNParameters* flann_params)
|
|
{
|
|
int _result = 0;
|
|
try {
|
|
init_flann_parameters(flann_params);
|
|
|
|
IndexParams* params = IndexParams::createFromParameters(*flann_params);
|
|
Index* index = new Index(Matrix<float>(rows,cols,dataset), *params);
|
|
Matrix<int> m_indices(tcount, nn, result);
|
|
Matrix<float> m_dists(tcount, nn, dists);
|
|
index->knnSearch(Matrix<float>(tcount, index->veclen(), testset),
|
|
m_indices,
|
|
m_dists, nn, SearchParams(flann_params->checks) );
|
|
}
|
|
catch(runtime_error& e) {
|
|
logger.error("Caught exception: %s\n",e.what());
|
|
_result = -1;
|
|
}
|
|
|
|
return _result;
|
|
}
|
|
|
|
|
|
EXPORTED int flann_find_nearest_neighbors_index(flann_index_t index_ptr, float* testset, int tcount, int* result, float* dists, int nn, int checks, FLANNParameters* flann_params)
|
|
{
|
|
try {
|
|
init_flann_parameters(flann_params);
|
|
if (index_ptr==NULL) {
|
|
throw FLANNException("Invalid index");
|
|
}
|
|
Index* index = (Index*) index_ptr;
|
|
|
|
Matrix<int> m_indices(tcount, nn, result);
|
|
Matrix<float> m_dists(tcount, nn, dists);
|
|
index->knnSearch(Matrix<float>(tcount, index->veclen(), testset),
|
|
m_indices,
|
|
m_dists, nn, SearchParams(checks) );
|
|
|
|
}
|
|
catch(runtime_error& e) {
|
|
logger.error("Caught exception: %s\n",e.what());
|
|
return -1;
|
|
}
|
|
|
|
return -1;
|
|
}
|
|
|
|
|
|
EXPORTED int flann_radius_search(FLANN_INDEX index_ptr,
|
|
float* query,
|
|
int* indices,
|
|
float* dists,
|
|
int max_nn,
|
|
float radius,
|
|
int checks,
|
|
FLANNParameters* flann_params)
|
|
{
|
|
try {
|
|
init_flann_parameters(flann_params);
|
|
if (index_ptr==NULL) {
|
|
throw FLANNException("Invalid index");
|
|
}
|
|
Index* index = (Index*) index_ptr;
|
|
|
|
Matrix<int> m_indices(1, max_nn, indices);
|
|
Matrix<float> m_dists(1, max_nn, dists);
|
|
int count = index->radiusSearch(Matrix<float>(1, index->veclen(), query),
|
|
m_indices,
|
|
m_dists, radius, SearchParams(checks) );
|
|
|
|
|
|
return count;
|
|
}
|
|
catch(runtime_error& e) {
|
|
logger.error("Caught exception: %s\n",e.what());
|
|
return -1;
|
|
}
|
|
|
|
}
|
|
|
|
|
|
EXPORTED int flann_free_index(FLANN_INDEX index_ptr, FLANNParameters* flann_params)
|
|
{
|
|
try {
|
|
init_flann_parameters(flann_params);
|
|
if (index_ptr==NULL) {
|
|
throw FLANNException("Invalid index");
|
|
}
|
|
Index* index = (Index*) index_ptr;
|
|
delete index;
|
|
|
|
return 0;
|
|
}
|
|
catch(runtime_error& e) {
|
|
logger.error("Caught exception: %s\n",e.what());
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
|
|
EXPORTED int flann_compute_cluster_centers(float* dataset, int rows, int cols, int clusters, float* result, FLANNParameters* flann_params)
|
|
{
|
|
try {
|
|
init_flann_parameters(flann_params);
|
|
|
|
MatrixPtr inputData = new Matrix<float>(rows,cols,dataset);
|
|
KMeansIndexParams params(flann_params->branching, flann_params->iterations, flann_params->centers_init, flann_params->cb_index);
|
|
Matrix<float> centers(clusters, cols, result);
|
|
int clusterNum = hierarchicalClustering(*inputData,centers, params);
|
|
|
|
return clusterNum;
|
|
} catch (runtime_error& e) {
|
|
logger.error("Caught exception: %s\n",e.what());
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
|
|
EXPORTED void compute_ground_truth_float(float* dataset, int dshape[], float* testset, int tshape[], int* match, int mshape[], int skip)
|
|
{
|
|
assert(dshape[1]==tshape[1]);
|
|
assert(tshape[0]==mshape[0]);
|
|
|
|
Matrix<int> _match(mshape[0], mshape[1], match);
|
|
compute_ground_truth(Matrix<float>(dshape[0], dshape[1], dataset), Matrix<float>(tshape[0], tshape[1], testset), _match, skip);
|
|
}
|
|
|
|
|
|
EXPORTED float test_with_precision(FLANN_INDEX index_ptr, float* dataset, int dshape[], float* testset, int tshape[], int* matches, int mshape[],
|
|
int nn, float precision, int* checks, int skip = 0)
|
|
{
|
|
assert(dshape[1]==tshape[1]);
|
|
assert(tshape[0]==mshape[0]);
|
|
|
|
try {
|
|
if (index_ptr==NULL) {
|
|
throw FLANNException("Invalid index");
|
|
}
|
|
NNIndexPtr index = (NNIndexPtr)index_ptr;
|
|
return test_index_precision(*index, Matrix<float>(dshape[0], dshape[1],dataset), Matrix<float>(tshape[0], tshape[1], testset),
|
|
Matrix<int>(mshape[0],mshape[1],matches), precision, *checks, nn, skip);
|
|
} catch (runtime_error& e) {
|
|
logger.error("Caught exception: %s\n",e.what());
|
|
return -1;
|
|
}
|
|
}
|
|
|
|
EXPORTED float test_with_checks(FLANN_INDEX index_ptr, float* dataset, int dshape[], float* testset, int tshape[], int* matches, int mshape[],
|
|
int nn, int checks, float* precision, int skip = 0)
|
|
{
|
|
assert(dshape[1]==tshape[1]);
|
|
assert(tshape[0]==mshape[0]);
|
|
|
|
try {
|
|
if (index_ptr==NULL) {
|
|
throw FLANNException("Invalid index");
|
|
}
|
|
NNIndexPtr index = (NNIndexPtr)index_ptr;
|
|
return test_index_checks(*index, Matrix<float>(dshape[0], dshape[1],dataset), Matrix<float>(tshape[0], tshape[1], testset),
|
|
Matrix<int>(mshape[0],mshape[1],matches), checks, *precision, nn, skip);
|
|
} catch (runtime_error& e) {
|
|
logger.error("Caught exception: %s\n",e.what());
|
|
return -1;
|
|
}
|
|
}
|