added some constants to python cv2 api

This commit is contained in:
Alexander Mordvintsev
2011-06-06 14:18:25 +00:00
parent cd2f3786f0
commit 6dc7ae0ff6
4 changed files with 21 additions and 20 deletions

View File

@@ -6,13 +6,6 @@ def load_base(fn):
samples, responses = a[:,1:], a[:,0]
return samples, responses
# TODO move these to cv2
CV_ROW_SAMPLE = 1
CV_VAR_NUMERICAL = 0
CV_VAR_ORDERED = 0
CV_VAR_CATEGORICAL = 1
class LetterStatModel(object):
train_ratio = 0.5
def load(self, fn):
@@ -26,10 +19,10 @@ class RTrees(LetterStatModel):
def train(self, samples, responses):
sample_n, var_n = samples.shape
var_types = np.array([CV_VAR_NUMERICAL] * var_n + [CV_VAR_CATEGORICAL], np.uint8)
var_types = np.array([cv2.CV_VAR_NUMERICAL] * var_n + [cv2.CV_VAR_CATEGORICAL], np.uint8)
#CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));
params = dict(max_depth=10 )
self.model.train(samples, CV_ROW_SAMPLE, responses, varType = var_types, params = params)
self.model.train(samples, cv2.CV_ROW_SAMPLE, responses, varType = var_types, params = params)
def predict(self, samples):
return np.float32( [self.model.predict(s) for s in samples] )
@@ -56,10 +49,10 @@ class Boost(LetterStatModel):
sample_n, var_n = samples.shape
new_samples = self.unroll_samples(samples)
new_responses = self.unroll_responses(responses)
var_types = np.array([CV_VAR_NUMERICAL] * var_n + [CV_VAR_CATEGORICAL, CV_VAR_CATEGORICAL], np.uint8)
var_types = np.array([cv2.CV_VAR_NUMERICAL] * var_n + [cv2.CV_VAR_CATEGORICAL, cv2.CV_VAR_CATEGORICAL], np.uint8)
#CvBoostParams(CvBoost::REAL, 100, 0.95, 5, false, 0 )
params = dict(max_depth=5) #, use_surrogates=False)
self.model.train(new_samples, CV_ROW_SAMPLE, new_responses, varType = var_types, params=params)
self.model.train(new_samples, cv2.CV_ROW_SAMPLE, new_responses, varType = var_types, params=params)
def predict(self, samples):
new_samples = self.unroll_samples(samples)
@@ -105,7 +98,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-model', default='rtrees', choices=models.keys())
parser.add_argument('-data', nargs=1, default='letter-recognition.data')
parser.add_argument('-data', nargs=1, default='../cpp/letter-recognition.data')
parser.add_argument('-load', nargs=1)
parser.add_argument('-save', nargs=1)
args = parser.parse_args()