Support loading old models in ML module
- added test for loading legacy files - added version to new written models - fixed loading of several fields in some models - added generation of new fields from old data
This commit is contained in:
@@ -1597,7 +1597,10 @@ void DTreesImpl::writeParams(FileStorage& fs) const
|
||||
fs << "}";
|
||||
|
||||
if( !varIdx.empty() )
|
||||
{
|
||||
fs << "global_var_idx" << 1;
|
||||
fs << "var_idx" << varIdx;
|
||||
}
|
||||
|
||||
fs << "var_type" << varType;
|
||||
|
||||
@@ -1726,9 +1729,8 @@ void DTreesImpl::readParams( const FileNode& fn )
|
||||
if( !tparams_node.empty() ) // training parameters are not necessary
|
||||
{
|
||||
params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0;
|
||||
params0.maxCategories = (int)tparams_node["max_categories"];
|
||||
params0.maxCategories = (int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"]);
|
||||
params0.regressionAccuracy = (float)tparams_node["regression_accuracy"];
|
||||
|
||||
params0.maxDepth = (int)tparams_node["max_depth"];
|
||||
params0.minSampleCount = (int)tparams_node["min_sample_count"];
|
||||
params0.CVFolds = (int)tparams_node["cross_validation_folds"];
|
||||
@@ -1741,13 +1743,83 @@ void DTreesImpl::readParams( const FileNode& fn )
|
||||
tparams_node["priors"] >> params0.priors;
|
||||
}
|
||||
|
||||
fn["var_idx"] >> varIdx;
|
||||
readVectorOrMat(fn["var_idx"], varIdx);
|
||||
fn["var_type"] >> varType;
|
||||
|
||||
fn["cat_ofs"] >> catOfs;
|
||||
fn["cat_map"] >> catMap;
|
||||
fn["missing_subst"] >> missingSubst;
|
||||
fn["class_labels"] >> classLabels;
|
||||
int format = 0;
|
||||
fn["format"] >> format;
|
||||
bool isLegacy = format < 3;
|
||||
|
||||
int varAll = (int)fn["var_all"];
|
||||
if (isLegacy && (int)varType.size() <= varAll)
|
||||
{
|
||||
std::vector<uchar> extendedTypes(varAll + 1, 0);
|
||||
|
||||
int i = 0, n;
|
||||
if (!varIdx.empty())
|
||||
{
|
||||
n = (int)varIdx.size();
|
||||
for (; i < n; ++i)
|
||||
{
|
||||
int var = varIdx[i];
|
||||
extendedTypes[var] = varType[i];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
n = (int)varType.size();
|
||||
for (; i < n; ++i)
|
||||
{
|
||||
extendedTypes[i] = varType[i];
|
||||
}
|
||||
}
|
||||
extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED);
|
||||
extendedTypes.swap(varType);
|
||||
}
|
||||
|
||||
readVectorOrMat(fn["cat_map"], catMap);
|
||||
|
||||
if (isLegacy)
|
||||
{
|
||||
// generating "catOfs" from "cat_count"
|
||||
catOfs.clear();
|
||||
classLabels.clear();
|
||||
std::vector<int> counts;
|
||||
readVectorOrMat(fn["cat_count"], counts);
|
||||
unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1;
|
||||
for (; i < size; ++i)
|
||||
{
|
||||
Vec2i newOffsets(0, 0);
|
||||
if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap
|
||||
{
|
||||
newOffsets[0] = curShift;
|
||||
curShift += counts[j];
|
||||
newOffsets[1] = curShift;
|
||||
++j;
|
||||
}
|
||||
catOfs.push_back(newOffsets);
|
||||
}
|
||||
// other elements in "catMap" are "classLabels"
|
||||
if (curShift < catMap.size())
|
||||
{
|
||||
classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end());
|
||||
catMap.erase(catMap.begin() + curShift, catMap.end());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
fn["cat_ofs"] >> catOfs;
|
||||
fn["missing_subst"] >> missingSubst;
|
||||
fn["class_labels"] >> classLabels;
|
||||
}
|
||||
|
||||
// init var mapping for node reading (var indexes or varIdx indexes)
|
||||
bool globalVarIdx = false;
|
||||
fn["global_var_idx"] >> globalVarIdx;
|
||||
if (globalVarIdx || varIdx.empty())
|
||||
setRangeVector(varMapping, (int)varType.size());
|
||||
else
|
||||
varMapping = varIdx;
|
||||
|
||||
initCompVarIdx();
|
||||
setDParams(params0);
|
||||
@@ -1759,6 +1831,7 @@ int DTreesImpl::readSplit( const FileNode& fn )
|
||||
|
||||
int vi = (int)fn["var"];
|
||||
CV_Assert( 0 <= vi && vi <= (int)varType.size() );
|
||||
vi = varMapping[vi]; // convert to varIdx if needed
|
||||
split.varIdx = vi;
|
||||
|
||||
if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var
|
||||
|
||||
Reference in New Issue
Block a user