Merge pull request #5783 from vpisarev:ml_fixes

This commit is contained in:
Vadim Pisarevsky 2015-12-11 06:25:15 +00:00
commit 286ba8cffd
7 changed files with 136 additions and 21 deletions

View File

@ -675,11 +675,19 @@ public:
/** @brief Retrieves all the support vectors
The method returns all the support vector as floating-point matrix, where support vectors are
The method returns all the support vectors as a floating-point matrix, where support vectors are
stored as matrix rows.
*/
CV_WRAP virtual Mat getSupportVectors() const = 0;
/** @brief Retrieves all the uncompressed support vectors of a linear %SVM
The method returns all the uncompressed support vectors of a linear %SVM that the compressed
support vector, used for prediction, was derived from. They are returned in a floating-point
matrix, where the support vectors are stored as matrix rows.
*/
CV_WRAP Mat getUncompressedSupportVectors() const;
/** @brief Retrieves the decision function
@param i the index of the decision function. If the problem solved is regression, 1-class or

View File

@ -253,7 +253,7 @@ public:
if( !sampleIdx.empty() )
{
CV_Assert( (sampleIdx.checkVector(1, CV_32S, true) > 0 &&
checkRange(sampleIdx, true, 0, 0, nsamples-1)) ||
checkRange(sampleIdx, true, 0, 0, nsamples)) ||
sampleIdx.checkVector(1, CV_8U, true) == nsamples );
if( sampleIdx.type() == CV_8U )
sampleIdx = convertMaskToIdx(sampleIdx);

View File

@ -187,7 +187,7 @@ public:
oobidx.clear();
for( i = 0; i < n; i++ )
{
if( !oobmask[i] )
if( oobmask[i] )
oobidx.push_back(i);
}
int n_oob = (int)oobidx.size();
@ -217,6 +217,7 @@ public:
else
{
int ival = cvRound(val);
//Voting scheme to combine OOB errors of each tree
int* votes = &oobvotes[j*nclasses];
votes[ival]++;
int best_class = 0;
@ -232,38 +233,39 @@ public:
oobError /= n_oob;
if( rparams.calcVarImportance && n_oob > 1 )
{
Mat sample_clone;
oobperm.resize(n_oob);
for( i = 0; i < n_oob; i++ )
oobperm[i] = oobidx[i];
for (i = n_oob - 1; i > 0; --i) //Randomly shuffle indices so we can permute features
{
int r_i = rng.uniform(0, n_oob);
std::swap(oobperm[i], oobperm[r_i]);
}
for( vi_ = 0; vi_ < nvars; vi_++ )
{
vi = vidx ? vidx[vi_] : vi_;
vi = vidx ? vidx[vi_] : vi_; //Ensure that only the user specified predictors are used for training
double ncorrect_responses_permuted = 0;
for( i = 0; i < n_oob; i++ )
{
int i1 = rng.uniform(0, n_oob);
int i2 = rng.uniform(0, n_oob);
std::swap(i1, i2);
}
for( i = 0; i < n_oob; i++ )
{
j = oobidx[i];
int vj = oobperm[i];
sample0 = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
for( k = 0; k < nallvars; k++ )
sample.at<float>(k) = sample0.at<float>(k);
sample.at<float>(vi) = psamples[sstep0*w->sidx[vj] + sstep1*vi];
sample0.copyTo(sample_clone); //create a copy so we don't mess up the original data
sample_clone.at<float>(vi) = psamples[sstep0*w->sidx[vj] + sstep1*vi];
double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
double val = predictTrees(Range(treeidx, treeidx+1), sample_clone, predictFlags);
if( !_isClassifier )
{
val = (val - w->ord_responses[w->sidx[j]])/max_response;
ncorrect_responses_permuted += exp( -val*val );
}
else
{
ncorrect_responses_permuted += cvRound(val) == w->cat_responses[w->sidx[j]];
}
}
varImportance[vi] += (float)(ncorrect_responses - ncorrect_responses_permuted);
}

View File

@ -1241,6 +1241,12 @@ public:
df_alpha.clear();
df_index.clear();
sv.release();
uncompressed_sv.release();
}
Mat getUncompressedSupportVectors_() const
{
return uncompressed_sv;
}
Mat getSupportVectors() const
@ -1538,6 +1544,7 @@ public:
}
optimize_linear_svm();
return true;
}
@ -1588,6 +1595,7 @@ public:
setRangeVector(df_index, df_count);
df_alpha.assign(df_count, 1.);
sv.copyTo(uncompressed_sv);
std::swap(sv, new_sv);
std::swap(decision_func, new_df);
}
@ -1822,8 +1830,8 @@ public:
}
}
params = best_params;
class_labels = class_labels0;
setParams(best_params);
return do_train( samples, responses );
}
@ -2056,6 +2064,21 @@ public:
}
fs << "]";
if ( !uncompressed_sv.empty() )
{
// write the joint collection of uncompressed support vectors
int uncompressed_sv_total = uncompressed_sv.rows;
fs << "uncompressed_sv_total" << uncompressed_sv_total;
fs << "uncompressed_support_vectors" << "[";
for( i = 0; i < uncompressed_sv_total; i++ )
{
fs << "[:";
fs.writeRaw("f", uncompressed_sv.ptr(i), uncompressed_sv.cols*uncompressed_sv.elemSize());
fs << "]";
}
fs << "]";
}
// write decision functions
int df_count = (int)decision_func.size();
@ -2096,7 +2119,7 @@ public:
svm_type_str == "NU_SVR" ? NU_SVR : -1;
if( svmType < 0 )
CV_Error( CV_StsParseError, "Missing of invalid SVM type" );
CV_Error( CV_StsParseError, "Missing or invalid SVM type" );
FileNode kernel_node = fn["kernel"];
if( kernel_node.empty() )
@ -2168,14 +2191,31 @@ public:
FileNode sv_node = fn["support_vectors"];
CV_Assert((int)sv_node.size() == sv_total);
sv.create(sv_total, var_count, CV_32F);
sv.create(sv_total, var_count, CV_32F);
FileNodeIterator sv_it = sv_node.begin();
for( i = 0; i < sv_total; i++, ++sv_it )
{
(*sv_it).readRaw("f", sv.ptr(i), var_count*sv.elemSize());
}
int uncompressed_sv_total = (int)fn["uncompressed_sv_total"];
if( uncompressed_sv_total > 0 )
{
// read uncompressed support vectors
FileNode uncompressed_sv_node = fn["uncompressed_support_vectors"];
CV_Assert((int)uncompressed_sv_node.size() == uncompressed_sv_total);
uncompressed_sv.create(uncompressed_sv_total, var_count, CV_32F);
FileNodeIterator uncompressed_sv_it = uncompressed_sv_node.begin();
for( i = 0; i < uncompressed_sv_total; i++, ++uncompressed_sv_it )
{
(*uncompressed_sv_it).readRaw("f", uncompressed_sv.ptr(i), var_count*uncompressed_sv.elemSize());
}
}
// read decision functions
int df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
FileNode df_node = fn["decision_functions"];
@ -2207,7 +2247,7 @@ public:
SvmParams params;
Mat class_labels;
int var_count;
Mat sv;
Mat sv, uncompressed_sv;
vector<DecisionFunc> decision_func;
vector<double> df_alpha;
vector<int> df_index;
@ -2221,6 +2261,14 @@ Ptr<SVM> SVM::create()
return makePtr<SVMImpl>();
}
Mat SVM::getUncompressedSupportVectors() const
{
const SVMImpl* this_ = dynamic_cast<const SVMImpl*>(this);
if(!this_)
CV_Error(Error::StsNotImplemented, "the class is not SVMImpl");
return this_->getUncompressedSupportVectors_();
}
}
}

View File

@ -118,3 +118,51 @@ TEST(ML_SVM, trainAuto_regression_5369)
EXPECT_EQ(0., result0);
EXPECT_EQ(1., result1);
}
class CV_SVMGetSupportVectorsTest : public cvtest::BaseTest {
public:
CV_SVMGetSupportVectorsTest() {}
protected:
virtual void run( int startFrom );
};
void CV_SVMGetSupportVectorsTest::run(int /*startFrom*/ )
{
int code = cvtest::TS::OK;
// Set up training data
int labels[4] = {1, -1, -1, -1};
float trainingData[4][2] = { {501, 10}, {255, 10}, {501, 255}, {10, 501} };
Mat trainingDataMat(4, 2, CV_32FC1, trainingData);
Mat labelsMat(4, 1, CV_32SC1, labels);
Ptr<SVM> svm = SVM::create();
svm->setType(SVM::C_SVC);
svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));
// Test retrieval of SVs and compressed SVs on linear SVM
svm->setKernel(SVM::LINEAR);
svm->train(trainingDataMat, cv::ml::ROW_SAMPLE, labelsMat);
Mat sv = svm->getSupportVectors();
CV_Assert(sv.rows == 1); // by default compressed SV returned
sv = svm->getUncompressedSupportVectors();
CV_Assert(sv.rows == 3);
// Test retrieval of SVs and compressed SVs on non-linear SVM
svm->setKernel(SVM::POLY);
svm->setDegree(2);
svm->train(trainingDataMat, cv::ml::ROW_SAMPLE, labelsMat);
sv = svm->getSupportVectors();
CV_Assert(sv.rows == 3);
sv = svm->getUncompressedSupportVectors();
CV_Assert(sv.rows == 0); // inapplicable for non-linear SVMs
ts->set_failed_test_info(code);
}
TEST(ML_SVM, getSupportVectors) { CV_SVMGetSupportVectorsTest test; test.safe_run(); }

