From 61c7c441b980feb28cfa566512e7af904c8bac9b Mon Sep 17 00:00:00 2001 From: Vadim Pisarevsky Date: Fri, 6 Apr 2012 13:22:08 +0000 Subject: [PATCH] handle single-point sets in kmeans properly --- modules/core/src/matrix.cpp | 5 +++-- modules/core/test/test_math.cpp | 28 +++++++++++++++++----------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/modules/core/src/matrix.cpp b/modules/core/src/matrix.cpp index 408370746..dddca6bf6 100644 --- a/modules/core/src/matrix.cpp +++ b/modules/core/src/matrix.cpp @@ -2459,8 +2459,9 @@ double cv::kmeans( InputArray _data, int K, { const int SPP_TRIALS = 3; Mat data = _data.getMat(); - int N = data.rows > 1 ? data.rows : data.cols; - int dims = (data.rows > 1 ? data.cols : 1)*data.channels(); + bool isrow = data.rows == 1 && data.channels() > 1; + int N = !isrow ? data.rows : data.cols; + int dims = (!isrow ? data.cols : 1)*data.channels(); int type = data.depth(); attempts = std::max(attempts, 1); diff --git a/modules/core/test/test_math.cpp b/modules/core/test/test_math.cpp index 77bb2dbd0..35afafc37 100644 --- a/modules/core/test/test_math.cpp +++ b/modules/core/test/test_math.cpp @@ -2437,42 +2437,48 @@ public: protected: void run(int) { + int i, iter = 0, N = 0, N0 = 0, K = 0, dims = 0; + Mat labels; try { RNG& rng = theRNG(); const int MAX_DIM=5; - int MAX_POINTS = 100; - for( int iter = 0; iter < 100; iter++ ) + int MAX_POINTS = 100, maxIter = 100; + for( iter = 0; iter < maxIter; iter++ ) { ts->update_context(this, iter, true); - int dims = rng.uniform(1, MAX_DIM+1); - int N = rng.uniform(1, MAX_POINTS+1); - int N0 = rng.uniform(1, N/10+1); - int K = rng.uniform(1, N+1); + dims = rng.uniform(1, MAX_DIM+1); + N = rng.uniform(1, MAX_POINTS+1); + N0 = rng.uniform(1, MAX(N/10, 2)); + K = rng.uniform(1, N+1); - Mat data0(N0, dims, CV_32F), labels; + Mat data0(N0, dims, CV_32F); rng.fill(data0, RNG::UNIFORM, -1, 1); Mat data(N, dims, CV_32F); - for( int i = 0; i < N; i++ ) + for( i = 0; i < N; i++ ) data0.row(rng.uniform(0, N0)).copyTo(data.row(i)); kmeans(data, K, labels, TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 30, 0), 5, KMEANS_PP_CENTERS); Mat hist(K, 1, CV_32S, Scalar(0)); - for( int i = 0; i < N; i++ ) + for( i = 0; i < N; i++ ) { int l = labels.at(i); - CV_Assert( 0 <= l && l < K ); + CV_Assert(0 <= l && l < K); hist.at(l)++; } - for( int i = 0; i < K; i++ ) + for( i = 0; i < K; i++ ) CV_Assert( hist.at(i) != 0 ); } } catch(...) { + ts->printf(cvtest::TS::LOG, + "context: iteration=%d, N=%d, N0=%d, K=%d\n", + iter, N, N0, K); + std::cout << labels << std::endl; ts->set_failed_test_info(cvtest::TS::FAIL_MISMATCH); } }