Support loading old models in ML module

- added test for loading legacy files
- added version to new written models
- fixed loading of several fields in some models
- added generation of new fields from old data
This commit is contained in:
Maksim Shabunin
2014-12-16 18:15:50 +03:00
parent 1c8493fb0d
commit d004ee58c5
9 changed files with 211 additions and 13 deletions

View File

@@ -158,6 +158,109 @@ TEST(ML_Boost, save_load) { CV_SLMLTest test( CV_BOOST ); test.safe_run(); }
TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); }
TEST(DISABLED_ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); }
class CV_LegacyTest : public cvtest::BaseTest
{
public:
CV_LegacyTest(const std::string &_modelName, const std::string &_suffixes = std::string())
: cvtest::BaseTest(), modelName(_modelName), suffixes(_suffixes)
{
}
virtual ~CV_LegacyTest() {}
protected:
void run(int)
{
unsigned int idx = 0;
for (;;)
{
if (idx >= suffixes.size())
break;
int found = (int)suffixes.find(';', idx);
string piece = suffixes.substr(idx, found - idx);
if (piece.empty())
break;
oneTest(piece);
idx += (unsigned int)piece.size() + 1;
}
}
void oneTest(const string & suffix)
{
using namespace cv::ml;
int code = cvtest::TS::OK;
string filename = ts->get_data_path() + "legacy/" + modelName + suffix;
bool isTree = modelName == CV_BOOST || modelName == CV_DTREE || modelName == CV_RTREES;
Ptr<StatModel> model;
if (modelName == CV_BOOST)
model = StatModel::load<Boost>(filename);
else if (modelName == CV_ANN)
model = StatModel::load<ANN_MLP>(filename);
else if (modelName == CV_DTREE)
model = StatModel::load<DTrees>(filename);
else if (modelName == CV_NBAYES)
model = StatModel::load<NormalBayesClassifier>(filename);
else if (modelName == CV_SVM)
model = StatModel::load<SVM>(filename);
else if (modelName == CV_RTREES)
model = StatModel::load<RTrees>(filename);
if (!model)
{
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
}
else
{
Mat input = Mat(isTree ? 10 : 1, model->getVarCount(), CV_32F);
ts->get_rng().fill(input, RNG::UNIFORM, 0, 40);
if (isTree)
randomFillCategories(filename, input);
Mat output;
model->predict(input, output, StatModel::RAW_OUTPUT | (isTree ? DTrees::PREDICT_SUM : 0));
// just check if no internal assertions or errors thrown
}
ts->set_failed_test_info(code);
}
void randomFillCategories(const string & filename, Mat & input)
{
Mat catMap;
Mat catCount;
std::vector<uchar> varTypes;
FileStorage fs(filename, FileStorage::READ);
FileNode root = fs.getFirstTopLevelNode();
root["cat_map"] >> catMap;
root["cat_count"] >> catCount;
root["var_type"] >> varTypes;
int offset = 0;
int countOffset = 0;
uint var = 0, varCount = (uint)varTypes.size();
for (; var < varCount; ++var)
{
if (varTypes[var] == ml::VAR_CATEGORICAL)
{
int size = catCount.at<int>(0, countOffset);
for (int row = 0; row < input.rows; ++row)
{
int randomChosenIndex = offset + ((uint)ts->get_rng()) % size;
int value = catMap.at<int>(0, randomChosenIndex);
input.at<float>(row, var) = (float)value;
}
offset += size;
++countOffset;
}
}
}
string modelName;
string suffixes;
};
TEST(ML_ANN, legacy_load) { CV_LegacyTest test(CV_ANN, "_waveform.xml"); test.safe_run(); }
TEST(ML_Boost, legacy_load) { CV_LegacyTest test(CV_BOOST, "_adult.xml;_1.xml;_2.xml;_3.xml"); test.safe_run(); }
TEST(ML_DTree, legacy_load) { CV_LegacyTest test(CV_DTREE, "_abalone.xml;_mushroom.xml"); test.safe_run(); }
TEST(ML_NBayes, legacy_load) { CV_LegacyTest test(CV_NBAYES, "_waveform.xml"); test.safe_run(); }
TEST(ML_SVM, legacy_load) { CV_LegacyTest test(CV_SVM, "_poletelecomm.xml;_waveform.xml"); test.safe_run(); }
TEST(ML_RTrees, legacy_load) { CV_LegacyTest test(CV_RTREES, "_waveform.xml"); test.safe_run(); }
/*TEST(ML_SVM, throw_exception_when_save_untrained_model)
{