Updated ml module interfaces and documentation

This commit is contained in:
Maksim Shabunin
2015-02-11 13:24:14 +03:00
parent da383e65e2
commit 79e8f0680c
32 changed files with 1403 additions and 1528 deletions

View File

@@ -120,6 +120,91 @@ namespace ml
return termCrit;
}
struct TreeParams
{
TreeParams();
TreeParams( int maxDepth, int minSampleCount,
double regressionAccuracy, bool useSurrogates,
int maxCategories, int CVFolds,
bool use1SERule, bool truncatePrunedTree,
const Mat& priors );
inline void setMaxCategories(int val)
{
if( val < 2 )
CV_Error( CV_StsOutOfRange, "max_categories should be >= 2" );
maxCategories = std::min(val, 15 );
}
inline void setMaxDepth(int val)
{
if( val < 0 )
CV_Error( CV_StsOutOfRange, "max_depth should be >= 0" );
maxDepth = std::min( val, 25 );
}
inline void setMinSampleCount(int val)
{
minSampleCount = std::max(val, 1);
}
inline void setCVFolds(int val)
{
if( val < 0 )
CV_Error( CV_StsOutOfRange,
"params.CVFolds should be =0 (the tree is not pruned) "
"or n>0 (tree is pruned using n-fold cross-validation)" );
if( val == 1 )
val = 0;
CVFolds = val;
}
inline void setRegressionAccuracy(float val)
{
if( val < 0 )
CV_Error( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
regressionAccuracy = val;
}
inline int getMaxCategories() const { return maxCategories; }
inline int getMaxDepth() const { return maxDepth; }
inline int getMinSampleCount() const { return minSampleCount; }
inline int getCVFolds() const { return CVFolds; }
inline float getRegressionAccuracy() const { return regressionAccuracy; }
CV_IMPL_PROPERTY(bool, UseSurrogates, useSurrogates)
CV_IMPL_PROPERTY(bool, Use1SERule, use1SERule)
CV_IMPL_PROPERTY(bool, TruncatePrunedTree, truncatePrunedTree)
CV_IMPL_PROPERTY_S(cv::Mat, Priors, priors)
public:
bool useSurrogates;
bool use1SERule;
bool truncatePrunedTree;
Mat priors;
protected:
int maxCategories;
int maxDepth;
int minSampleCount;
int CVFolds;
float regressionAccuracy;
};
struct RTreeParams
{
RTreeParams();
RTreeParams(bool calcVarImportance, int nactiveVars, TermCriteria termCrit );
bool calcVarImportance;
int nactiveVars;
TermCriteria termCrit;
};
struct BoostTreeParams
{
BoostTreeParams();
BoostTreeParams(int boostType, int weakCount, double weightTrimRate);
int boostType;
int weakCount;
double weightTrimRate;
};
class DTreesImpl : public DTrees
{
public:
@@ -191,6 +276,16 @@ namespace ml
int maxSubsetSize;
};
CV_WRAP_SAME_PROPERTY(int, MaxCategories, params)
CV_WRAP_SAME_PROPERTY(int, MaxDepth, params)
CV_WRAP_SAME_PROPERTY(int, MinSampleCount, params)
CV_WRAP_SAME_PROPERTY(int, CVFolds, params)
CV_WRAP_SAME_PROPERTY(bool, UseSurrogates, params)
CV_WRAP_SAME_PROPERTY(bool, Use1SERule, params)
CV_WRAP_SAME_PROPERTY(bool, TruncatePrunedTree, params)
CV_WRAP_SAME_PROPERTY(float, RegressionAccuracy, params)
CV_WRAP_SAME_PROPERTY_S(cv::Mat, Priors, params)
DTreesImpl();
virtual ~DTreesImpl();
virtual void clear();
@@ -202,8 +297,7 @@ namespace ml
int getCatCount(int vi) const { return catOfs[vi][1] - catOfs[vi][0]; }
int getSubsetSize(int vi) const { return (getCatCount(vi) + 31)/32; }
virtual void setDParams(const Params& _params);
virtual Params getDParams() const;
virtual void setDParams(const TreeParams& _params);
virtual void startTraining( const Ptr<TrainData>& trainData, int flags );
virtual void endTraining();
virtual void initCompVarIdx();
@@ -250,7 +344,7 @@ namespace ml
virtual const std::vector<Split>& getSplits() const { return splits; }
virtual const std::vector<int>& getSubsets() const { return subsets; }
Params params0, params;
TreeParams params;
vector<int> varIdx;
vector<int> compVarIdx;