Support loading old models in ML module

- added test for loading legacy files
- added version to new written models
- fixed loading of several fields in some models
- added generation of new fields from old data
This commit is contained in:
Maksim Shabunin
2014-12-16 18:15:50 +03:00
parent 1c8493fb0d
commit d004ee58c5
9 changed files with 211 additions and 13 deletions

View File

@@ -1597,7 +1597,10 @@ void DTreesImpl::writeParams(FileStorage& fs) const
fs << "}";
if( !varIdx.empty() )
{
fs << "global_var_idx" << 1;
fs << "var_idx" << varIdx;
}
fs << "var_type" << varType;
@@ -1726,9 +1729,8 @@ void DTreesImpl::readParams( const FileNode& fn )
if( !tparams_node.empty() ) // training parameters are not necessary
{
params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0;
params0.maxCategories = (int)tparams_node["max_categories"];
params0.maxCategories = (int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"]);
params0.regressionAccuracy = (float)tparams_node["regression_accuracy"];
params0.maxDepth = (int)tparams_node["max_depth"];
params0.minSampleCount = (int)tparams_node["min_sample_count"];
params0.CVFolds = (int)tparams_node["cross_validation_folds"];
@@ -1741,13 +1743,83 @@ void DTreesImpl::readParams( const FileNode& fn )
tparams_node["priors"] >> params0.priors;
}
fn["var_idx"] >> varIdx;
readVectorOrMat(fn["var_idx"], varIdx);
fn["var_type"] >> varType;
fn["cat_ofs"] >> catOfs;
fn["cat_map"] >> catMap;
fn["missing_subst"] >> missingSubst;
fn["class_labels"] >> classLabels;
int format = 0;
fn["format"] >> format;
bool isLegacy = format < 3;
int varAll = (int)fn["var_all"];
if (isLegacy && (int)varType.size() <= varAll)
{
std::vector<uchar> extendedTypes(varAll + 1, 0);
int i = 0, n;
if (!varIdx.empty())
{
n = (int)varIdx.size();
for (; i < n; ++i)
{
int var = varIdx[i];
extendedTypes[var] = varType[i];
}
}
else
{
n = (int)varType.size();
for (; i < n; ++i)
{
extendedTypes[i] = varType[i];
}
}
extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED);
extendedTypes.swap(varType);
}
readVectorOrMat(fn["cat_map"], catMap);
if (isLegacy)
{
// generating "catOfs" from "cat_count"
catOfs.clear();
classLabels.clear();
std::vector<int> counts;
readVectorOrMat(fn["cat_count"], counts);
unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1;
for (; i < size; ++i)
{
Vec2i newOffsets(0, 0);
if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap
{
newOffsets[0] = curShift;
curShift += counts[j];
newOffsets[1] = curShift;
++j;
}
catOfs.push_back(newOffsets);
}
// other elements in "catMap" are "classLabels"
if (curShift < catMap.size())
{
classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end());
catMap.erase(catMap.begin() + curShift, catMap.end());
}
}
else
{
fn["cat_ofs"] >> catOfs;
fn["missing_subst"] >> missingSubst;
fn["class_labels"] >> classLabels;
}
// init var mapping for node reading (var indexes or varIdx indexes)
bool globalVarIdx = false;
fn["global_var_idx"] >> globalVarIdx;
if (globalVarIdx || varIdx.empty())
setRangeVector(varMapping, (int)varType.size());
else
varMapping = varIdx;
initCompVarIdx();
setDParams(params0);
@@ -1759,6 +1831,7 @@ int DTreesImpl::readSplit( const FileNode& fn )
int vi = (int)fn["var"];
CV_Assert( 0 <= vi && vi <= (int)varType.size() );
vi = varMapping[vi]; // convert to varIdx if needed
split.varIdx = vi;
if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var