67 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			67 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python
 | |
| '''
 | |
| ===============================================================================
 | |
| Interactive Image Segmentation using GrabCut algorithm.
 | |
| ===============================================================================
 | |
| '''
 | |
| 
 | |
| # Python 2/3 compatibility
 | |
| from __future__ import print_function
 | |
| 
 | |
| import numpy as np
 | |
| import cv2
 | |
| import sys
 | |
| 
 | |
| from tests_common import NewOpenCVTests
 | |
| 
 | |
| class grabcut_test(NewOpenCVTests):
 | |
| 
 | |
|     def verify(self, mask, exp):
 | |
| 
 | |
|         maxDiffRatio = 0.02
 | |
|         expArea = np.count_nonzero(exp)
 | |
|         nonIntersectArea = np.count_nonzero(mask != exp)
 | |
|         curRatio = float(nonIntersectArea) / expArea
 | |
|         return curRatio < maxDiffRatio
 | |
| 
 | |
|     def scaleMask(self, mask):
 | |
| 
 | |
|         return np.where((mask==cv2.GC_FGD) + (mask==cv2.GC_PR_FGD),255,0).astype('uint8')
 | |
| 
 | |
|     def test_grabcut(self):
 | |
| 
 | |
|         img = self.get_sample('cv/shared/airplane.png')
 | |
|         mask_prob = self.get_sample("cv/grabcut/mask_probpy.png", 0)
 | |
|         exp_mask1 = self.get_sample("cv/grabcut/exp_mask1py.png", 0)
 | |
|         exp_mask2 = self.get_sample("cv/grabcut/exp_mask2py.png", 0)
 | |
| 
 | |
|         if img is None:
 | |
|             self.assertTrue(False, 'Missing test data')
 | |
| 
 | |
|         rect = (24, 126, 459, 168)
 | |
|         mask = np.zeros(img.shape[:2], dtype = np.uint8)
 | |
|         bgdModel = np.zeros((1,65),np.float64)
 | |
|         fgdModel = np.zeros((1,65),np.float64)
 | |
|         cv2.grabCut(img, mask, rect, bgdModel, fgdModel, 0, cv2.GC_INIT_WITH_RECT)
 | |
|         cv2.grabCut(img, mask, rect, bgdModel, fgdModel, 2, cv2.GC_EVAL)
 | |
| 
 | |
|         if mask_prob is None:
 | |
|             mask_prob = mask.copy()
 | |
|             cv2.imwrite(self.extraTestDataPath + '/cv/grabcut/mask_probpy.png', mask_prob)
 | |
|         if exp_mask1 is None:
 | |
|             exp_mask1 = self.scaleMask(mask)
 | |
|             cv2.imwrite(self.extraTestDataPath + '/cv/grabcut/exp_mask1py.png', exp_mask1)
 | |
| 
 | |
|         self.assertEqual(self.verify(self.scaleMask(mask), exp_mask1), True)
 | |
| 
 | |
|         mask = mask_prob
 | |
|         bgdModel = np.zeros((1,65),np.float64)
 | |
|         fgdModel = np.zeros((1,65),np.float64)
 | |
|         cv2.grabCut(img, mask, rect, bgdModel, fgdModel, 0, cv2.GC_INIT_WITH_MASK)
 | |
|         cv2.grabCut(img, mask, rect, bgdModel, fgdModel, 1, cv2.GC_EVAL)
 | |
| 
 | |
|         if exp_mask2 is None:
 | |
|             exp_mask2 = self.scaleMask(mask)
 | |
|             cv2.imwrite(self.extraTestDataPath + '/cv/grabcut/exp_mask2py.png', exp_mask2)
 | |
| 
 | |
|         self.assertEqual(self.verify(self.scaleMask(mask), exp_mask2), True) | 
