#include "test_precomp.hpp"

#include <string>
#include <fstream>
#include <iostream>

using namespace std;


class CV_GBTreesTest : public cvtest::BaseTest
{
public:
    CV_GBTreesTest();
    ~CV_GBTreesTest();

protected:
    void run(int);

    int TestTrainPredict(int test_num);
    int TestSaveLoad();

    int checkPredictError(int test_num);
    int checkLoadSave();

    string model_file_name1;
    string model_file_name2;

    string* datasets;
    string data_path;

    CvMLData* data;
    CvGBTrees* gtb;

    vector<float> test_resps1;
    vector<float> test_resps2;

    int64 initSeed;
};


int _get_len(const CvMat* mat)
{
    return (mat->cols > mat->rows) ? mat->cols : mat->rows;
}


CV_GBTreesTest::CV_GBTreesTest()
{
    int64 seeds[] = { CV_BIG_INT(0x00009fff4f9c8d52),
                      CV_BIG_INT(0x0000a17166072c7c),
                      CV_BIG_INT(0x0201b32115cd1f9a),
                      CV_BIG_INT(0x0513cb37abcd1234),
                      CV_BIG_INT(0x0001a2b3c4d5f678)
                    };

    int seedCount = sizeof(seeds)/sizeof(seeds[0]);
    cv::RNG& rng = cv::theRNG();
    initSeed = rng.state;
    rng.state = seeds[rng(seedCount)];

    datasets = 0;
    data = 0;
    gtb = 0;
}

CV_GBTreesTest::~CV_GBTreesTest()
{
    if (data)
        delete data;
    delete[] datasets;
    cv::theRNG().state = initSeed;
}


int CV_GBTreesTest::TestTrainPredict(int test_num)
{
    int code = cvtest::TS::OK;

    int weak_count = 200;
    float shrinkage = 0.1f;
    float subsample_portion = 0.5f;
    int max_depth = 5;
    bool use_surrogates = false;
    int loss_function_type = 0;
    switch (test_num)
    {
        case (1) : loss_function_type = CvGBTrees::SQUARED_LOSS; break;
        case (2) : loss_function_type = CvGBTrees::ABSOLUTE_LOSS; break;
        case (3) : loss_function_type = CvGBTrees::HUBER_LOSS; break;
        case (0) : loss_function_type = CvGBTrees::DEVIANCE_LOSS; break;
        default  :
            {
            ts->printf( cvtest::TS::LOG, "Bad test_num value in CV_GBTreesTest::TestTrainPredict(..) function." );
            return cvtest::TS::FAIL_BAD_ARG_CHECK;
            }
    }

    int dataset_num = test_num == 0 ? 0 : 1;
    if (!data)
    {
        data = new CvMLData();
        data->set_delimiter(',');

        if (data->read_csv(datasets[dataset_num].c_str()))
        {
            ts->printf( cvtest::TS::LOG, "File reading error." );
            return cvtest::TS::FAIL_INVALID_TEST_DATA;
        }

        if (test_num == 0)
        {
            data->set_response_idx(57);
            data->set_var_types("ord[0-56],cat[57]");
        }
        else
        {
            data->set_response_idx(13);
            data->set_var_types("ord[0-2,4-13],cat[3]");
            subsample_portion = 0.7f;
        }

        int train_sample_count = cvFloor(_get_len(data->get_responses())*0.5f);
        CvTrainTestSplit spl( train_sample_count );
        data->set_train_test_split( &spl );
    }

    data->mix_train_and_test_idx();


    if (gtb) delete gtb;
    gtb = new CvGBTrees();
    bool tmp_code = true;
    tmp_code = gtb->train(data, CvGBTreesParams(loss_function_type, weak_count,
                          shrinkage, subsample_portion,
                          max_depth, use_surrogates));

    if (!tmp_code)
    {
        ts->printf( cvtest::TS::LOG, "Model training was failed.");
        return cvtest::TS::FAIL_INVALID_OUTPUT;
    }

    code = checkPredictError(test_num);

    return code;

}


