HoG and Hellinger-metric preprocess for digit recognition

line breaks in fitline.py description
This commit is contained in:
Alexander Mordvintsev 2012-07-02 13:49:36 +00:00
parent efe139667b
commit 1543b46383
4 changed files with 69 additions and 24 deletions

View File

@ -3,8 +3,19 @@ SVN and KNearest digit recognition.
Sample loads a dataset of handwritten digits from 'digits.png'. Sample loads a dataset of handwritten digits from 'digits.png'.
Then it trains a SVN and KNearest classifiers on it and evaluates Then it trains a SVN and KNearest classifiers on it and evaluates
their accuracy. Moment-based image deskew is used to improve their accuracy.
the recognition accuracy.
Following preprocessing is applied to the dataset:
- Moment-based image deskew (see deskew())
- Digit images are split into 4 10x10 cells and 16-bin
histogram of oriented gradients is computed for each
cell
- Transform histograms to space with Hellinger metric (see [1] (RootSIFT))
[1] R. Arandjelovic, A. Zisserman
"Three things everyone should know to improve object retrieval"
http://www.robots.ox.ac.uk/~vgg/publications/2012/Arandjelovic12/arandjelovic12.pdf
Usage: Usage:
digits.py digits.py
@ -14,17 +25,25 @@ import numpy as np
import cv2 import cv2
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from common import clock, mosaic from common import clock, mosaic
from numpy.linalg import norm
SZ = 20 # size of each digit is SZ x SZ SZ = 20 # size of each digit is SZ x SZ
CLASS_N = 10 CLASS_N = 10
DIGITS_FN = 'digits.png' DIGITS_FN = 'digits.png'
def split2d(img, cell_size, flatten=True):
h, w = img.shape[:2]
sx, sy = cell_size
cells = [np.hsplit(row, w//sx) for row in np.vsplit(img, h//sy)]
cells = np.array(cells)
if flatten:
cells = cells.reshape(-1, sy, sx)
return cells
def load_digits(fn): def load_digits(fn):
print 'loading "%s" ...' % fn print 'loading "%s" ...' % fn
digits_img = cv2.imread(fn, 0) digits_img = cv2.imread(fn, 0)
h, w = digits_img.shape digits = split2d(digits_img, (SZ, SZ))
digits = [np.hsplit(row, w/SZ) for row in np.vsplit(digits_img, h/SZ)]
digits = np.array(digits).reshape(-1, SZ, SZ)
labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N) labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N)
return digits, labels return digits, labels
@ -92,6 +111,31 @@ def evaluate_model(model, digits, samples, labels):
vis.append(img) vis.append(img)
return mosaic(25, vis) return mosaic(25, vis)
def preprocess_simple(digits):
return np.float32(digits).reshape(-1, SZ*SZ) / 255.0
def preprocess_hog(digits):
samples = []
for img in digits:
gx = cv2.Sobel(img, cv2.CV_32F, 1, 0)
gy = cv2.Sobel(img, cv2.CV_32F, 0, 1)
mag, ang = cv2.cartToPolar(gx, gy)
bin_n = 16
bin = np.int32(bin_n*ang/(2*np.pi))
bin_cells = bin[:10,:10], bin[10:,:10], bin[:10,10:], bin[10:,10:]
mag_cells = mag[:10,:10], mag[10:,:10], mag[:10,10:], mag[10:,10:]
hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)]
hist = np.hstack(hists)
# transform to Hellinger kernel
eps = 1e-7
hist /= hist.sum() + eps
hist = np.sqrt(hist)
hist /= norm(hist) + eps
samples.append(hist)
return np.float32(samples)
if __name__ == '__main__': if __name__ == '__main__':
print __doc__ print __doc__
@ -100,13 +144,13 @@ if __name__ == '__main__':
print 'preprocessing...' print 'preprocessing...'
# shuffle digits # shuffle digits
rand = np.random.RandomState(12345) rand = np.random.RandomState(321)
shuffle = rand.permutation(len(digits)) shuffle = rand.permutation(len(digits))
digits, labels = digits[shuffle], labels[shuffle] digits, labels = digits[shuffle], labels[shuffle]
digits2 = map(deskew, digits) digits2 = map(deskew, digits)
samples = np.float32(digits2).reshape(-1, SZ*SZ) / 255.0 samples = preprocess_hog(digits2)
train_n = int(0.9*len(samples)) train_n = int(0.9*len(samples))
cv2.imshow('test set', mosaic(25, digits[train_n:])) cv2.imshow('test set', mosaic(25, digits[train_n:]))
digits_train, digits_test = np.split(digits2, [train_n]) digits_train, digits_test = np.split(digits2, [train_n])
@ -115,13 +159,13 @@ if __name__ == '__main__':
print 'training KNearest...' print 'training KNearest...'
model = KNearest(k=1) model = KNearest(k=4)
model.train(samples_train, labels_train) model.train(samples_train, labels_train)
vis = evaluate_model(model, digits_test, samples_test, labels_test) vis = evaluate_model(model, digits_test, samples_test, labels_test)
cv2.imshow('KNearest test', vis) cv2.imshow('KNearest test', vis)
print 'training SVM...' print 'training SVM...'
model = SVM(C=4.66, gamma=0.08) model = SVM(C=2.67, gamma=5.383)
model.train(samples_train, labels_train) model.train(samples_train, labels_train)
vis = evaluate_model(model, digits_test, samples_test, labels_test) vis = evaluate_model(model, digits_test, samples_test, labels_test)
cv2.imshow('SVM test', vis) cv2.imshow('SVM test', vis)

