fixed digits.py sample to work with opencv 3
This commit is contained in:
parent
565d3dde40
commit
a4a2659dff
@ -74,30 +74,35 @@ class StatModel(object):
|
||||
class KNearest(StatModel):
|
||||
def __init__(self, k = 3):
|
||||
self.k = k
|
||||
self.model = cv2.KNearest()
|
||||
self.model = cv2.ml.KNearest_create()
|
||||
|
||||
def train(self, samples, responses):
|
||||
self.model = cv2.KNearest()
|
||||
self.model.train(samples, responses)
|
||||
self.model = cv2.ml.KNearest_create()
|
||||
self.model.train(samples, cv2.ml.ROW_SAMPLE, responses)
|
||||
|
||||
def predict(self, samples):
|
||||
retval, results, neigh_resp, dists = self.model.find_nearest(samples, self.k)
|
||||
retval, results, neigh_resp, dists = self.model.findNearest(samples, self.k)
|
||||
return results.ravel()
|
||||
|
||||
class SVM(StatModel):
|
||||
def __init__(self, C = 1, gamma = 0.5):
|
||||
self.params = dict( kernel_type = cv2.SVM_RBF,
|
||||
svm_type = cv2.SVM_C_SVC,
|
||||
self.params = dict( kernel_type = cv2.ml.SVM_RBF,
|
||||
svm_type = cv2.ml.SVM_C_SVC,
|
||||
C = C,
|
||||
gamma = gamma )
|
||||
self.model = cv2.SVM()
|
||||
self.model = cv2.ml.SVM_create()
|
||||
|
||||
def train(self, samples, responses):
|
||||
self.model = cv2.SVM()
|
||||
self.model.train(samples, responses, params = self.params)
|
||||
self.model = cv2.ml.SVM_create()
|
||||
""" original code """
|
||||
#self.model.train(samples, responses, params = self.params)
|
||||
""" but it's either this """
|
||||
self.model.train(samples, cv2.ml.ROW_SAMPLE, responses)
|
||||
""" or this """
|
||||
#self.model.train(samples, params = self.params)
|
||||
|
||||
def predict(self, samples):
|
||||
return self.model.predict_all(samples).ravel()
|
||||
return self.model.predict(samples)[1][0].ravel()
|
||||
|
||||
|
||||
def evaluate_model(model, digits, samples, labels):
|
||||
|
Loading…
Reference in New Issue
Block a user