Updated ml module interfaces and documentation

This commit is contained in:
Maksim Shabunin
2015-02-11 13:24:14 +03:00
parent da383e65e2
commit 79e8f0680c
32 changed files with 1403 additions and 1528 deletions

View File

@@ -54,48 +54,33 @@ log_ratio( double val )
}
Boost::Params::Params()
BoostTreeParams::BoostTreeParams()
{
boostType = Boost::REAL;
weakCount = 100;
weightTrimRate = 0.95;
CVFolds = 0;
maxDepth = 1;
}
Boost::Params::Params( int _boostType, int _weak_count,
double _weightTrimRate, int _maxDepth,
bool _use_surrogates, const Mat& _priors )
BoostTreeParams::BoostTreeParams( int _boostType, int _weak_count,
double _weightTrimRate)
{
boostType = _boostType;
weakCount = _weak_count;
weightTrimRate = _weightTrimRate;
CVFolds = 0;
maxDepth = _maxDepth;
useSurrogates = _use_surrogates;
priors = _priors;
}
class DTreesImplForBoost : public DTreesImpl
{
public:
DTreesImplForBoost() {}
DTreesImplForBoost()
{
params.setCVFolds(0);
params.setMaxDepth(1);
}
virtual ~DTreesImplForBoost() {}
bool isClassifier() const { return true; }
void setBParams(const Boost::Params& p)
{
bparams = p;
}
Boost::Params getBParams() const
{
return bparams;
}
void clear()
{
DTreesImpl::clear();
@@ -199,10 +184,6 @@ public:
bool train( const Ptr<TrainData>& trainData, int flags )
{
Params dp(bparams.maxDepth, bparams.minSampleCount, bparams.regressionAccuracy,
bparams.useSurrogates, bparams.maxCategories, 0,
false, false, bparams.priors);
setDParams(dp);
startTraining(trainData, flags);
int treeidx, ntrees = bparams.weakCount >= 0 ? bparams.weakCount : 10000;
vector<int> sidx = w->sidx;
@@ -426,12 +407,6 @@ public:
void readParams( const FileNode& fn )
{
DTreesImpl::readParams(fn);
bparams.maxDepth = params0.maxDepth;
bparams.minSampleCount = params0.minSampleCount;
bparams.regressionAccuracy = params0.regressionAccuracy;
bparams.useSurrogates = params0.useSurrogates;
bparams.maxCategories = params0.maxCategories;
bparams.priors = params0.priors;
FileNode tparams_node = fn["training_params"];
// check for old layout
@@ -465,7 +440,7 @@ public:
}
}
Boost::Params bparams;
BoostTreeParams bparams;
vector<double> sumResult;
};
@@ -476,6 +451,20 @@ public:
BoostImpl() {}
virtual ~BoostImpl() {}
CV_IMPL_PROPERTY(int, BoostType, impl.bparams.boostType)
CV_IMPL_PROPERTY(int, WeakCount, impl.bparams.weakCount)
CV_IMPL_PROPERTY(double, WeightTrimRate, impl.bparams.weightTrimRate)
CV_WRAP_SAME_PROPERTY(int, MaxCategories, impl.params)
CV_WRAP_SAME_PROPERTY(int, MaxDepth, impl.params)
CV_WRAP_SAME_PROPERTY(int, MinSampleCount, impl.params)
CV_WRAP_SAME_PROPERTY(int, CVFolds, impl.params)
CV_WRAP_SAME_PROPERTY(bool, UseSurrogates, impl.params)
CV_WRAP_SAME_PROPERTY(bool, Use1SERule, impl.params)
CV_WRAP_SAME_PROPERTY(bool, TruncatePrunedTree, impl.params)
CV_WRAP_SAME_PROPERTY(float, RegressionAccuracy, impl.params)
CV_WRAP_SAME_PROPERTY_S(cv::Mat, Priors, impl.params)
String getDefaultModelName() const { return "opencv_ml_boost"; }
bool train( const Ptr<TrainData>& trainData, int flags )
@@ -498,9 +487,6 @@ public:
impl.read(fn);
}
void setBParams(const Params& p) { impl.setBParams(p); }
Params getBParams() const { return impl.getBParams(); }
int getVarCount() const { return impl.getVarCount(); }
bool isTrained() const { return impl.isTrained(); }
@@ -515,11 +501,9 @@ public:
};
Ptr<Boost> Boost::create(const Params& params)
Ptr<Boost> Boost::create()
{
Ptr<BoostImpl> p = makePtr<BoostImpl>();
p->setBParams(params);
return p;
return makePtr<BoostImpl>();
}
}}