Fixed variable importance in rtrees

This commit is contained in:
niederb
2015-11-29 23:25:46 +01:00
committed by Vadim Pisarevsky
parent bb4b4acce5
commit d8e3971e7f
2 changed files with 24 additions and 14 deletions

View File

@@ -63,7 +63,6 @@ int main(int argc, char** argv)
const double train_test_split_ratio = 0.5;
Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec);
if( data.empty() )
{
printf("ERROR: File %s can not be read\n", filename);
@@ -71,6 +70,7 @@ int main(int argc, char** argv)
}
data->setTrainTestSplitRatio(train_test_split_ratio);
std::cout << "Test/Train: " << data->getNTestSamples() << "/" << data->getNTrainSamples();
printf("======DTREE=====\n");
Ptr<DTrees> dtree = DTrees::create();
@@ -106,10 +106,19 @@ int main(int argc, char** argv)
rtrees->setUseSurrogates(false);
rtrees->setMaxCategories(16);
rtrees->setPriors(Mat());
rtrees->setCalculateVarImportance(false);
rtrees->setCalculateVarImportance(true);
rtrees->setActiveVarCount(0);
rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 0));
train_and_print_errs(rtrees, data);
cv::Mat ref_labels = data->getClassLabels();
cv::Mat test_data = data->getTestSampleIdx();
cv::Mat predict_labels;
rtrees->predict(data->getSamples(), predict_labels);
cv::Mat variable_importance = rtrees->getVarImportance();
std::cout << "Estimated variable importance" << std::endl;
for (int i = 0; i < variable_importance.rows; i++) {
std::cout << "Variable " << i << ": " << variable_importance.at<float>(i, 0) << std::endl;
}
return 0;
}