extended cv.KMeans2 API in Python (ticket #414; thanks to hogelog). fixed failures in the opencv python tests.

This commit is contained in:
Vadim Pisarevsky
2010-11-30 10:11:38 +00:00
parent 53e362e403
commit 8754cafffb
4 changed files with 83 additions and 32 deletions

View File

@@ -383,8 +383,8 @@ class FunctionTests(OpenCVTests):
m = cv.CreateMat(rows, cols, t)
self.assertEqual(cv.GetElemType(m), t)
self.assertEqual(m.type, t)
self.assertRaises(cv.error, lambda: cv.CreateMat(0, 100, cv.CV_8SC4))
self.assertRaises(cv.error, lambda: cv.CreateMat(100, 0, cv.CV_8SC4))
self.assertRaises(cv.error, lambda: cv.CreateMat(-1, 100, cv.CV_8SC4))
self.assertRaises(cv.error, lambda: cv.CreateMat(100, -1, cv.CV_8SC4))
self.assertRaises(cv.error, lambda: cv.cvmat())
def test_DrawChessboardCorners(self):
@@ -632,7 +632,7 @@ class FunctionTests(OpenCVTests):
tmp1 = cv.CreateMat(1, 13 * 5, cv.CV_32FC1)
tmp2 = cv.CreateMat(1, 13 * 5, cv.CV_32FC1)
mask = cv.CreateMat(image.rows, image.cols, cv.CV_8UC1)
cv.grabCut(image, mask, (10,10,200,200), tmp1, tmp2, 10, cv.GC_INIT_WITH_RECT)
cv.GrabCut(image, mask, (10,10,200,200), tmp1, tmp2, 10, cv.GC_INIT_WITH_RECT)
def test_HoughLines2_PROBABILISTIC(self):
li = cv.HoughLines2(self.yield_line_image(),
@@ -806,6 +806,36 @@ class FunctionTests(OpenCVTests):
r2 = cv.SnakeImage(cv.GetImage(src), pts, w, w, w, (7,7), (cv.CV_TERMCRIT_ITER, 100, 0.1))
self.assertEqual(r, r2)
def test_KMeans2(self):
size = 500
samples = cv.CreateMat(size, 1, cv.CV_32FC3)
labels = cv.CreateMat(size, 1, cv.CV_32SC1)
centers = cv.CreateMat(2, 3, cv.CV_32FC1)
cv.Zero(samples)
cv.Zero(labels)
cv.Zero(centers)
cv.Set(cv.GetSubRect(samples, (0, 0, 1, size/2)), (255, 255, 255))
compact = cv.KMeans2(samples, 2, labels, (cv.CV_TERMCRIT_ITER, 100, 0.1), 1, 0, centers)
self.assertEqual(int(compact), 0)
random.seed(0)
for i in range(50):
index = random.randrange(size)
if index < size/2:
self.assertEqual(samples[index, 0], (255, 255, 255))
self.assertEqual(labels[index, 0], 1)
else:
self.assertEqual(samples[index, 0], (0, 0, 0))
self.assertEqual(labels[index, 0], 0)
for cluster in (0, 1):
for channel in (0, 1, 2):
self.assertEqual(int(centers[cluster, channel]), cluster*255)
def test_Sum(self):
for r in range(1,11):
for c in range(1, 11):