From c6a27b3d2f6b467ef97acf70655f4a910e1a4f77 Mon Sep 17 00:00:00 2001 From: Vadim Pisarevsky Date: Thu, 5 Apr 2012 13:01:34 +0000 Subject: [PATCH] probably, ultimately fixed the problem of empty clusters in kmeans; added test for singular kmeans cases --- modules/core/src/matrix.cpp | 14 +++++++-- modules/core/test/test_math.cpp | 52 +++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/modules/core/src/matrix.cpp b/modules/core/src/matrix.cpp index ebacaa4f4..408370746 100644 --- a/modules/core/src/matrix.cpp +++ b/modules/core/src/matrix.cpp @@ -2489,7 +2489,7 @@ double cv::kmeans( InputArray _data, int K, } int* labels = _labels.ptr(); - Mat centers(K, dims, type), old_centers(K, dims, type); + Mat centers(K, dims, type), old_centers(K, dims, type), temp(1, dims, type); vector counters(K); vector _box(dims); Vec2f* box = &_box[0]; @@ -2533,7 +2533,7 @@ double cv::kmeans( InputArray _data, int K, for( a = 0; a < attempts; a++ ) { double max_center_shift = DBL_MAX; - for( iter = 0; iter < criteria.maxCount && max_center_shift > criteria.epsilon; iter++ ) + for( iter = 0;; ) { swap(centers, old_centers); @@ -2609,7 +2609,11 @@ double cv::kmeans( InputArray _data, int K, double max_dist = 0; int farthest_i = -1; float* new_center = centers.ptr(k); - float* old_center = centers.ptr(max_k); + float* _old_center = centers.ptr(max_k); + float* old_center = temp.ptr(); + float scale = 1.f/counters[max_k]; + for( j = 0; j < dims; j++ ) + old_center[j] = _old_center[j]*scale; for( i = 0; i < N; i++ ) { @@ -2627,6 +2631,7 @@ double cv::kmeans( InputArray _data, int K, counters[max_k]--; counters[k]++; + labels[farthest_i] = k; sample = data.ptr(farthest_i); for( j = 0; j < dims; j++ ) @@ -2658,6 +2663,9 @@ double cv::kmeans( InputArray _data, int K, } } } + + if( ++iter == MAX(criteria.maxCount, 2) || max_center_shift <= criteria.epsilon ) + break; // assign labels compactness = 0; diff --git a/modules/core/test/test_math.cpp b/modules/core/test/test_math.cpp index 5785e8312..77bb2dbd0 100644 --- a/modules/core/test/test_math.cpp +++ b/modules/core/test/test_math.cpp @@ -2428,5 +2428,57 @@ TEST(Core_SolvePoly, accuracy) { Core_SolvePolyTest test; test.safe_run(); } // TODO: eigenvv, invsqrt, cbrt, fastarctan, (round, floor, ceil(?)), + +class CV_KMeansSingularTest : public cvtest::BaseTest +{ +public: + CV_KMeansSingularTest() {} + ~CV_KMeansSingularTest() {} +protected: + void run(int) + { + try + { + RNG& rng = theRNG(); + const int MAX_DIM=5; + int MAX_POINTS = 100; + for( int iter = 0; iter < 100; iter++ ) + { + ts->update_context(this, iter, true); + int dims = rng.uniform(1, MAX_DIM+1); + int N = rng.uniform(1, MAX_POINTS+1); + int N0 = rng.uniform(1, N/10+1); + int K = rng.uniform(1, N+1); + + Mat data0(N0, dims, CV_32F), labels; + rng.fill(data0, RNG::UNIFORM, -1, 1); + + Mat data(N, dims, CV_32F); + for( int i = 0; i < N; i++ ) + data0.row(rng.uniform(0, N0)).copyTo(data.row(i)); + + kmeans(data, K, labels, TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 30, 0), + 5, KMEANS_PP_CENTERS); + + Mat hist(K, 1, CV_32S, Scalar(0)); + for( int i = 0; i < N; i++ ) + { + int l = labels.at(i); + CV_Assert( 0 <= l && l < K ); + hist.at(l)++; + } + for( int i = 0; i < K; i++ ) + CV_Assert( hist.at(i) != 0 ); + } + } + catch(...) + { + ts->set_failed_test_info(cvtest::TS::FAIL_MISMATCH); + } + } +}; + +TEST(Core_KMeans, singular) { CV_KMeansSingularTest test; test.safe_run(); } + /* End of file. */