From 989be02562c55e38ec6679ce2cbac6b63ef2f8c8 Mon Sep 17 00:00:00 2001 From: Maria Dimashova Date: Fri, 23 Aug 2013 18:13:10 +0400 Subject: [PATCH] fixed cpp wrappers of ML tree-based models --- modules/ml/include/opencv2/ml.hpp | 6 ++++++ modules/ml/src/boost.cpp | 11 ++++++++--- modules/ml/src/ertrees.cpp | 11 ++++++++--- modules/ml/src/rtrees.cpp | 11 ++++++++--- modules/ml/src/tree.cpp | 11 ++++++++--- 5 files changed, 38 insertions(+), 12 deletions(-) diff --git a/modules/ml/include/opencv2/ml.hpp b/modules/ml/include/opencv2/ml.hpp index 7a334f264..5e7871da3 100644 --- a/modules/ml/include/opencv2/ml.hpp +++ b/modules/ml/include/opencv2/ml.hpp @@ -942,6 +942,8 @@ protected: CvDTreeNode* root; CvMat* var_importance; CvDTreeTrainData* data; + CvMat train_data_hdr, responses_hdr; + cv::Mat train_data_mat, responses_mat; public: int pruned_tree_idx; @@ -1053,6 +1055,8 @@ protected: // array of the trees of the forest CvForestTree** trees; CvDTreeTrainData* data; + CvMat train_data_hdr, responses_hdr; + cv::Mat train_data_mat, responses_mat; int ntrees; int nclasses; double oob_error; @@ -1268,6 +1272,8 @@ protected: virtual void initialize_weights(double (&p)[2]); CvDTreeTrainData* data; + CvMat train_data_hdr, responses_hdr; + cv::Mat train_data_mat, responses_mat; CvBoostParams params; CvSeq* weak; diff --git a/modules/ml/src/boost.cpp b/modules/ml/src/boost.cpp index 53c194f3c..a22e13a53 100644 --- a/modules/ml/src/boost.cpp +++ b/modules/ml/src/boost.cpp @@ -2122,9 +2122,14 @@ CvBoost::train( const Mat& _train_data, int _tflag, const Mat& _missing_mask, CvBoostParams _params, bool _update ) { - CvMat tdata = _train_data, responses = _responses, vidx = _var_idx, - sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask; - return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0, + train_data_hdr = _train_data; + train_data_mat = _train_data; + responses_hdr = _responses; + responses_mat = _responses; + + CvMat vidx = _var_idx, sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask; + + return train(&train_data_hdr, _tflag, &responses_hdr, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0, mmask.data.ptr ? &mmask : 0, _params, _update); } diff --git a/modules/ml/src/ertrees.cpp b/modules/ml/src/ertrees.cpp index e379ed4b1..88f7374fa 100644 --- a/modules/ml/src/ertrees.cpp +++ b/modules/ml/src/ertrees.cpp @@ -1844,9 +1844,14 @@ bool CvERTrees::train( const Mat& _train_data, int _tflag, const Mat& _sample_idx, const Mat& _var_type, const Mat& _missing_mask, CvRTParams params ) { - CvMat tdata = _train_data, responses = _responses, vidx = _var_idx, - sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask; - return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0, + train_data_hdr = _train_data; + train_data_mat = _train_data; + responses_hdr = _responses; + responses_mat = _responses; + + CvMat vidx = _var_idx, sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask; + + return train(&train_data_hdr, _tflag, &responses_hdr, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0, mmask.data.ptr ? &mmask : 0, params); } diff --git a/modules/ml/src/rtrees.cpp b/modules/ml/src/rtrees.cpp index 7947b062f..fcb1baf6b 100644 --- a/modules/ml/src/rtrees.cpp +++ b/modules/ml/src/rtrees.cpp @@ -839,9 +839,14 @@ bool CvRTrees::train( const Mat& _train_data, int _tflag, const Mat& _sample_idx, const Mat& _var_type, const Mat& _missing_mask, CvRTParams _params ) { - CvMat tdata = _train_data, responses = _responses, vidx = _var_idx, - sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask; - return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0, + train_data_hdr = _train_data; + train_data_mat = _train_data; + responses_hdr = _responses; + responses_mat = _responses; + + CvMat vidx = _var_idx, sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask; + + return train(&train_data_hdr, _tflag, &responses_hdr, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0, mmask.data.ptr ? &mmask : 0, _params); } diff --git a/modules/ml/src/tree.cpp b/modules/ml/src/tree.cpp index e8072958f..d195385d7 100644 --- a/modules/ml/src/tree.cpp +++ b/modules/ml/src/tree.cpp @@ -1594,9 +1594,14 @@ bool CvDTree::train( const Mat& _train_data, int _tflag, const Mat& _sample_idx, const Mat& _var_type, const Mat& _missing_mask, CvDTreeParams _params ) { - CvMat tdata = _train_data, responses = _responses, vidx=_var_idx, - sidx=_sample_idx, vtype=_var_type, mmask=_missing_mask; - return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0, + train_data_hdr = _train_data; + train_data_mat = _train_data; + responses_hdr = _responses; + responses_mat = _responses; + + CvMat vidx=_var_idx, sidx=_sample_idx, vtype=_var_type, mmask=_missing_mask; + + return train(&train_data_hdr, _tflag, &responses_hdr, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0, mmask.data.ptr ? &mmask : 0, _params); }