View File

@ -76,7 +76,7 @@ class App(object):
shuffle = np.random.permutation(len(digits)) shuffle = np.random.permutation(len(digits))
digits, labels = digits[shuffle], labels[shuffle] digits, labels = digits[shuffle], labels[shuffle]
digits2 = map(deskew, digits) digits2 = map(deskew, digits)
samples = np.float32(digits2).reshape(-1, SZ*SZ) / 255.0 samples = preprocess_hog(digits2)
return samples, labels return samples, labels
def get_dataset(self): def get_dataset(self):
@ -95,8 +95,8 @@ class App(object):
return ires return ires
def adjust_SVM(self): def adjust_SVM(self):
Cs = np.logspace(0, 5, 10, base=2) Cs = np.logspace(0, 10, 15, base=2)
gammas = np.logspace(-7, -2, 10, base=2) gammas = np.logspace(-7, 4, 15, base=2)
scores = np.zeros((len(Cs), len(gammas))) scores = np.zeros((len(Cs), len(gammas)))
scores[:] = np.nan scores[:] = np.nan
@ -114,6 +114,9 @@ class App(object):
print '%d / %d (best error: %.2f %%, last: %.2f %%)' % (count+1, scores.size, np.nanmin(scores)*100, score*100) print '%d / %d (best error: %.2f %%, last: %.2f %%)' % (count+1, scores.size, np.nanmin(scores)*100, score*100)
print scores print scores
print 'writing score table to "svm_scores.npz"'
np.savez('svm_scores.npz', scores=scores, Cs=Cs, gammas=gammas)
i, j = np.unravel_index(scores.argmin(), scores.shape) i, j = np.unravel_index(scores.argmin(), scores.shape)
best_params = dict(C = Cs[i], gamma=gammas[j]) best_params = dict(C = Cs[i], gamma=gammas[j])
print 'best params:', best_params print 'best params:', best_params
@ -142,7 +145,6 @@ if __name__ == '__main__':
print __doc__ print __doc__
args, _ = getopt.getopt(sys.argv[1:], '', ['model=', 'cloud', 'env=']) args, _ = getopt.getopt(sys.argv[1:], '', ['model=', 'cloud', 'env='])
args = dict(args) args = dict(args)
args.setdefault('--model', 'svm') args.setdefault('--model', 'svm')

View File

@ -1,10 +1,10 @@
import numpy as np import numpy as np
import cv2 import cv2
import digits
import os import os
import video import video
from common import mosaic from common import mosaic
from digits import *
def main(): def main():
@ -15,11 +15,9 @@ def main():
print '"%s" not found, run digits.py first' % classifier_fn print '"%s" not found, run digits.py first' % classifier_fn
return return
model = digits.SVM() model = SVM()
model.load('digits_svm.dat') model.load('digits_svm.dat')
SZ = 20
while True: while True:
ret, frame = cap.read() ret, frame = cap.read()
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
@ -55,13 +53,12 @@ def main():
A[:,:2] = np.eye(2)*s A[:,:2] = np.eye(2)*s
A[:,2] = t A[:,2] = t
sub1 = cv2.warpAffine(sub, A, (SZ, SZ), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR) sub1 = cv2.warpAffine(sub, A, (SZ, SZ), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)
sub1 = digits.deskew(sub1) sub1 = deskew(sub1)
if x+w+SZ < frame.shape[1] and y+SZ < frame.shape[0]: if x+w+SZ < frame.shape[1] and y+SZ < frame.shape[0]:
frame[y:,x+w:][:SZ, :SZ] = sub1[...,np.newaxis] frame[y:,x+w:][:SZ, :SZ] = sub1[...,np.newaxis]
sample = np.float32(sub1).reshape(1,SZ*SZ) / 255.0 sample = preprocess_hog([sub1])
digit = model.predict(sample)[0] digit = model.predict(sample)[0]
cv2.putText(frame, '%d'%digit, (x, y), cv2.FONT_HERSHEY_PLAIN, 1.0, (200, 0, 0), thickness = 1) cv2.putText(frame, '%d'%digit, (x, y), cv2.FONT_HERSHEY_PLAIN, 1.0, (200, 0, 0), thickness = 1)

View File

@ -2,14 +2,16 @@
Robust line fitting. Robust line fitting.
================== ==================
Example of using cv2.fitLine function for fitting line to points in presence of outliers. Example of using cv2.fitLine function for fitting line
to points in presence of outliers.
Usage Usage
----- -----
fitline.py fitline.py
Switch through different M-estimator functions and see, how well the robust functions Switch through different M-estimator functions and see,
fit the line even in case of ~50% of outliers. how well the robust functions fit the line even
in case of ~50% of outliers.
Keys Keys
---- ----