added some constants to python cv2 api
This commit is contained in:
@@ -6,13 +6,6 @@ def load_base(fn):
|
||||
samples, responses = a[:,1:], a[:,0]
|
||||
return samples, responses
|
||||
|
||||
# TODO move these to cv2
|
||||
CV_ROW_SAMPLE = 1
|
||||
CV_VAR_NUMERICAL = 0
|
||||
CV_VAR_ORDERED = 0
|
||||
CV_VAR_CATEGORICAL = 1
|
||||
|
||||
|
||||
class LetterStatModel(object):
|
||||
train_ratio = 0.5
|
||||
def load(self, fn):
|
||||
@@ -26,10 +19,10 @@ class RTrees(LetterStatModel):
|
||||
|
||||
def train(self, samples, responses):
|
||||
sample_n, var_n = samples.shape
|
||||
var_types = np.array([CV_VAR_NUMERICAL] * var_n + [CV_VAR_CATEGORICAL], np.uint8)
|
||||
var_types = np.array([cv2.CV_VAR_NUMERICAL] * var_n + [cv2.CV_VAR_CATEGORICAL], np.uint8)
|
||||
#CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));
|
||||
params = dict(max_depth=10 )
|
||||
self.model.train(samples, CV_ROW_SAMPLE, responses, varType = var_types, params = params)
|
||||
self.model.train(samples, cv2.CV_ROW_SAMPLE, responses, varType = var_types, params = params)
|
||||
|
||||
def predict(self, samples):
|
||||
return np.float32( [self.model.predict(s) for s in samples] )
|
||||
@@ -56,10 +49,10 @@ class Boost(LetterStatModel):
|
||||
sample_n, var_n = samples.shape
|
||||
new_samples = self.unroll_samples(samples)
|
||||
new_responses = self.unroll_responses(responses)
|
||||
var_types = np.array([CV_VAR_NUMERICAL] * var_n + [CV_VAR_CATEGORICAL, CV_VAR_CATEGORICAL], np.uint8)
|
||||
var_types = np.array([cv2.CV_VAR_NUMERICAL] * var_n + [cv2.CV_VAR_CATEGORICAL, cv2.CV_VAR_CATEGORICAL], np.uint8)
|
||||
#CvBoostParams(CvBoost::REAL, 100, 0.95, 5, false, 0 )
|
||||
params = dict(max_depth=5) #, use_surrogates=False)
|
||||
self.model.train(new_samples, CV_ROW_SAMPLE, new_responses, varType = var_types, params=params)
|
||||
self.model.train(new_samples, cv2.CV_ROW_SAMPLE, new_responses, varType = var_types, params=params)
|
||||
|
||||
def predict(self, samples):
|
||||
new_samples = self.unroll_samples(samples)
|
||||
@@ -105,7 +98,7 @@ if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-model', default='rtrees', choices=models.keys())
|
||||
parser.add_argument('-data', nargs=1, default='letter-recognition.data')
|
||||
parser.add_argument('-data', nargs=1, default='../cpp/letter-recognition.data')
|
||||
parser.add_argument('-load', nargs=1)
|
||||
parser.add_argument('-save', nargs=1)
|
||||
args = parser.parse_args()
|
||||
|
Reference in New Issue
Block a user