Merge pull request #3537 from mshabunin:ml-old-xml
This commit is contained in:
commit
bc23f29b27
@ -103,6 +103,7 @@ public:
|
|||||||
|
|
||||||
ANN_MLPImpl( const Params& p )
|
ANN_MLPImpl( const Params& p )
|
||||||
{
|
{
|
||||||
|
clear();
|
||||||
setParams(p);
|
setParams(p);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -126,6 +127,7 @@ public:
|
|||||||
rng = RNG((uint64)-1);
|
rng = RNG((uint64)-1);
|
||||||
weights.clear();
|
weights.clear();
|
||||||
trained = false;
|
trained = false;
|
||||||
|
max_buf_sz = 1 << 12;
|
||||||
}
|
}
|
||||||
|
|
||||||
int layer_count() const { return (int)layer_sizes.size(); }
|
int layer_count() const { return (int)layer_sizes.size(); }
|
||||||
@ -1241,7 +1243,7 @@ public:
|
|||||||
clear();
|
clear();
|
||||||
|
|
||||||
vector<int> _layer_sizes;
|
vector<int> _layer_sizes;
|
||||||
fn["layer_sizes"] >> _layer_sizes;
|
readVectorOrMat(fn["layer_sizes"], _layer_sizes);
|
||||||
create( _layer_sizes );
|
create( _layer_sizes );
|
||||||
|
|
||||||
int i, l_count = layer_count();
|
int i, l_count = layer_count();
|
||||||
|
@ -434,13 +434,17 @@ public:
|
|||||||
bparams.priors = params0.priors;
|
bparams.priors = params0.priors;
|
||||||
|
|
||||||
FileNode tparams_node = fn["training_params"];
|
FileNode tparams_node = fn["training_params"];
|
||||||
String bts = (String)tparams_node["boosting_type"];
|
// check for old layout
|
||||||
|
String bts = (String)(fn["boosting_type"].empty() ?
|
||||||
|
tparams_node["boosting_type"] : fn["boosting_type"]);
|
||||||
bparams.boostType = (bts == "DiscreteAdaboost" ? Boost::DISCRETE :
|
bparams.boostType = (bts == "DiscreteAdaboost" ? Boost::DISCRETE :
|
||||||
bts == "RealAdaboost" ? Boost::REAL :
|
bts == "RealAdaboost" ? Boost::REAL :
|
||||||
bts == "LogitBoost" ? Boost::LOGIT :
|
bts == "LogitBoost" ? Boost::LOGIT :
|
||||||
bts == "GentleAdaboost" ? Boost::GENTLE : -1);
|
bts == "GentleAdaboost" ? Boost::GENTLE : -1);
|
||||||
_isClassifier = bparams.boostType == Boost::DISCRETE;
|
_isClassifier = bparams.boostType == Boost::DISCRETE;
|
||||||
bparams.weightTrimRate = (double)tparams_node["weight_trimming_rate"];
|
// check for old layout
|
||||||
|
bparams.weightTrimRate = (double)(fn["weight_trimming_rate"].empty() ?
|
||||||
|
tparams_node["weight_trimming_rate"] : fn["weight_trimming_rate"]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void read( const FileNode& fn )
|
void read( const FileNode& fn )
|
||||||
|
@ -898,7 +898,7 @@ public:
|
|||||||
|
|
||||||
CV_Assert( m > 0 ); // if m==0, vi is an ordered variable
|
CV_Assert( m > 0 ); // if m==0, vi is an ordered variable
|
||||||
const int* cmap = &catMap.at<int>(ofs[0]);
|
const int* cmap = &catMap.at<int>(ofs[0]);
|
||||||
bool fastMap = (m == cmap[m] - cmap[0]);
|
bool fastMap = (m == cmap[m - 1] - cmap[0] + 1);
|
||||||
|
|
||||||
if( fastMap )
|
if( fastMap )
|
||||||
{
|
{
|
||||||
|
@ -115,6 +115,7 @@ void StatModel::save(const String& filename) const
|
|||||||
{
|
{
|
||||||
FileStorage fs(filename, FileStorage::WRITE);
|
FileStorage fs(filename, FileStorage::WRITE);
|
||||||
fs << getDefaultModelName() << "{";
|
fs << getDefaultModelName() << "{";
|
||||||
|
fs << "format" << (int)3;
|
||||||
write(fs);
|
write(fs);
|
||||||
fs << "}";
|
fs << "}";
|
||||||
}
|
}
|
||||||
|
@ -263,11 +263,27 @@ namespace ml
|
|||||||
vector<int> subsets;
|
vector<int> subsets;
|
||||||
vector<int> classLabels;
|
vector<int> classLabels;
|
||||||
vector<float> missingSubst;
|
vector<float> missingSubst;
|
||||||
|
vector<int> varMapping;
|
||||||
bool _isClassifier;
|
bool _isClassifier;
|
||||||
|
|
||||||
Ptr<WorkData> w;
|
Ptr<WorkData> w;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static inline void readVectorOrMat(const FileNode & node, std::vector<T> & v)
|
||||||
|
{
|
||||||
|
if (node.type() == FileNode::MAP)
|
||||||
|
{
|
||||||
|
Mat m;
|
||||||
|
node >> m;
|
||||||
|
m.copyTo(v);
|
||||||
|
}
|
||||||
|
else if (node.type() == FileNode::SEQ)
|
||||||
|
{
|
||||||
|
node >> v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
#endif /* __OPENCV_ML_PRECOMP_HPP__ */
|
#endif /* __OPENCV_ML_PRECOMP_HPP__ */
|
||||||
|
@ -346,7 +346,7 @@ public:
|
|||||||
oobError = (double)fn["oob_error"];
|
oobError = (double)fn["oob_error"];
|
||||||
int ntrees = (int)fn["ntrees"];
|
int ntrees = (int)fn["ntrees"];
|
||||||
|
|
||||||
fn["var_importance"] >> varImportance;
|
readVectorOrMat(fn["var_importance"], varImportance);
|
||||||
|
|
||||||
readParams(fn);
|
readParams(fn);
|
||||||
|
|
||||||
|
@ -2038,7 +2038,8 @@ public:
|
|||||||
{
|
{
|
||||||
Params _params;
|
Params _params;
|
||||||
|
|
||||||
String svm_type_str = (String)fn["svmType"];
|
// check for old naming
|
||||||
|
String svm_type_str = (String)(fn["svm_type"].empty() ? fn["svmType"] : fn["svm_type"]);
|
||||||
int svmType =
|
int svmType =
|
||||||
svm_type_str == "C_SVC" ? C_SVC :
|
svm_type_str == "C_SVC" ? C_SVC :
|
||||||
svm_type_str == "NU_SVC" ? NU_SVC :
|
svm_type_str == "NU_SVC" ? NU_SVC :
|
||||||
|
@ -1597,7 +1597,10 @@ void DTreesImpl::writeParams(FileStorage& fs) const
|
|||||||
fs << "}";
|
fs << "}";
|
||||||
|
|
||||||
if( !varIdx.empty() )
|
if( !varIdx.empty() )
|
||||||
|
{
|
||||||
|
fs << "global_var_idx" << 1;
|
||||||
fs << "var_idx" << varIdx;
|
fs << "var_idx" << varIdx;
|
||||||
|
}
|
||||||
|
|
||||||
fs << "var_type" << varType;
|
fs << "var_type" << varType;
|
||||||
|
|
||||||
@ -1726,9 +1729,8 @@ void DTreesImpl::readParams( const FileNode& fn )
|
|||||||
if( !tparams_node.empty() ) // training parameters are not necessary
|
if( !tparams_node.empty() ) // training parameters are not necessary
|
||||||
{
|
{
|
||||||
params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0;
|
params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0;
|
||||||
params0.maxCategories = (int)tparams_node["max_categories"];
|
params0.maxCategories = (int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"]);
|
||||||
params0.regressionAccuracy = (float)tparams_node["regression_accuracy"];
|
params0.regressionAccuracy = (float)tparams_node["regression_accuracy"];
|
||||||
|
|
||||||
params0.maxDepth = (int)tparams_node["max_depth"];
|
params0.maxDepth = (int)tparams_node["max_depth"];
|
||||||
params0.minSampleCount = (int)tparams_node["min_sample_count"];
|
params0.minSampleCount = (int)tparams_node["min_sample_count"];
|
||||||
params0.CVFolds = (int)tparams_node["cross_validation_folds"];
|
params0.CVFolds = (int)tparams_node["cross_validation_folds"];
|
||||||
@ -1741,13 +1743,83 @@ void DTreesImpl::readParams( const FileNode& fn )
|
|||||||
tparams_node["priors"] >> params0.priors;
|
tparams_node["priors"] >> params0.priors;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn["var_idx"] >> varIdx;
|
readVectorOrMat(fn["var_idx"], varIdx);
|
||||||
fn["var_type"] >> varType;
|
fn["var_type"] >> varType;
|
||||||
|
|
||||||
fn["cat_ofs"] >> catOfs;
|
int format = 0;
|
||||||
fn["cat_map"] >> catMap;
|
fn["format"] >> format;
|
||||||
fn["missing_subst"] >> missingSubst;
|
bool isLegacy = format < 3;
|
||||||
fn["class_labels"] >> classLabels;
|
|
||||||
|
int varAll = (int)fn["var_all"];
|
||||||
|
if (isLegacy && (int)varType.size() <= varAll)
|
||||||
|
{
|
||||||
|
std::vector<uchar> extendedTypes(varAll + 1, 0);
|
||||||
|
|
||||||
|
int i = 0, n;
|
||||||
|
if (!varIdx.empty())
|
||||||
|
{
|
||||||
|
n = (int)varIdx.size();
|
||||||
|
for (; i < n; ++i)
|
||||||
|
{
|
||||||
|
int var = varIdx[i];
|
||||||
|
extendedTypes[var] = varType[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
n = (int)varType.size();
|
||||||
|
for (; i < n; ++i)
|
||||||
|
{
|
||||||
|
extendedTypes[i] = varType[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED);
|
||||||
|
extendedTypes.swap(varType);
|
||||||
|
}
|
||||||
|
|
||||||
|
readVectorOrMat(fn["cat_map"], catMap);
|
||||||
|
|
||||||
|
if (isLegacy)
|
||||||
|
{
|
||||||
|
// generating "catOfs" from "cat_count"
|
||||||
|
catOfs.clear();
|
||||||
|
classLabels.clear();
|
||||||
|
std::vector<int> counts;
|
||||||
|
readVectorOrMat(fn["cat_count"], counts);
|
||||||
|
unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1;
|
||||||
|
for (; i < size; ++i)
|
||||||
|
{
|
||||||
|
Vec2i newOffsets(0, 0);
|
||||||
|
if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap
|
||||||
|
{
|
||||||
|
newOffsets[0] = curShift;
|
||||||
|
curShift += counts[j];
|
||||||
|
newOffsets[1] = curShift;
|
||||||
|
++j;
|
||||||
|
}
|
||||||
|
catOfs.push_back(newOffsets);
|
||||||
|
}
|
||||||
|
// other elements in "catMap" are "classLabels"
|
||||||
|
if (curShift < catMap.size())
|
||||||
|
{
|
||||||
|
classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end());
|
||||||
|
catMap.erase(catMap.begin() + curShift, catMap.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
fn["cat_ofs"] >> catOfs;
|
||||||
|
fn["missing_subst"] >> missingSubst;
|
||||||
|
fn["class_labels"] >> classLabels;
|
||||||
|
}
|
||||||
|
|
||||||
|
// init var mapping for node reading (var indexes or varIdx indexes)
|
||||||
|
bool globalVarIdx = false;
|
||||||
|
fn["global_var_idx"] >> globalVarIdx;
|
||||||
|
if (globalVarIdx || varIdx.empty())
|
||||||
|
setRangeVector(varMapping, (int)varType.size());
|
||||||
|
else
|
||||||
|
varMapping = varIdx;
|
||||||
|
|
||||||
initCompVarIdx();
|
initCompVarIdx();
|
||||||
setDParams(params0);
|
setDParams(params0);
|
||||||
@ -1759,6 +1831,7 @@ int DTreesImpl::readSplit( const FileNode& fn )
|
|||||||
|
|
||||||
int vi = (int)fn["var"];
|
int vi = (int)fn["var"];
|
||||||
CV_Assert( 0 <= vi && vi <= (int)varType.size() );
|
CV_Assert( 0 <= vi && vi <= (int)varType.size() );
|
||||||
|
vi = varMapping[vi]; // convert to varIdx if needed
|
||||||
split.varIdx = vi;
|
split.varIdx = vi;
|
||||||
|
|
||||||
if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var
|
if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var
|
||||||
|
@ -158,6 +158,109 @@ TEST(ML_Boost, save_load) { CV_SLMLTest test( CV_BOOST ); test.safe_run(); }
|
|||||||
TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); }
|
TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); }
|
||||||
TEST(DISABLED_ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); }
|
TEST(DISABLED_ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); }
|
||||||
|
|
||||||
|
class CV_LegacyTest : public cvtest::BaseTest
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
CV_LegacyTest(const std::string &_modelName, const std::string &_suffixes = std::string())
|
||||||
|
: cvtest::BaseTest(), modelName(_modelName), suffixes(_suffixes)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
virtual ~CV_LegacyTest() {}
|
||||||
|
protected:
|
||||||
|
void run(int)
|
||||||
|
{
|
||||||
|
unsigned int idx = 0;
|
||||||
|
for (;;)
|
||||||
|
{
|
||||||
|
if (idx >= suffixes.size())
|
||||||
|
break;
|
||||||
|
int found = (int)suffixes.find(';', idx);
|
||||||
|
string piece = suffixes.substr(idx, found - idx);
|
||||||
|
if (piece.empty())
|
||||||
|
break;
|
||||||
|
oneTest(piece);
|
||||||
|
idx += (unsigned int)piece.size() + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void oneTest(const string & suffix)
|
||||||
|
{
|
||||||
|
using namespace cv::ml;
|
||||||
|
|
||||||
|
int code = cvtest::TS::OK;
|
||||||
|
string filename = ts->get_data_path() + "legacy/" + modelName + suffix;
|
||||||
|
bool isTree = modelName == CV_BOOST || modelName == CV_DTREE || modelName == CV_RTREES;
|
||||||
|
Ptr<StatModel> model;
|
||||||
|
if (modelName == CV_BOOST)
|
||||||
|
model = StatModel::load<Boost>(filename);
|
||||||
|
else if (modelName == CV_ANN)
|
||||||
|
model = StatModel::load<ANN_MLP>(filename);
|
||||||
|
else if (modelName == CV_DTREE)
|
||||||
|
model = StatModel::load<DTrees>(filename);
|
||||||
|
else if (modelName == CV_NBAYES)
|
||||||
|
model = StatModel::load<NormalBayesClassifier>(filename);
|
||||||
|
else if (modelName == CV_SVM)
|
||||||
|
model = StatModel::load<SVM>(filename);
|
||||||
|
else if (modelName == CV_RTREES)
|
||||||
|
model = StatModel::load<RTrees>(filename);
|
||||||
|
if (!model)
|
||||||
|
{
|
||||||
|
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
Mat input = Mat(isTree ? 10 : 1, model->getVarCount(), CV_32F);
|
||||||
|
ts->get_rng().fill(input, RNG::UNIFORM, 0, 40);
|
||||||
|
|
||||||
|
if (isTree)
|
||||||
|
randomFillCategories(filename, input);
|
||||||
|
|
||||||
|
Mat output;
|
||||||
|
model->predict(input, output, StatModel::RAW_OUTPUT | (isTree ? DTrees::PREDICT_SUM : 0));
|
||||||
|
// just check if no internal assertions or errors thrown
|
||||||
|
}
|
||||||
|
ts->set_failed_test_info(code);
|
||||||
|
}
|
||||||
|
void randomFillCategories(const string & filename, Mat & input)
|
||||||
|
{
|
||||||
|
Mat catMap;
|
||||||
|
Mat catCount;
|
||||||
|
std::vector<uchar> varTypes;
|
||||||
|
|
||||||
|
FileStorage fs(filename, FileStorage::READ);
|
||||||
|
FileNode root = fs.getFirstTopLevelNode();
|
||||||
|
root["cat_map"] >> catMap;
|
||||||
|
root["cat_count"] >> catCount;
|
||||||
|
root["var_type"] >> varTypes;
|
||||||
|
|
||||||
|
int offset = 0;
|
||||||
|
int countOffset = 0;
|
||||||
|
uint var = 0, varCount = (uint)varTypes.size();
|
||||||
|
for (; var < varCount; ++var)
|
||||||
|
{
|
||||||
|
if (varTypes[var] == ml::VAR_CATEGORICAL)
|
||||||
|
{
|
||||||
|
int size = catCount.at<int>(0, countOffset);
|
||||||
|
for (int row = 0; row < input.rows; ++row)
|
||||||
|
{
|
||||||
|
int randomChosenIndex = offset + ((uint)ts->get_rng()) % size;
|
||||||
|
int value = catMap.at<int>(0, randomChosenIndex);
|
||||||
|
input.at<float>(row, var) = (float)value;
|
||||||
|
}
|
||||||
|
offset += size;
|
||||||
|
++countOffset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
string modelName;
|
||||||
|
string suffixes;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(ML_ANN, legacy_load) { CV_LegacyTest test(CV_ANN, "_waveform.xml"); test.safe_run(); }
|
||||||
|
TEST(ML_Boost, legacy_load) { CV_LegacyTest test(CV_BOOST, "_adult.xml;_1.xml;_2.xml;_3.xml"); test.safe_run(); }
|
||||||
|
TEST(ML_DTree, legacy_load) { CV_LegacyTest test(CV_DTREE, "_abalone.xml;_mushroom.xml"); test.safe_run(); }
|
||||||
|
TEST(ML_NBayes, legacy_load) { CV_LegacyTest test(CV_NBAYES, "_waveform.xml"); test.safe_run(); }
|
||||||
|
TEST(ML_SVM, legacy_load) { CV_LegacyTest test(CV_SVM, "_poletelecomm.xml;_waveform.xml"); test.safe_run(); }
|
||||||
|
TEST(ML_RTrees, legacy_load) { CV_LegacyTest test(CV_RTREES, "_waveform.xml"); test.safe_run(); }
|
||||||
|
|
||||||
/*TEST(ML_SVM, throw_exception_when_save_untrained_model)
|
/*TEST(ML_SVM, throw_exception_when_save_untrained_model)
|
||||||
{
|
{
|
||||||
|
Loading…
x
Reference in New Issue
Block a user