int CV_GBTreesTest::checkPredictError(int test_num)
{
    if (!gtb)
        return cvtest::TS::FAIL_GENERIC;

    //float mean[] = {5.430247f, 13.5654f, 12.6569f, 13.1661f};
    //float sigma[] = {0.4162694f, 3.21161f, 3.43297f, 3.00624f};
    float mean[] = {5.80226f, 12.68689f, 13.49095f, 13.19628f};
    float sigma[] = {0.4764534f, 3.166919f, 3.022405f, 2.868722f};

    float current_error = gtb->calc_error(data, CV_TEST_ERROR);

    if ( abs( current_error - mean[test_num]) > 6*sigma[test_num] )
    {
        ts->printf( cvtest::TS::LOG, "Test error is out of range:\n"
                    "abs(%f/*curEr*/ - %f/*mean*/ > %f/*6*sigma*/",
                    current_error, mean[test_num], 6*sigma[test_num] );
        return cvtest::TS::FAIL_BAD_ACCURACY;
    }

    return cvtest::TS::OK;

}


int CV_GBTreesTest::TestSaveLoad()
{
    if (!gtb)
        return cvtest::TS::FAIL_GENERIC;

    model_file_name1 = cv::tempfile();
    model_file_name2 = cv::tempfile();

    gtb->save(model_file_name1.c_str());
    gtb->calc_error(data, CV_TEST_ERROR, &test_resps1);
    gtb->load(model_file_name1.c_str());
    gtb->calc_error(data, CV_TEST_ERROR, &test_resps2);
    gtb->save(model_file_name2.c_str());

    return checkLoadSave();

}



int CV_GBTreesTest::checkLoadSave()
{
    int code = cvtest::TS::OK;

    // 1. compare files
    ifstream f1( model_file_name1.c_str() ), f2( model_file_name2.c_str() );
    string s1, s2;
    int lineIdx = 0;
    CV_Assert( f1.is_open() && f2.is_open() );
    for( ; !f1.eof() && !f2.eof(); lineIdx++ )
    {
        getline( f1, s1 );
        getline( f2, s2 );
        if( s1.compare(s2) )
        {
            ts->printf( cvtest::TS::LOG, "first and second saved files differ in %n-line; first %n line: %s; second %n-line: %s",
               lineIdx, lineIdx, s1.c_str(), lineIdx, s2.c_str() );
            code = cvtest::TS::FAIL_INVALID_OUTPUT;
        }
    }
    if( !f1.eof() || !f2.eof() )
    {
        ts->printf( cvtest::TS::LOG, "First and second saved files differ in %n-line; first %n line: %s; second %n-line: %s",
            lineIdx, lineIdx, s1.c_str(), lineIdx, s2.c_str() );
        code = cvtest::TS::FAIL_INVALID_OUTPUT;
    }
    f1.close();
    f2.close();
    // delete temporary files
    remove( model_file_name1.c_str() );
    remove( model_file_name2.c_str() );

    // 2. compare responses
    CV_Assert( test_resps1.size() == test_resps2.size() );
    vector<float>::const_iterator it1 = test_resps1.begin(), it2 = test_resps2.begin();
    for( ; it1 != test_resps1.end(); ++it1, ++it2 )
    {
        if( fabs(*it1 - *it2) > FLT_EPSILON )
        {
            ts->printf( cvtest::TS::LOG, "Responses predicted before saving and after loading are different" );
            code = cvtest::TS::FAIL_INVALID_OUTPUT;
        }
    }
    return code;
}



void CV_GBTreesTest::run(int)
{

    string dataPath = string(ts->get_data_path());
    datasets = new string[2];
    datasets[0] = dataPath + string("spambase.data"); /*string("dataset_classification.csv");*/
    datasets[1] = dataPath + string("housing_.data");  /*string("dataset_regression.csv");*/

    int code = cvtest::TS::OK;

    for (int i = 0; i < 4; i++)
    {

        int temp_code = TestTrainPredict(i);
        if (temp_code != cvtest::TS::OK)
        {
            code = temp_code;
            break;
        }

        else if (i==0)
        {
            temp_code = TestSaveLoad();
            if (temp_code != cvtest::TS::OK)
                code = temp_code;
            delete data;
            data = 0;
        }

        delete gtb;
        gtb = 0;
    }
    delete data;
    data = 0;

    ts->set_failed_test_info( code );
}

/////////////////////////////////////////////////////////////////////////////
//////////////////// test registration  /////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////

TEST(ML_GBTrees, regression) { CV_GBTreesTest test; test.safe_run(); }