Return uncompressed support vectors for getSupportVectors on linear SVM (Bug #4096)
This commit is contained in:

committed by
Vadim Pisarevsky

parent
544990e377
commit
0d706f6796
@@ -1241,6 +1241,12 @@ public:
|
||||
df_alpha.clear();
|
||||
df_index.clear();
|
||||
sv.release();
|
||||
uncompressed_sv.release();
|
||||
}
|
||||
|
||||
Mat getUncompressedSupportVectors_() const
|
||||
{
|
||||
return uncompressed_sv;
|
||||
}
|
||||
|
||||
Mat getSupportVectors() const
|
||||
@@ -1538,6 +1544,7 @@ public:
|
||||
}
|
||||
|
||||
optimize_linear_svm();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1588,6 +1595,7 @@ public:
|
||||
|
||||
setRangeVector(df_index, df_count);
|
||||
df_alpha.assign(df_count, 1.);
|
||||
sv.copyTo(uncompressed_sv);
|
||||
std::swap(sv, new_sv);
|
||||
std::swap(decision_func, new_df);
|
||||
}
|
||||
@@ -2056,6 +2064,21 @@ public:
|
||||
}
|
||||
fs << "]";
|
||||
|
||||
if ( !uncompressed_sv.empty() )
|
||||
{
|
||||
// write the joint collection of uncompressed support vectors
|
||||
int uncompressed_sv_total = uncompressed_sv.rows;
|
||||
fs << "uncompressed_sv_total" << uncompressed_sv_total;
|
||||
fs << "uncompressed_support_vectors" << "[";
|
||||
for( i = 0; i < uncompressed_sv_total; i++ )
|
||||
{
|
||||
fs << "[:";
|
||||
fs.writeRaw("f", uncompressed_sv.ptr(i), uncompressed_sv.cols*uncompressed_sv.elemSize());
|
||||
fs << "]";
|
||||
}
|
||||
fs << "]";
|
||||
}
|
||||
|
||||
// write decision functions
|
||||
int df_count = (int)decision_func.size();
|
||||
|
||||
@@ -2096,7 +2119,7 @@ public:
|
||||
svm_type_str == "NU_SVR" ? NU_SVR : -1;
|
||||
|
||||
if( svmType < 0 )
|
||||
CV_Error( CV_StsParseError, "Missing of invalid SVM type" );
|
||||
CV_Error( CV_StsParseError, "Missing or invalid SVM type" );
|
||||
|
||||
FileNode kernel_node = fn["kernel"];
|
||||
if( kernel_node.empty() )
|
||||
@@ -2168,14 +2191,31 @@ public:
|
||||
FileNode sv_node = fn["support_vectors"];
|
||||
|
||||
CV_Assert((int)sv_node.size() == sv_total);
|
||||
sv.create(sv_total, var_count, CV_32F);
|
||||
|
||||
sv.create(sv_total, var_count, CV_32F);
|
||||
FileNodeIterator sv_it = sv_node.begin();
|
||||
for( i = 0; i < sv_total; i++, ++sv_it )
|
||||
{
|
||||
(*sv_it).readRaw("f", sv.ptr(i), var_count*sv.elemSize());
|
||||
}
|
||||
|
||||
int uncompressed_sv_total = (int)fn["uncompressed_sv_total"];
|
||||
|
||||
if( uncompressed_sv_total > 0 )
|
||||
{
|
||||
// read uncompressed support vectors
|
||||
FileNode uncompressed_sv_node = fn["uncompressed_support_vectors"];
|
||||
|
||||
CV_Assert((int)uncompressed_sv_node.size() == uncompressed_sv_total);
|
||||
uncompressed_sv.create(uncompressed_sv_total, var_count, CV_32F);
|
||||
|
||||
FileNodeIterator uncompressed_sv_it = uncompressed_sv_node.begin();
|
||||
for( i = 0; i < uncompressed_sv_total; i++, ++uncompressed_sv_it )
|
||||
{
|
||||
(*uncompressed_sv_it).readRaw("f", uncompressed_sv.ptr(i), var_count*uncompressed_sv.elemSize());
|
||||
}
|
||||
}
|
||||
|
||||
// read decision functions
|
||||
int df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
|
||||
FileNode df_node = fn["decision_functions"];
|
||||
@@ -2207,7 +2247,7 @@ public:
|
||||
SvmParams params;
|
||||
Mat class_labels;
|
||||
int var_count;
|
||||
Mat sv;
|
||||
Mat sv, uncompressed_sv;
|
||||
vector<DecisionFunc> decision_func;
|
||||
vector<double> df_alpha;
|
||||
vector<int> df_index;
|
||||
@@ -2221,6 +2261,14 @@ Ptr<SVM> SVM::create()
|
||||
return makePtr<SVMImpl>();
|
||||
}
|
||||
|
||||
Mat SVM::getUncompressedSupportVectors() const
|
||||
{
|
||||
const SVMImpl* this_ = dynamic_cast<const SVMImpl*>(this);
|
||||
if(!this_)
|
||||
CV_Error(Error::StsNotImplemented, "the class is not SVMImpl");
|
||||
return this_->getUncompressedSupportVectors_();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user