Updated ml module interfaces and documentation
This commit is contained in:
@@ -54,48 +54,33 @@ log_ratio( double val )
|
||||
}
|
||||
|
||||
|
||||
Boost::Params::Params()
|
||||
BoostTreeParams::BoostTreeParams()
|
||||
{
|
||||
boostType = Boost::REAL;
|
||||
weakCount = 100;
|
||||
weightTrimRate = 0.95;
|
||||
CVFolds = 0;
|
||||
maxDepth = 1;
|
||||
}
|
||||
|
||||
|
||||
Boost::Params::Params( int _boostType, int _weak_count,
|
||||
double _weightTrimRate, int _maxDepth,
|
||||
bool _use_surrogates, const Mat& _priors )
|
||||
BoostTreeParams::BoostTreeParams( int _boostType, int _weak_count,
|
||||
double _weightTrimRate)
|
||||
{
|
||||
boostType = _boostType;
|
||||
weakCount = _weak_count;
|
||||
weightTrimRate = _weightTrimRate;
|
||||
CVFolds = 0;
|
||||
maxDepth = _maxDepth;
|
||||
useSurrogates = _use_surrogates;
|
||||
priors = _priors;
|
||||
}
|
||||
|
||||
|
||||
class DTreesImplForBoost : public DTreesImpl
|
||||
{
|
||||
public:
|
||||
DTreesImplForBoost() {}
|
||||
DTreesImplForBoost()
|
||||
{
|
||||
params.setCVFolds(0);
|
||||
params.setMaxDepth(1);
|
||||
}
|
||||
virtual ~DTreesImplForBoost() {}
|
||||
|
||||
bool isClassifier() const { return true; }
|
||||
|
||||
void setBParams(const Boost::Params& p)
|
||||
{
|
||||
bparams = p;
|
||||
}
|
||||
|
||||
Boost::Params getBParams() const
|
||||
{
|
||||
return bparams;
|
||||
}
|
||||
|
||||
void clear()
|
||||
{
|
||||
DTreesImpl::clear();
|
||||
@@ -199,10 +184,6 @@ public:
|
||||
|
||||
bool train( const Ptr<TrainData>& trainData, int flags )
|
||||
{
|
||||
Params dp(bparams.maxDepth, bparams.minSampleCount, bparams.regressionAccuracy,
|
||||
bparams.useSurrogates, bparams.maxCategories, 0,
|
||||
false, false, bparams.priors);
|
||||
setDParams(dp);
|
||||
startTraining(trainData, flags);
|
||||
int treeidx, ntrees = bparams.weakCount >= 0 ? bparams.weakCount : 10000;
|
||||
vector<int> sidx = w->sidx;
|
||||
@@ -426,12 +407,6 @@ public:
|
||||
void readParams( const FileNode& fn )
|
||||
{
|
||||
DTreesImpl::readParams(fn);
|
||||
bparams.maxDepth = params0.maxDepth;
|
||||
bparams.minSampleCount = params0.minSampleCount;
|
||||
bparams.regressionAccuracy = params0.regressionAccuracy;
|
||||
bparams.useSurrogates = params0.useSurrogates;
|
||||
bparams.maxCategories = params0.maxCategories;
|
||||
bparams.priors = params0.priors;
|
||||
|
||||
FileNode tparams_node = fn["training_params"];
|
||||
// check for old layout
|
||||
@@ -465,7 +440,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
Boost::Params bparams;
|
||||
BoostTreeParams bparams;
|
||||
vector<double> sumResult;
|
||||
};
|
||||
|
||||
@@ -476,6 +451,20 @@ public:
|
||||
BoostImpl() {}
|
||||
virtual ~BoostImpl() {}
|
||||
|
||||
CV_IMPL_PROPERTY(int, BoostType, impl.bparams.boostType)
|
||||
CV_IMPL_PROPERTY(int, WeakCount, impl.bparams.weakCount)
|
||||
CV_IMPL_PROPERTY(double, WeightTrimRate, impl.bparams.weightTrimRate)
|
||||
|
||||
CV_WRAP_SAME_PROPERTY(int, MaxCategories, impl.params)
|
||||
CV_WRAP_SAME_PROPERTY(int, MaxDepth, impl.params)
|
||||
CV_WRAP_SAME_PROPERTY(int, MinSampleCount, impl.params)
|
||||
CV_WRAP_SAME_PROPERTY(int, CVFolds, impl.params)
|
||||
CV_WRAP_SAME_PROPERTY(bool, UseSurrogates, impl.params)
|
||||
CV_WRAP_SAME_PROPERTY(bool, Use1SERule, impl.params)
|
||||
CV_WRAP_SAME_PROPERTY(bool, TruncatePrunedTree, impl.params)
|
||||
CV_WRAP_SAME_PROPERTY(float, RegressionAccuracy, impl.params)
|
||||
CV_WRAP_SAME_PROPERTY_S(cv::Mat, Priors, impl.params)
|
||||
|
||||
String getDefaultModelName() const { return "opencv_ml_boost"; }
|
||||
|
||||
bool train( const Ptr<TrainData>& trainData, int flags )
|
||||
@@ -498,9 +487,6 @@ public:
|
||||
impl.read(fn);
|
||||
}
|
||||
|
||||
void setBParams(const Params& p) { impl.setBParams(p); }
|
||||
Params getBParams() const { return impl.getBParams(); }
|
||||
|
||||
int getVarCount() const { return impl.getVarCount(); }
|
||||
|
||||
bool isTrained() const { return impl.isTrained(); }
|
||||
@@ -515,11 +501,9 @@ public:
|
||||
};
|
||||
|
||||
|
||||
Ptr<Boost> Boost::create(const Params& params)
|
||||
Ptr<Boost> Boost::create()
|
||||
{
|
||||
Ptr<BoostImpl> p = makePtr<BoostImpl>();
|
||||
p->setBParams(params);
|
||||
return p;
|
||||
return makePtr<BoostImpl>();
|
||||
}
|
||||
|
||||
}}
|
||||
|
Reference in New Issue
Block a user