switched from argparse to getopt for compatibility with Python 2.6
This commit is contained in:
@@ -91,31 +91,34 @@ class SVM(LetterStatModel):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
import getopt
|
||||
import sys
|
||||
|
||||
models = [RTrees, KNearest, Boost, SVM] # MLP, NBayes
|
||||
models = dict( [(cls.__name__.lower(), cls) for cls in models] )
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-model', default='rtrees', choices=models.keys())
|
||||
parser.add_argument('-data', nargs=1, default='../cpp/letter-recognition.data')
|
||||
parser.add_argument('-load', nargs=1)
|
||||
parser.add_argument('-save', nargs=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
print 'loading data %s ...' % args.data
|
||||
samples, responses = load_base(args.data)
|
||||
Model = models[args.model]
|
||||
print 'USAGE: letter_recog.py [--model <model>] [--data <data fn>] [--load <model fn>] [--save <model fn>]'
|
||||
print 'Models: ', ', '.join(models)
|
||||
print
|
||||
|
||||
args, dummy = getopt.getopt(sys.argv[1:], '', ['model=', 'data=', 'load=', 'save='])
|
||||
args = dict(args)
|
||||
args.setdefault('--model', 'rtrees')
|
||||
args.setdefault('--data', '../cpp/letter-recognition.data')
|
||||
|
||||
print 'loading data %s ...' % args['--data']
|
||||
samples, responses = load_base(args['--data'])
|
||||
Model = models[args['--model']]
|
||||
model = Model()
|
||||
|
||||
train_n = int(len(samples)*model.train_ratio)
|
||||
if args.load is None:
|
||||
print 'training %s ...' % Model.__name__
|
||||
model.train(samples[:train_n], responses[:train_n])
|
||||
else:
|
||||
fn = args.load[0]
|
||||
if '--load' in args:
|
||||
fn = args['--load']
|
||||
print 'loading model from %s ...' % fn
|
||||
model.load(fn)
|
||||
else:
|
||||
print 'training %s ...' % Model.__name__
|
||||
model.train(samples[:train_n], responses[:train_n])
|
||||
|
||||
print 'testing...'
|
||||
train_rate = np.mean(model.predict(samples[:train_n]) == responses[:train_n])
|
||||
@@ -123,7 +126,7 @@ if __name__ == '__main__':
|
||||
|
||||
print 'train rate: %f test rate: %f' % (train_rate*100, test_rate*100)
|
||||
|
||||
if args.save is not None:
|
||||
fn = args.save[0]
|
||||
if '--save' in args:
|
||||
fn = args['--save']
|
||||
print 'saving model to %s ...' % fn
|
||||
model.save(fn)
|
||||
|
Reference in New Issue
Block a user