Merge pull request #4070 from techfort:fixpythonsample

This commit is contained in:
Vadim Pisarevsky 2015-05-28 18:06:17 +00:00
commit 8b8fc9e66b

View File

@ -74,30 +74,30 @@ class StatModel(object):
class KNearest(StatModel):
def __init__(self, k = 3):
self.k = k
self.model = cv2.KNearest()
self.model = cv2.ml.KNearest_create()
def train(self, samples, responses):
self.model = cv2.KNearest()
self.model.train(samples, responses)
self.model = cv2.ml.KNearest_create()
self.model.train(samples, cv2.ml.ROW_SAMPLE, responses)
def predict(self, samples):
retval, results, neigh_resp, dists = self.model.find_nearest(samples, self.k)
retval, results, neigh_resp, dists = self.model.findNearest(samples, self.k)
return results.ravel()
class SVM(StatModel):
def __init__(self, C = 1, gamma = 0.5):
self.params = dict( kernel_type = cv2.SVM_RBF,
svm_type = cv2.SVM_C_SVC,
C = C,
gamma = gamma )
self.model = cv2.SVM()
self.model = cv2.ml.SVM_create()
self.model.setGamma(gamma)
self.model.setC(C)
self.model.setKernel(cv2.ml.SVM_RBF)
self.model.setType(cv2.ml.SVM_C_SVC)
def train(self, samples, responses):
self.model = cv2.SVM()
self.model.train(samples, responses, params = self.params)
self.model = cv2.ml.SVM_create()
self.model.train(samples, cv2.ml.ROW_SAMPLE, responses)
def predict(self, samples):
return self.model.predict_all(samples).ravel()
return self.model.predict(samples)[1][0].ravel()
def evaluate_model(model, digits, samples, labels):