View File

@ -63,7 +63,6 @@ int main(int argc, char** argv)
const double train_test_split_ratio = 0.5;
Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec);
if( data.empty() )
{
printf("ERROR: File %s can not be read\n", filename);
@ -71,6 +70,7 @@ int main(int argc, char** argv)
}
data->setTrainTestSplitRatio(train_test_split_ratio);
std::cout << "Test/Train: " << data->getNTestSamples() << "/" << data->getNTrainSamples();
printf("======DTREE=====\n");
Ptr<DTrees> dtree = DTrees::create();
@ -106,10 +106,19 @@ int main(int argc, char** argv)
rtrees->setUseSurrogates(false);
rtrees->setMaxCategories(16);
rtrees->setPriors(Mat());
rtrees->setCalculateVarImportance(false);
rtrees->setCalculateVarImportance(true);
rtrees->setActiveVarCount(0);
rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 0));
train_and_print_errs(rtrees, data);
cv::Mat ref_labels = data->getClassLabels();
cv::Mat test_data = data->getTestSampleIdx();
cv::Mat predict_labels;
rtrees->predict(data->getSamples(), predict_labels);
cv::Mat variable_importance = rtrees->getVarImportance();
std::cout << "Estimated variable importance" << std::endl;
for (int i = 0; i < variable_importance.rows; i++) {
std::cout << "Variable " << i << ": " << variable_importance.at<float>(i, 0) << std::endl;
}
return 0;
}

View File

@ -65,7 +65,7 @@ int main(int, char**)
//! [show_vectors]
thickness = 2;
lineType = 8;
Mat sv = svm->getSupportVectors();
Mat sv = svm->getUncompressedSupportVectors();
for (int i = 0; i < sv.rows; ++i)
{