svn repository web references are replaced with links to git
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
'''
|
||||
SVN and KNearest digit recognition.
|
||||
SVM and KNearest digit recognition.
|
||||
|
||||
Sample loads a dataset of handwritten digits from 'digits.png'.
|
||||
Then it trains a SVN and KNearest classifiers on it and evaluates
|
||||
their accuracy.
|
||||
Then it trains a SVM and KNearest classifiers on it and evaluates
|
||||
their accuracy.
|
||||
|
||||
Following preprocessing is applied to the dataset:
|
||||
- Moment-based image deskew (see deskew())
|
||||
@@ -77,7 +77,7 @@ class KNearest(StatModel):
|
||||
|
||||
class SVM(StatModel):
|
||||
def __init__(self, C = 1, gamma = 0.5):
|
||||
self.params = dict( kernel_type = cv2.SVM_RBF,
|
||||
self.params = dict( kernel_type = cv2.SVM_RBF,
|
||||
svm_type = cv2.SVM_C_SVC,
|
||||
C = C,
|
||||
gamma = gamma )
|
||||
@@ -95,7 +95,7 @@ def evaluate_model(model, digits, samples, labels):
|
||||
resp = model.predict(samples)
|
||||
err = (labels != resp).mean()
|
||||
print 'error: %.2f %%' % (err*100)
|
||||
|
||||
|
||||
confusion = np.zeros((10, 10), np.int32)
|
||||
for i, j in zip(labels, resp):
|
||||
confusion[i, j] += 1
|
||||
@@ -128,7 +128,7 @@ def preprocess_hog(digits):
|
||||
hist = np.hstack(hists)
|
||||
|
||||
# transform to Hellinger kernel
|
||||
eps = 1e-7
|
||||
eps = 1e-7
|
||||
hist /= hist.sum() + eps
|
||||
hist = np.sqrt(hist)
|
||||
hist /= norm(hist) + eps
|
||||
@@ -141,23 +141,23 @@ if __name__ == '__main__':
|
||||
print __doc__
|
||||
|
||||
digits, labels = load_digits(DIGITS_FN)
|
||||
|
||||
|
||||
print 'preprocessing...'
|
||||
# shuffle digits
|
||||
rand = np.random.RandomState(321)
|
||||
shuffle = rand.permutation(len(digits))
|
||||
digits, labels = digits[shuffle], labels[shuffle]
|
||||
|
||||
|
||||
digits2 = map(deskew, digits)
|
||||
samples = preprocess_hog(digits2)
|
||||
|
||||
|
||||
train_n = int(0.9*len(samples))
|
||||
cv2.imshow('test set', mosaic(25, digits[train_n:]))
|
||||
digits_train, digits_test = np.split(digits2, [train_n])
|
||||
samples_train, samples_test = np.split(samples, [train_n])
|
||||
labels_train, labels_test = np.split(labels, [train_n])
|
||||
|
||||
|
||||
|
||||
print 'training KNearest...'
|
||||
model = KNearest(k=4)
|
||||
model.train(samples_train, labels_train)
|
||||
|
@@ -1,15 +1,15 @@
|
||||
'''
|
||||
Digit recognition adjustment.
|
||||
Grid search is used to find the best parameters for SVN and KNearest classifiers.
|
||||
SVM adjustment follows the guidelines given in
|
||||
Digit recognition adjustment.
|
||||
Grid search is used to find the best parameters for SVM and KNearest classifiers.
|
||||
SVM adjustment follows the guidelines given in
|
||||
http://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf
|
||||
|
||||
Threading or cloud computing (with http://www.picloud.com/)) may be used
|
||||
Threading or cloud computing (with http://www.picloud.com/)) may be used
|
||||
to speedup the computation.
|
||||
|
||||
Usage:
|
||||
digits_adjust.py [--model {svm|knearest}] [--cloud] [--env <PiCloud environment>]
|
||||
|
||||
|
||||
--model {svm|knearest} - select the classifier (SVM is the default)
|
||||
--cloud - use PiCloud computing platform
|
||||
--env - cloud environment name
|
||||
@@ -23,12 +23,12 @@ from multiprocessing.pool import ThreadPool
|
||||
|
||||
from digits import *
|
||||
|
||||
try:
|
||||
try:
|
||||
import cloud
|
||||
have_cloud = True
|
||||
except ImportError:
|
||||
have_cloud = False
|
||||
|
||||
|
||||
|
||||
|
||||
def cross_validate(model_class, params, samples, labels, kfold = 3, pool = None):
|
||||
@@ -93,7 +93,7 @@ class App(object):
|
||||
pool = ThreadPool(processes=cv2.getNumberOfCPUs())
|
||||
ires = pool.imap_unordered(f, jobs)
|
||||
return ires
|
||||
|
||||
|
||||
def adjust_SVM(self):
|
||||
Cs = np.logspace(0, 10, 15, base=2)
|
||||
gammas = np.logspace(-7, 4, 15, base=2)
|
||||
@@ -107,7 +107,7 @@ class App(object):
|
||||
params = dict(C = Cs[i], gamma=gammas[j])
|
||||
score = cross_validate(SVM, params, samples, labels)
|
||||
return i, j, score
|
||||
|
||||
|
||||
ires = self.run_jobs(f, np.ndindex(*scores.shape))
|
||||
for count, (i, j, score) in enumerate(ires):
|
||||
scores[i, j] = score
|
||||
@@ -142,7 +142,7 @@ class App(object):
|
||||
if __name__ == '__main__':
|
||||
import getopt
|
||||
import sys
|
||||
|
||||
|
||||
print __doc__
|
||||
|
||||
args, _ = getopt.getopt(sys.argv[1:], '', ['model=', 'cloud', 'env='])
|
||||
|
Reference in New Issue
Block a user