diff --git a/samples/cpp/letter_recog.cpp b/samples/cpp/letter_recog.cpp index 144dbe836..74d5971ca 100644 --- a/samples/cpp/letter_recog.cpp +++ b/samples/cpp/letter_recog.cpp @@ -131,7 +131,7 @@ int build_rtrees_classifier( char* data_filename, printf( "Could not read the classifier %s\n", filename_to_load ); return -1; } - printf( "The classifier %s is loaded.\n", data_filename ); + printf( "The classifier %s is loaded.\n", filename_to_load ); } else { @@ -262,7 +262,7 @@ int build_boost_classifier( char* data_filename, printf( "Could not read the classifier %s\n", filename_to_load ); return -1; } - printf( "The classifier %s is loaded.\n", data_filename ); + printf( "The classifier %s is loaded.\n", filename_to_load ); } else { @@ -403,7 +403,7 @@ int build_mlp_classifier( char* data_filename, printf( "Could not read the classifier %s\n", filename_to_load ); return -1; } - printf( "The classifier %s is loaded.\n", data_filename ); + printf( "The classifier %s is loaded.\n", filename_to_load ); } else { @@ -639,10 +639,11 @@ int build_nbayes_classifier( char* data_filename ) } static -int build_svm_classifier( char* data_filename ) +int build_svm_classifier( char* data_filename, const char* filename_to_save, const char* filename_to_load ) { CvMat* data = 0; CvMat* responses = 0; + CvMat* train_resp = 0; CvMat train_data; int nsamples_all = 0, ntrain_samples = 0; int var_count; @@ -666,13 +667,29 @@ int build_svm_classifier( char* data_filename ) ntrain_samples = (int)(nsamples_all*0.1); var_count = data->cols; - // train classifier - printf( "Training the classifier (may take a few minutes)...\n"); - cvGetRows( data, &train_data, 0, ntrain_samples ); - CvMat* train_resp = cvCreateMat( ntrain_samples, 1, CV_32FC1); - for (int i = 0; i < ntrain_samples; i++) - train_resp->data.fl[i] = responses->data.fl[i]; - svm.train(&train_data, train_resp, 0, 0, param); + // Create or load Random Trees classifier + if( filename_to_load ) + { + // load classifier from the specified file + svm.load( filename_to_load ); + ntrain_samples = 0; + if( svm.get_var_count() == 0 ) + { + printf( "Could not read the classifier %s\n", filename_to_load ); + return -1; + } + printf( "The classifier %s is loaded.\n", filename_to_load ); + } + else + { + // train classifier + printf( "Training the classifier (may take a few minutes)...\n"); + cvGetRows( data, &train_data, 0, ntrain_samples ); + train_resp = cvCreateMat( ntrain_samples, 1, CV_32FC1); + for (int i = 0; i < ntrain_samples; i++) + train_resp->data.fl[i] = responses->data.fl[i]; + svm.train(&train_data, train_resp, 0, 0, param); + } // classification std::vector _sample(var_count * (nsamples_all - ntrain_samples)); @@ -705,6 +722,9 @@ int build_svm_classifier( char* data_filename ) printf("true_resp = %f%%\n", (float)true_resp / (nsamples_all - ntrain_samples) * 100); + if( filename_to_save ) + svm.save( filename_to_save ); + cvReleaseMat( &train_resp ); cvReleaseMat( &result ); cvReleaseMat( &data ); @@ -775,7 +795,7 @@ int main( int argc, char *argv[] ) method == 4 ? build_nbayes_classifier( data_filename) : method == 5 ? - build_svm_classifier( data_filename ): + build_svm_classifier( data_filename, filename_to_save, filename_to_load ): -1) < 0) { help();