Support loading old models in ML module
- added test for loading legacy files - added version to new written models - fixed loading of several fields in some models - added generation of new fields from old data
This commit is contained in:
@@ -158,6 +158,109 @@ TEST(ML_Boost, save_load) { CV_SLMLTest test( CV_BOOST ); test.safe_run(); }
|
||||
TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); }
|
||||
TEST(DISABLED_ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); }
|
||||
|
||||
class CV_LegacyTest : public cvtest::BaseTest
|
||||
{
|
||||
public:
|
||||
CV_LegacyTest(const std::string &_modelName, const std::string &_suffixes = std::string())
|
||||
: cvtest::BaseTest(), modelName(_modelName), suffixes(_suffixes)
|
||||
{
|
||||
}
|
||||
virtual ~CV_LegacyTest() {}
|
||||
protected:
|
||||
void run(int)
|
||||
{
|
||||
unsigned int idx = 0;
|
||||
for (;;)
|
||||
{
|
||||
if (idx >= suffixes.size())
|
||||
break;
|
||||
int found = (int)suffixes.find(';', idx);
|
||||
string piece = suffixes.substr(idx, found - idx);
|
||||
if (piece.empty())
|
||||
break;
|
||||
oneTest(piece);
|
||||
idx += (unsigned int)piece.size() + 1;
|
||||
}
|
||||
}
|
||||
void oneTest(const string & suffix)
|
||||
{
|
||||
using namespace cv::ml;
|
||||
|
||||
int code = cvtest::TS::OK;
|
||||
string filename = ts->get_data_path() + "legacy/" + modelName + suffix;
|
||||
bool isTree = modelName == CV_BOOST || modelName == CV_DTREE || modelName == CV_RTREES;
|
||||
Ptr<StatModel> model;
|
||||
if (modelName == CV_BOOST)
|
||||
model = StatModel::load<Boost>(filename);
|
||||
else if (modelName == CV_ANN)
|
||||
model = StatModel::load<ANN_MLP>(filename);
|
||||
else if (modelName == CV_DTREE)
|
||||
model = StatModel::load<DTrees>(filename);
|
||||
else if (modelName == CV_NBAYES)
|
||||
model = StatModel::load<NormalBayesClassifier>(filename);
|
||||
else if (modelName == CV_SVM)
|
||||
model = StatModel::load<SVM>(filename);
|
||||
else if (modelName == CV_RTREES)
|
||||
model = StatModel::load<RTrees>(filename);
|
||||
if (!model)
|
||||
{
|
||||
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
|
||||
}
|
||||
else
|
||||
{
|
||||
Mat input = Mat(isTree ? 10 : 1, model->getVarCount(), CV_32F);
|
||||
ts->get_rng().fill(input, RNG::UNIFORM, 0, 40);
|
||||
|
||||
if (isTree)
|
||||
randomFillCategories(filename, input);
|
||||
|
||||
Mat output;
|
||||
model->predict(input, output, StatModel::RAW_OUTPUT | (isTree ? DTrees::PREDICT_SUM : 0));
|
||||
// just check if no internal assertions or errors thrown
|
||||
}
|
||||
ts->set_failed_test_info(code);
|
||||
}
|
||||
void randomFillCategories(const string & filename, Mat & input)
|
||||
{
|
||||
Mat catMap;
|
||||
Mat catCount;
|
||||
std::vector<uchar> varTypes;
|
||||
|
||||
FileStorage fs(filename, FileStorage::READ);
|
||||
FileNode root = fs.getFirstTopLevelNode();
|
||||
root["cat_map"] >> catMap;
|
||||
root["cat_count"] >> catCount;
|
||||
root["var_type"] >> varTypes;
|
||||
|
||||
int offset = 0;
|
||||
int countOffset = 0;
|
||||
uint var = 0, varCount = (uint)varTypes.size();
|
||||
for (; var < varCount; ++var)
|
||||
{
|
||||
if (varTypes[var] == ml::VAR_CATEGORICAL)
|
||||
{
|
||||
int size = catCount.at<int>(0, countOffset);
|
||||
for (int row = 0; row < input.rows; ++row)
|
||||
{
|
||||
int randomChosenIndex = offset + ((uint)ts->get_rng()) % size;
|
||||
int value = catMap.at<int>(0, randomChosenIndex);
|
||||
input.at<float>(row, var) = (float)value;
|
||||
}
|
||||
offset += size;
|
||||
++countOffset;
|
||||
}
|
||||
}
|
||||
}
|
||||
string modelName;
|
||||
string suffixes;
|
||||
};
|
||||
|
||||
TEST(ML_ANN, legacy_load) { CV_LegacyTest test(CV_ANN, "_waveform.xml"); test.safe_run(); }
|
||||
TEST(ML_Boost, legacy_load) { CV_LegacyTest test(CV_BOOST, "_adult.xml;_1.xml;_2.xml;_3.xml"); test.safe_run(); }
|
||||
TEST(ML_DTree, legacy_load) { CV_LegacyTest test(CV_DTREE, "_abalone.xml;_mushroom.xml"); test.safe_run(); }
|
||||
TEST(ML_NBayes, legacy_load) { CV_LegacyTest test(CV_NBAYES, "_waveform.xml"); test.safe_run(); }
|
||||
TEST(ML_SVM, legacy_load) { CV_LegacyTest test(CV_SVM, "_poletelecomm.xml;_waveform.xml"); test.safe_run(); }
|
||||
TEST(ML_RTrees, legacy_load) { CV_LegacyTest test(CV_RTREES, "_waveform.xml"); test.safe_run(); }
|
||||
|
||||
/*TEST(ML_SVM, throw_exception_when_save_untrained_model)
|
||||
{
|
||||
|
Reference in New Issue
Block a user