Merge pull request #687 from vpisarev:fast_lin_svm2
This commit is contained in:
		@@ -1551,25 +1551,28 @@ void CvSVM::optimize_linear_svm()
 | 
				
			|||||||
        return;
 | 
					        return;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int var_count = get_var_count();
 | 
					    int var_count = get_var_count();
 | 
				
			||||||
    int sample_size = (int)(var_count*sizeof(sv[0][0]));
 | 
					    cv::AutoBuffer<double> vbuf(var_count);
 | 
				
			||||||
 | 
					    double* v = vbuf;
 | 
				
			||||||
    float** new_sv = (float**)cvMemStorageAlloc(storage, df_count*sizeof(new_sv[0]));
 | 
					    float** new_sv = (float**)cvMemStorageAlloc(storage, df_count*sizeof(new_sv[0]));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for( i = 0; i < df_count; i++ )
 | 
					    for( i = 0; i < df_count; i++ )
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
        new_sv[i] = (float*)cvMemStorageAlloc(storage, sample_size);
 | 
					        new_sv[i] = (float*)cvMemStorageAlloc(storage, var_count*sizeof(new_sv[i][0]));
 | 
				
			||||||
        float* dst = new_sv[i];
 | 
					        float* dst = new_sv[i];
 | 
				
			||||||
        memset(dst, 0, sample_size);
 | 
					        memset(v, 0, var_count*sizeof(v[0]));
 | 
				
			||||||
        int j, k, sv_count = df[i].sv_count;
 | 
					        int j, k, sv_count = df[i].sv_count;
 | 
				
			||||||
        for( j = 0; j < sv_count; j++ )
 | 
					        for( j = 0; j < sv_count; j++ )
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            const float* src = class_count > 1 ? sv[df[i].sv_index[j]] : sv[j];
 | 
					            const float* src = class_count > 1 && df[i].sv_index ? sv[df[i].sv_index[j]] : sv[j];
 | 
				
			||||||
            double a = df[i].alpha[j];
 | 
					            double a = df[i].alpha[j];
 | 
				
			||||||
            for( k = 0; k < var_count; k++ )
 | 
					            for( k = 0; k < var_count; k++ )
 | 
				
			||||||
                dst[k] = (float)(dst[k] + src[k]*a);
 | 
					                v[k] += src[k]*a;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					        for( k = 0; k < var_count; k++ )
 | 
				
			||||||
 | 
					            dst[k] = (float)v[k];
 | 
				
			||||||
        df[i].sv_count = 1;
 | 
					        df[i].sv_count = 1;
 | 
				
			||||||
        df[i].alpha[0] = 1.;
 | 
					        df[i].alpha[0] = 1.;
 | 
				
			||||||
        if( class_count > 1 )
 | 
					        if( class_count > 1 && df[i].sv_index )
 | 
				
			||||||
            df[i].sv_index[0] = i;
 | 
					            df[i].sv_index[0] = i;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2570,6 +2573,7 @@ void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node )
 | 
				
			|||||||
        CV_NEXT_SEQ_ELEM( df_node->data.seq->elem_size, reader );
 | 
					        CV_NEXT_SEQ_ELEM( df_node->data.seq->elem_size, reader );
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if( cvReadIntByName(fs, svm_node, "optimize_linear", 1) != 0 )
 | 
				
			||||||
        optimize_linear_svm();
 | 
					        optimize_linear_svm();
 | 
				
			||||||
    create_kernel();
 | 
					    create_kernel();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -769,7 +769,11 @@ void CV_MLBaseTest::load( const char* filename )
 | 
				
			|||||||
    else if( !modelName.compare(CV_KNEAREST) )
 | 
					    else if( !modelName.compare(CV_KNEAREST) )
 | 
				
			||||||
        knearest->load( filename );
 | 
					        knearest->load( filename );
 | 
				
			||||||
    else if( !modelName.compare(CV_SVM) )
 | 
					    else if( !modelName.compare(CV_SVM) )
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        delete svm;
 | 
				
			||||||
 | 
					        svm = new CvSVM;
 | 
				
			||||||
        svm->load( filename );
 | 
					        svm->load( filename );
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
    else if( !modelName.compare(CV_ANN) )
 | 
					    else if( !modelName.compare(CV_ANN) )
 | 
				
			||||||
        ann->load( filename );
 | 
					        ann->load( filename );
 | 
				
			||||||
    else if( !modelName.compare(CV_DTREE) )
 | 
					    else if( !modelName.compare(CV_DTREE) )
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -82,32 +82,53 @@ int CV_SLMLTest::validate_test_results( int testCaseIdx )
 | 
				
			|||||||
    int code = cvtest::TS::OK;
 | 
					    int code = cvtest::TS::OK;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // 1. compare files
 | 
					    // 1. compare files
 | 
				
			||||||
    ifstream f1( fname1.c_str() ), f2( fname2.c_str() );
 | 
					    FILE *fs1 = fopen(fname1.c_str(), "rb"), *fs2 = fopen(fname2.c_str(), "rb");
 | 
				
			||||||
    string s1, s2;
 | 
					    size_t sz1 = 0, sz2 = 0;
 | 
				
			||||||
    int lineIdx = 0;
 | 
					    if( !fs1 || !fs2 )
 | 
				
			||||||
    CV_Assert( f1.is_open() && f2.is_open() );
 | 
					        code = cvtest::TS::FAIL_MISSING_TEST_DATA;
 | 
				
			||||||
    for( ; !f1.eof() && !f2.eof(); lineIdx++ )
 | 
					    if( code >= 0 )
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
        getline( f1, s1 );
 | 
					        fseek(fs1, 0, SEEK_END); fseek(fs2, 0, SEEK_END);
 | 
				
			||||||
        getline( f2, s2 );
 | 
					        sz1 = ftell(fs1);
 | 
				
			||||||
        if( s1.compare(s2) )
 | 
					        sz2 = ftell(fs2);
 | 
				
			||||||
        {
 | 
					        fseek(fs1, 0, SEEK_SET); fseek(fs2, 0, SEEK_SET);
 | 
				
			||||||
            ts->printf( cvtest::TS::LOG, "first and second saved files differ in %n-line; first %n line: %s; second %n-line: %s",
 | 
					    }
 | 
				
			||||||
               lineIdx, lineIdx, s1.c_str(), lineIdx, s2.c_str() );
 | 
					
 | 
				
			||||||
 | 
					    if( sz1 != sz2 )
 | 
				
			||||||
        code = cvtest::TS::FAIL_INVALID_OUTPUT;
 | 
					        code = cvtest::TS::FAIL_INVALID_OUTPUT;
 | 
				
			||||||
        }
 | 
					
 | 
				
			||||||
    }
 | 
					    if( code >= 0 )
 | 
				
			||||||
    if( !f1.eof() || !f2.eof() )
 | 
					 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
        ts->printf( cvtest::TS::LOG, "in test case %d first and second saved files differ in %n-line; first %n line: %s; second %n-line: %s",
 | 
					        const int BUFSZ = 1024;
 | 
				
			||||||
            testCaseIdx, lineIdx, lineIdx, s1.c_str(), lineIdx, s2.c_str() );
 | 
					        uchar buf1[BUFSZ], buf2[BUFSZ];
 | 
				
			||||||
 | 
					        for( size_t pos = 0; pos < sz1;  )
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            size_t r1 = fread(buf1, 1, BUFSZ, fs1);
 | 
				
			||||||
 | 
					            size_t r2 = fread(buf2, 1, BUFSZ, fs2);
 | 
				
			||||||
 | 
					            if( r1 != r2 || memcmp(buf1, buf2, r1) != 0 )
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                ts->printf( cvtest::TS::LOG,
 | 
				
			||||||
 | 
					                           "in test case %d first (%s) and second (%s) saved files differ in %d-th kb\n",
 | 
				
			||||||
 | 
					                           testCaseIdx, fname1.c_str(), fname2.c_str(),
 | 
				
			||||||
 | 
					                           (int)pos );
 | 
				
			||||||
                code = cvtest::TS::FAIL_INVALID_OUTPUT;
 | 
					                code = cvtest::TS::FAIL_INVALID_OUTPUT;
 | 
				
			||||||
 | 
					                break;
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
    f1.close();
 | 
					            pos += r1;
 | 
				
			||||||
    f2.close();
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if(fs1)
 | 
				
			||||||
 | 
					        fclose(fs1);
 | 
				
			||||||
 | 
					    if(fs2)
 | 
				
			||||||
 | 
					        fclose(fs2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // delete temporary files
 | 
					    // delete temporary files
 | 
				
			||||||
 | 
					    if( code >= 0 )
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
        remove( fname1.c_str() );
 | 
					        remove( fname1.c_str() );
 | 
				
			||||||
        remove( fname2.c_str() );
 | 
					        remove( fname2.c_str() );
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // 2. compare responses
 | 
					    // 2. compare responses
 | 
				
			||||||
    CV_Assert( test_resps1.size() == test_resps2.size() );
 | 
					    CV_Assert( test_resps1.size() == test_resps2.size() );
 | 
				
			||||||
@@ -133,4 +154,32 @@ TEST(ML_Boost, save_load) { CV_SLMLTest test( CV_BOOST ); test.safe_run(); }
 | 
				
			|||||||
TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); }
 | 
					TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); }
 | 
				
			||||||
