Merge pull request #669 from vpisarev:fast_lin_svm
This commit is contained in:
commit
321070ccf0
@ -534,6 +534,8 @@ protected:
|
|||||||
virtual void write_params( CvFileStorage* fs ) const;
|
virtual void write_params( CvFileStorage* fs ) const;
|
||||||
virtual void read_params( CvFileStorage* fs, CvFileNode* node );
|
virtual void read_params( CvFileStorage* fs, CvFileNode* node );
|
||||||
|
|
||||||
|
void optimize_linear_svm();
|
||||||
|
|
||||||
CvSVMParams params;
|
CvSVMParams params;
|
||||||
CvMat* class_labels;
|
CvMat* class_labels;
|
||||||
int var_all;
|
int var_all;
|
||||||
|
@ -1517,6 +1517,7 @@ bool CvSVM::do_train( int svm_type, int sample_count, int var_count, const float
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
optimize_linear_svm();
|
||||||
ok = true;
|
ok = true;
|
||||||
|
|
||||||
__END__;
|
__END__;
|
||||||
@ -1524,6 +1525,59 @@ bool CvSVM::do_train( int svm_type, int sample_count, int var_count, const float
|
|||||||
return ok;
|
return ok;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void CvSVM::optimize_linear_svm()
|
||||||
|
{
|
||||||
|
// we optimize only linear SVM: compress all the support vectors into one.
|
||||||
|
if( params.kernel_type != LINEAR )
|
||||||
|
return;
|
||||||
|
|
||||||
|
int class_count = class_labels ? class_labels->cols :
|
||||||
|
params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
|
||||||
|
|
||||||
|
int i, df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
|
||||||
|
CvSVMDecisionFunc* df = decision_func;
|
||||||
|
|
||||||
|
for( i = 0; i < df_count; i++ )
|
||||||
|
{
|
||||||
|
int sv_count = df[i].sv_count;
|
||||||
|
if( sv_count != 1 )
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// if every decision functions uses a single support vector;
|
||||||
|
// it's already compressed. skip it then.
|
||||||
|
if( i == df_count )
|
||||||
|
return;
|
||||||
|
|
||||||
|
int var_count = get_var_count();
|
||||||
|
int sample_size = (int)(var_count*sizeof(sv[0][0]));
|
||||||
|
float** new_sv = (float**)cvMemStorageAlloc(storage, df_count*sizeof(new_sv[0]));
|
||||||
|
|
||||||
|
for( i = 0; i < df_count; i++ )
|
||||||
|
{
|
||||||
|
new_sv[i] = (float*)cvMemStorageAlloc(storage, sample_size);
|
||||||
|
float* dst = new_sv[i];
|
||||||
|
memset(dst, 0, sample_size);
|
||||||
|
int j, k, sv_count = df[i].sv_count;
|
||||||
|
for( j = 0; j < sv_count; j++ )
|
||||||
|
{
|
||||||
|
const float* src = class_count > 1 ? sv[df[i].sv_index[j]] : sv[j];
|
||||||
|
double a = df[i].alpha[j];
|
||||||
|
for( k = 0; k < var_count; k++ )
|
||||||
|
dst[k] = (float)(dst[k] + src[k]*a);
|
||||||
|
}
|
||||||
|
df[i].sv_count = 1;
|
||||||
|
df[i].alpha[0] = 1.;
|
||||||
|
if( class_count > 1 )
|
||||||
|
df[i].sv_index[0] = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
sv = new_sv;
|
||||||
|
sv_total = df_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
bool CvSVM::train( const CvMat* _train_data, const CvMat* _responses,
|
bool CvSVM::train( const CvMat* _train_data, const CvMat* _responses,
|
||||||
const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
|
const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
|
||||||
{
|
{
|
||||||
@ -2516,6 +2570,7 @@ void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node )
|
|||||||
CV_NEXT_SEQ_ELEM( df_node->data.seq->elem_size, reader );
|
CV_NEXT_SEQ_ELEM( df_node->data.seq->elem_size, reader );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
optimize_linear_svm();
|
||||||
create_kernel();
|
create_kernel();
|
||||||
|
|
||||||
__END__;
|
__END__;
|
||||||
|
@ -131,7 +131,7 @@ int build_rtrees_classifier( char* data_filename,
|
|||||||
printf( "Could not read the classifier %s\n", filename_to_load );
|
printf( "Could not read the classifier %s\n", filename_to_load );
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
printf( "The classifier %s is loaded.\n", data_filename );
|
printf( "The classifier %s is loaded.\n", filename_to_load );
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -262,7 +262,7 @@ int build_boost_classifier( char* data_filename,
|
|||||||
printf( "Could not read the classifier %s\n", filename_to_load );
|
printf( "Could not read the classifier %s\n", filename_to_load );
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
printf( "The classifier %s is loaded.\n", data_filename );
|
printf( "The classifier %s is loaded.\n", filename_to_load );
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -403,7 +403,7 @@ int build_mlp_classifier( char* data_filename,
|
|||||||
printf( "Could not read the classifier %s\n", filename_to_load );
|
printf( "Could not read the classifier %s\n", filename_to_load );
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
printf( "The classifier %s is loaded.\n", data_filename );
|
printf( "The classifier %s is loaded.\n", filename_to_load );
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -639,10 +639,11 @@ int build_nbayes_classifier( char* data_filename )
|
|||||||
}
|
}
|
||||||
|
|
||||||
static
|
static
|
||||||
int build_svm_classifier( char* data_filename )
|
int build_svm_classifier( char* data_filename, const char* filename_to_save, const char* filename_to_load )
|
||||||
{
|
{
|
||||||
CvMat* data = 0;
|
CvMat* data = 0;
|
||||||
CvMat* responses = 0;
|
CvMat* responses = 0;
|
||||||
|
CvMat* train_resp = 0;
|
||||||
CvMat train_data;
|
CvMat train_data;
|
||||||
int nsamples_all = 0, ntrain_samples = 0;
|
int nsamples_all = 0, ntrain_samples = 0;
|
||||||
int var_count;
|
int var_count;
|
||||||
@ -666,13 +667,29 @@ int build_svm_classifier( char* data_filename )
|
|||||||
ntrain_samples = (int)(nsamples_all*0.1);
|
ntrain_samples = (int)(nsamples_all*0.1);
|
||||||
var_count = data->cols;
|
var_count = data->cols;
|
||||||
|
|
||||||
// train classifier
|
// Create or load Random Trees classifier
|
||||||
printf( "Training the classifier (may take a few minutes)...\n");
|
if( filename_to_load )
|
||||||
cvGetRows( data, &train_data, 0, ntrain_samples );
|
{
|
||||||
CvMat* train_resp = cvCreateMat( ntrain_samples, 1, CV_32FC1);
|
// load classifier from the specified file
|
||||||
for (int i = 0; i < ntrain_samples; i++)
|
svm.load( filename_to_load );
|
||||||
train_resp->data.fl[i] = responses->data.fl[i];
|
ntrain_samples = 0;
|
||||||
svm.train(&train_data, train_resp, 0, 0, param);
|
if( svm.get_var_count() == 0 )
|
||||||
|
{
|
||||||
|
printf( "Could not read the classifier %s\n", filename_to_load );
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
printf( "The classifier %s is loaded.\n", filename_to_load );
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// train classifier
|
||||||
|
printf( "Training the classifier (may take a few minutes)...\n");
|
||||||
|
cvGetRows( data, &train_data, 0, ntrain_samples );
|
||||||
|
train_resp = cvCreateMat( ntrain_samples, 1, CV_32FC1);
|
||||||
|
for (int i = 0; i < ntrain_samples; i++)
|
||||||
|
train_resp->data.fl[i] = responses->data.fl[i];
|
||||||
|
svm.train(&train_data, train_resp, 0, 0, param);
|
||||||
|
}
|
||||||
|
|
||||||
// classification
|
// classification
|
||||||
std::vector<float> _sample(var_count * (nsamples_all - ntrain_samples));
|
std::vector<float> _sample(var_count * (nsamples_all - ntrain_samples));
|
||||||
@ -691,7 +708,10 @@ int build_svm_classifier( char* data_filename )
|
|||||||
CvMat *result = cvCreateMat(1, nsamples_all - ntrain_samples, CV_32FC1);
|
CvMat *result = cvCreateMat(1, nsamples_all - ntrain_samples, CV_32FC1);
|
||||||
|
|
||||||
printf("Classification (may take a few minutes)...\n");
|
printf("Classification (may take a few minutes)...\n");
|
||||||
|
double t = (double)cvGetTickCount();
|
||||||
svm.predict(&sample, result);
|
svm.predict(&sample, result);
|
||||||
|
t = (double)cvGetTickCount() - t;
|
||||||
|
printf("Prediction type: %gms\n", t/(cvGetTickFrequency()*1000.));
|
||||||
|
|
||||||
int true_resp = 0;
|
int true_resp = 0;
|
||||||
for (int i = 0; i < nsamples_all - ntrain_samples; i++)
|
for (int i = 0; i < nsamples_all - ntrain_samples; i++)
|
||||||
@ -702,6 +722,9 @@ int build_svm_classifier( char* data_filename )
|
|||||||
|
|
||||||
printf("true_resp = %f%%\n", (float)true_resp / (nsamples_all - ntrain_samples) * 100);
|
printf("true_resp = %f%%\n", (float)true_resp / (nsamples_all - ntrain_samples) * 100);
|
||||||
|
|
||||||
|
if( filename_to_save )
|
||||||
|
svm.save( filename_to_save );
|
||||||
|
|
||||||
cvReleaseMat( &train_resp );
|
cvReleaseMat( &train_resp );
|
||||||
cvReleaseMat( &result );
|
cvReleaseMat( &result );
|
||||||
cvReleaseMat( &data );
|
cvReleaseMat( &data );
|
||||||
@ -772,7 +795,7 @@ int main( int argc, char *argv[] )
|
|||||||
method == 4 ?
|
method == 4 ?
|
||||||
build_nbayes_classifier( data_filename) :
|
build_nbayes_classifier( data_filename) :
|
||||||
method == 5 ?
|
method == 5 ?
|
||||||
build_svm_classifier( data_filename ):
|
build_svm_classifier( data_filename, filename_to_save, filename_to_load ):
|
||||||
-1) < 0)
|
-1) < 0)
|
||||||
{
|
{
|
||||||
help();
|
help();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user