From d004ee58c5685b86f5c6127895cfcb830ac5c2de Mon Sep 17 00:00:00 2001 From: Maksim Shabunin Date: Tue, 16 Dec 2014 18:15:50 +0300 Subject: [PATCH] Support loading old models in ML module - added test for loading legacy files - added version to new written models - fixed loading of several fields in some models - added generation of new fields from old data --- modules/ml/src/ann_mlp.cpp | 2 +- modules/ml/src/boost.cpp | 8 ++- modules/ml/src/data.cpp | 2 +- modules/ml/src/inner_functions.cpp | 1 + modules/ml/src/precomp.hpp | 16 +++++ modules/ml/src/rtrees.cpp | 2 +- modules/ml/src/svm.cpp | 3 +- modules/ml/src/tree.cpp | 87 ++++++++++++++++++++++-- modules/ml/test/test_save_load.cpp | 103 +++++++++++++++++++++++++++++ 9 files changed, 211 insertions(+), 13 deletions(-) diff --git a/modules/ml/src/ann_mlp.cpp b/modules/ml/src/ann_mlp.cpp index 3e7d44e87..ef5280131 100644 --- a/modules/ml/src/ann_mlp.cpp +++ b/modules/ml/src/ann_mlp.cpp @@ -1241,7 +1241,7 @@ public: clear(); vector _layer_sizes; - fn["layer_sizes"] >> _layer_sizes; + readVectorOrMat(fn["layer_sizes"], _layer_sizes); create( _layer_sizes ); int i, l_count = layer_count(); diff --git a/modules/ml/src/boost.cpp b/modules/ml/src/boost.cpp index 5e0b30733..236cd97a2 100644 --- a/modules/ml/src/boost.cpp +++ b/modules/ml/src/boost.cpp @@ -434,13 +434,17 @@ public: bparams.priors = params0.priors; FileNode tparams_node = fn["training_params"]; - String bts = (String)tparams_node["boosting_type"]; + // check for old layout + String bts = (String)(fn["boosting_type"].empty() ? + tparams_node["boosting_type"] : fn["boosting_type"]); bparams.boostType = (bts == "DiscreteAdaboost" ? Boost::DISCRETE : bts == "RealAdaboost" ? Boost::REAL : bts == "LogitBoost" ? Boost::LOGIT : bts == "GentleAdaboost" ? Boost::GENTLE : -1); _isClassifier = bparams.boostType == Boost::DISCRETE; - bparams.weightTrimRate = (double)tparams_node["weight_trimming_rate"]; + // check for old layout + bparams.weightTrimRate = (double)(fn["weight_trimming_rate"].empty() ? + tparams_node["weight_trimming_rate"] : fn["weight_trimming_rate"]); } void read( const FileNode& fn ) diff --git a/modules/ml/src/data.cpp b/modules/ml/src/data.cpp index 6b5ceb488..d2ac18ff0 100644 --- a/modules/ml/src/data.cpp +++ b/modules/ml/src/data.cpp @@ -898,7 +898,7 @@ public: CV_Assert( m > 0 ); // if m==0, vi is an ordered variable const int* cmap = &catMap.at(ofs[0]); - bool fastMap = (m == cmap[m] - cmap[0]); + bool fastMap = (m == cmap[m - 1] - cmap[0] + 1); if( fastMap ) { diff --git a/modules/ml/src/inner_functions.cpp b/modules/ml/src/inner_functions.cpp index dbc21ff09..561abbaeb 100644 --- a/modules/ml/src/inner_functions.cpp +++ b/modules/ml/src/inner_functions.cpp @@ -115,6 +115,7 @@ void StatModel::save(const String& filename) const { FileStorage fs(filename, FileStorage::WRITE); fs << getDefaultModelName() << "{"; + fs << "format" << (int)3; write(fs); fs << "}"; } diff --git a/modules/ml/src/precomp.hpp b/modules/ml/src/precomp.hpp index d308ae98e..69ff03047 100644 --- a/modules/ml/src/precomp.hpp +++ b/modules/ml/src/precomp.hpp @@ -263,11 +263,27 @@ namespace ml vector subsets; vector classLabels; vector missingSubst; + vector varMapping; bool _isClassifier; Ptr w; }; + template + static inline void readVectorOrMat(const FileNode & node, std::vector & v) + { + if (node.type() == FileNode::MAP) + { + Mat m; + node >> m; + m.copyTo(v); + } + else if (node.type() == FileNode::SEQ) + { + node >> v; + } + } + }} #endif /* __OPENCV_ML_PRECOMP_HPP__ */ diff --git a/modules/ml/src/rtrees.cpp b/modules/ml/src/rtrees.cpp index 7c9cbaf26..7441faac1 100644 --- a/modules/ml/src/rtrees.cpp +++ b/modules/ml/src/rtrees.cpp @@ -346,7 +346,7 @@ public: oobError = (double)fn["oob_error"]; int ntrees = (int)fn["ntrees"]; - fn["var_importance"] >> varImportance; + readVectorOrMat(fn["var_importance"], varImportance); readParams(fn); diff --git a/modules/ml/src/svm.cpp b/modules/ml/src/svm.cpp index c7c32f0be..a0df44f78 100644 --- a/modules/ml/src/svm.cpp +++ b/modules/ml/src/svm.cpp @@ -2038,7 +2038,8 @@ public: { Params _params; - String svm_type_str = (String)fn["svmType"]; + // check for old naming + String svm_type_str = (String)(fn["svm_type"].empty() ? fn["svmType"] : fn["svm_type"]); int svmType = svm_type_str == "C_SVC" ? C_SVC : svm_type_str == "NU_SVC" ? NU_SVC : diff --git a/modules/ml/src/tree.cpp b/modules/ml/src/tree.cpp index 416abd936..64f66169b 100644 --- a/modules/ml/src/tree.cpp +++ b/modules/ml/src/tree.cpp @@ -1597,7 +1597,10 @@ void DTreesImpl::writeParams(FileStorage& fs) const fs << "}"; if( !varIdx.empty() ) + { + fs << "global_var_idx" << 1; fs << "var_idx" << varIdx; + } fs << "var_type" << varType; @@ -1726,9 +1729,8 @@ void DTreesImpl::readParams( const FileNode& fn ) if( !tparams_node.empty() ) // training parameters are not necessary { params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0; - params0.maxCategories = (int)tparams_node["max_categories"]; + params0.maxCategories = (int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"]); params0.regressionAccuracy = (float)tparams_node["regression_accuracy"]; - params0.maxDepth = (int)tparams_node["max_depth"]; params0.minSampleCount = (int)tparams_node["min_sample_count"]; params0.CVFolds = (int)tparams_node["cross_validation_folds"]; @@ -1741,13 +1743,83 @@ void DTreesImpl::readParams( const FileNode& fn ) tparams_node["priors"] >> params0.priors; } - fn["var_idx"] >> varIdx; + readVectorOrMat(fn["var_idx"], varIdx); fn["var_type"] >> varType; - fn["cat_ofs"] >> catOfs; - fn["cat_map"] >> catMap; - fn["missing_subst"] >> missingSubst; - fn["class_labels"] >> classLabels; + int format = 0; + fn["format"] >> format; + bool isLegacy = format < 3; + + int varAll = (int)fn["var_all"]; + if (isLegacy && (int)varType.size() <= varAll) + { + std::vector extendedTypes(varAll + 1, 0); + + int i = 0, n; + if (!varIdx.empty()) + { + n = (int)varIdx.size(); + for (; i < n; ++i) + { + int var = varIdx[i]; + extendedTypes[var] = varType[i]; + } + } + else + { + n = (int)varType.size(); + for (; i < n; ++i) + { + extendedTypes[i] = varType[i]; + } + } + extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED); + extendedTypes.swap(varType); + } + + readVectorOrMat(fn["cat_map"], catMap); + + if (isLegacy) + { + // generating "catOfs" from "cat_count" + catOfs.clear(); + classLabels.clear(); + std::vector counts; + readVectorOrMat(fn["cat_count"], counts); + unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1; + for (; i < size; ++i) + { + Vec2i newOffsets(0, 0); + if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap + { + newOffsets[0] = curShift; + curShift += counts[j]; + newOffsets[1] = curShift; + ++j; + } + catOfs.push_back(newOffsets); + } + // other elements in "catMap" are "classLabels" + if (curShift < catMap.size()) + { + classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end()); + catMap.erase(catMap.begin() + curShift, catMap.end()); + } + } + else + { + fn["cat_ofs"] >> catOfs; + fn["missing_subst"] >> missingSubst; + fn["class_labels"] >> classLabels; + } + + // init var mapping for node reading (var indexes or varIdx indexes) + bool globalVarIdx = false; + fn["global_var_idx"] >> globalVarIdx; + if (globalVarIdx || varIdx.empty()) + setRangeVector(varMapping, (int)varType.size()); + else + varMapping = varIdx; initCompVarIdx(); setDParams(params0); @@ -1759,6 +1831,7 @@ int DTreesImpl::readSplit( const FileNode& fn ) int vi = (int)fn["var"]; CV_Assert( 0 <= vi && vi <= (int)varType.size() ); + vi = varMapping[vi]; // convert to varIdx if needed split.varIdx = vi; if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var diff --git a/modules/ml/test/test_save_load.cpp b/modules/ml/test/test_save_load.cpp index bef2fd0e1..74e8eef0d 100644 --- a/modules/ml/test/test_save_load.cpp +++ b/modules/ml/test/test_save_load.cpp @@ -158,6 +158,109 @@ TEST(ML_Boost, save_load) { CV_SLMLTest test( CV_BOOST ); test.safe_run(); } TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); } TEST(DISABLED_ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); } +class CV_LegacyTest : public cvtest::BaseTest +{ +public: + CV_LegacyTest(const std::string &_modelName, const std::string &_suffixes = std::string()) + : cvtest::BaseTest(), modelName(_modelName), suffixes(_suffixes) + { + } + virtual ~CV_LegacyTest() {} +protected: + void run(int) + { + unsigned int idx = 0; + for (;;) + { + if (idx >= suffixes.size()) + break; + int found = (int)suffixes.find(';', idx); + string piece = suffixes.substr(idx, found - idx); + if (piece.empty()) + break; + oneTest(piece); + idx += (unsigned int)piece.size() + 1; + } + } + void oneTest(const string & suffix) + { + using namespace cv::ml; + + int code = cvtest::TS::OK; + string filename = ts->get_data_path() + "legacy/" + modelName + suffix; + bool isTree = modelName == CV_BOOST || modelName == CV_DTREE || modelName == CV_RTREES; + Ptr model; + if (modelName == CV_BOOST) + model = StatModel::load(filename); + else if (modelName == CV_ANN) + model = StatModel::load(filename); + else if (modelName == CV_DTREE) + model = StatModel::load(filename); + else if (modelName == CV_NBAYES) + model = StatModel::load(filename); + else if (modelName == CV_SVM) + model = StatModel::load(filename); + else if (modelName == CV_RTREES) + model = StatModel::load(filename); + if (!model) + { + code = cvtest::TS::FAIL_INVALID_TEST_DATA; + } + else + { + Mat input = Mat(isTree ? 10 : 1, model->getVarCount(), CV_32F); + ts->get_rng().fill(input, RNG::UNIFORM, 0, 40); + + if (isTree) + randomFillCategories(filename, input); + + Mat output; + model->predict(input, output, StatModel::RAW_OUTPUT | (isTree ? DTrees::PREDICT_SUM : 0)); + // just check if no internal assertions or errors thrown + } + ts->set_failed_test_info(code); + } + void randomFillCategories(const string & filename, Mat & input) + { + Mat catMap; + Mat catCount; + std::vector varTypes; + + FileStorage fs(filename, FileStorage::READ); + FileNode root = fs.getFirstTopLevelNode(); + root["cat_map"] >> catMap; + root["cat_count"] >> catCount; + root["var_type"] >> varTypes; + + int offset = 0; + int countOffset = 0; + uint var = 0, varCount = (uint)varTypes.size(); + for (; var < varCount; ++var) + { + if (varTypes[var] == ml::VAR_CATEGORICAL) + { + int size = catCount.at(0, countOffset); + for (int row = 0; row < input.rows; ++row) + { + int randomChosenIndex = offset + ((uint)ts->get_rng()) % size; + int value = catMap.at(0, randomChosenIndex); + input.at(row, var) = (float)value; + } + offset += size; + ++countOffset; + } + } + } + string modelName; + string suffixes; +}; + +TEST(ML_ANN, legacy_load) { CV_LegacyTest test(CV_ANN, "_waveform.xml"); test.safe_run(); } +TEST(ML_Boost, legacy_load) { CV_LegacyTest test(CV_BOOST, "_adult.xml;_1.xml;_2.xml;_3.xml"); test.safe_run(); } +TEST(ML_DTree, legacy_load) { CV_LegacyTest test(CV_DTREE, "_abalone.xml;_mushroom.xml"); test.safe_run(); } +TEST(ML_NBayes, legacy_load) { CV_LegacyTest test(CV_NBAYES, "_waveform.xml"); test.safe_run(); } +TEST(ML_SVM, legacy_load) { CV_LegacyTest test(CV_SVM, "_poletelecomm.xml;_waveform.xml"); test.safe_run(); } +TEST(ML_RTrees, legacy_load) { CV_LegacyTest test(CV_RTREES, "_waveform.xml"); test.safe_run(); } /*TEST(ML_SVM, throw_exception_when_save_untrained_model) {