integrated parallel SVM prediction; fixed warnings after meanshift integration
This commit is contained in:
@@ -9,7 +9,7 @@
|
||||
void help()
|
||||
{
|
||||
printf("\nThe sample demonstrates how to train Random Trees classifier\n"
|
||||
"(or Boosting classifier, or MLP, or Knearest, or Nbayes - see main()) using the provided dataset.\n"
|
||||
"(or Boosting classifier, or MLP, or Knearest, or Nbayes, or Support Vector Machines - see main()) using the provided dataset.\n"
|
||||
"\n"
|
||||
"We use the sample database letter-recognition.data\n"
|
||||
"from UCI Repository, here is the link:\n"
|
||||
@@ -28,7 +28,7 @@ void help()
|
||||
"The usage: letter_recog [-data <path to letter-recognition.data>] \\\n"
|
||||
" [-save <output XML file for the classifier>] \\\n"
|
||||
" [-load <XML file with the pre-trained classifier>] \\\n"
|
||||
" [-boost|-mlp|-knearest|-nbayes] # to use boost/mlp/knearest classifier instead of default Random Trees\n" );
|
||||
" [-boost|-mlp|-knearest|-nbayes|-svm] # to use boost/mlp/knearest/SVM classifier instead of default Random Trees\n" );
|
||||
}
|
||||
|
||||
// This function reads data and responses from the file <filename>
|
||||
@@ -630,6 +630,78 @@ int build_nbayes_classifier( char* data_filename )
|
||||
return 0;
|
||||
}
|
||||
|
||||
static
|
||||
int build_svm_classifier( char* data_filename )
|
||||
{
|
||||
CvMat* data = 0;
|
||||
CvMat* responses = 0;
|
||||
CvMat train_data;
|
||||
int nsamples_all = 0, ntrain_samples = 0;
|
||||
int var_count;
|
||||
CvSVM svm;
|
||||
|
||||
int ok = read_num_class_data( data_filename, 16, &data, &responses );
|
||||
if( !ok )
|
||||
{
|
||||
printf( "Could not read the database %s\n", data_filename );
|
||||
return -1;
|
||||
}
|
||||
////////// SVM parameters ///////////////////////////////
|
||||
CvSVMParams param;
|
||||
param.kernel_type=CvSVM::LINEAR;
|
||||
param.svm_type=CvSVM::C_SVC;
|
||||
param.C=1;
|
||||
///////////////////////////////////////////////////////////
|
||||
|
||||
printf( "The database %s is loaded.\n", data_filename );
|
||||
nsamples_all = data->rows;
|
||||
ntrain_samples = (int)(nsamples_all*0.1);
|
||||
var_count = data->cols;
|
||||
|
||||
// train classifier
|
||||
printf( "Training the classifier (may take a few minutes)...\n");
|
||||
cvGetRows( data, &train_data, 0, ntrain_samples );
|
||||
CvMat* train_resp = cvCreateMat( ntrain_samples, 1, CV_32FC1);
|
||||
for (int i = 0; i < ntrain_samples; i++)
|
||||
train_resp->data.fl[i] = responses->data.fl[i];
|
||||
svm.train(&train_data, train_resp, 0, 0, param);
|
||||
|
||||
// classification
|
||||
float _sample[var_count * (nsamples_all - ntrain_samples)];
|
||||
CvMat sample = cvMat( nsamples_all - ntrain_samples, 16, CV_32FC1, _sample );
|
||||
float true_results[nsamples_all - ntrain_samples];
|
||||
for (int j = ntrain_samples; j < nsamples_all; j++)
|
||||
{
|
||||
float *s = data->data.fl + j * var_count;
|
||||
|
||||
for (int i = 0; i < var_count; i++)
|
||||
{
|
||||
sample.data.fl[(j - ntrain_samples) * var_count + i] = s[i];
|
||||
}
|
||||
true_results[j - ntrain_samples] = responses->data.fl[j];
|
||||
}
|
||||
CvMat *result = cvCreateMat(1, nsamples_all - ntrain_samples, CV_32FC1);
|
||||
|
||||
printf("Classification (may take a few minutes)...\n");
|
||||
(int)svm.predict(&sample, result);
|
||||
|
||||
int true_resp = 0;
|
||||
for (int i = 0; i < nsamples_all - ntrain_samples; i++)
|
||||
{
|
||||
if (result->data.fl[i] == true_results[i])
|
||||
true_resp++;
|
||||
}
|
||||
|
||||
printf("true_resp = %f%%\n", (float)true_resp / (nsamples_all - ntrain_samples) * 100);
|
||||
|
||||
cvReleaseMat( &train_resp );
|
||||
cvReleaseMat( &result );
|
||||
cvReleaseMat( &data );
|
||||
cvReleaseMat( &responses );
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main( int argc, char *argv[] )
|
||||
{
|
||||
char* filename_to_save = 0;
|
||||
@@ -672,6 +744,10 @@ int main( int argc, char *argv[] )
|
||||
{
|
||||
method = 4;
|
||||
}
|
||||
else if ( strcmp(argv[i], "-svm") == 0)
|
||||
{
|
||||
method = 5;
|
||||
}
|
||||
else
|
||||
break;
|
||||
}
|
||||
@@ -687,6 +763,8 @@ int main( int argc, char *argv[] )
|
||||
build_knearest_classifier( data_filename, 10 ) :
|
||||
method == 4 ?
|
||||
build_nbayes_classifier( data_filename) :
|
||||
method == 5 ?
|
||||
build_svm_classifier( data_filename ):
|
||||
-1) < 0)
|
||||
{
|
||||
help();
|
||||
|
Reference in New Issue
Block a user