TEST(ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); }
 | 
					TEST(ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(DISABLED_ML_SVM, linear_save_load)
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					    CvSVM svm1, svm2, svm3;
 | 
				
			||||||
 | 
					    svm1.load("SVM45_X_38-1.xml");
 | 
				
			||||||
 | 
					    svm2.load("SVM45_X_38-2.xml");
 | 
				
			||||||
 | 
					    string tname = tempfile("a.xml");
 | 
				
			||||||
 | 
					    svm2.save(tname.c_str());
 | 
				
			||||||
 | 
					    svm3.load(tname.c_str());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ASSERT_EQ(svm1.get_var_count(), svm2.get_var_count());
 | 
				
			||||||
 | 
					    ASSERT_EQ(svm1.get_var_count(), svm3.get_var_count());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int m = 10000, n = svm1.get_var_count();
 | 
				
			||||||
 | 
					    Mat samples(m, n, CV_32F), r1, r2, r3;
 | 
				
			||||||
 | 
					    randu(samples, 0., 1.);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    svm1.predict(samples, r1);
 | 
				
			||||||
 | 
					    svm2.predict(samples, r2);
 | 
				
			||||||
 | 
					    svm3.predict(samples, r3);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    double eps = 1e-4;
 | 
				
			||||||
 | 
					    EXPECT_LE(norm(r1, r2, NORM_INF), eps);
 | 
				
			||||||
 | 
					    EXPECT_LE(norm(r1, r3, NORM_INF), eps);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    remove(tname.c_str());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* End of file. */
 | 
					/* End of file. */
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user