exposed parallelized SVM prediction to python (predict_all)
This commit is contained in:
parent
e4d9d5294e
commit
a98d6b6217
@ -488,7 +488,7 @@ public:
|
|||||||
bool balanced=false );
|
bool balanced=false );
|
||||||
|
|
||||||
virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
|
virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
|
||||||
virtual float predict( const CvMat* samples, CvMat* results ) const;
|
virtual float predict( const CvMat* samples, CV_OUT CvMat* results ) const;
|
||||||
|
|
||||||
#ifndef SWIG
|
#ifndef SWIG
|
||||||
CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
|
CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
|
||||||
@ -510,6 +510,7 @@ public:
|
|||||||
CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
|
CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
|
||||||
bool balanced=false);
|
bool balanced=false);
|
||||||
CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
|
CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
|
||||||
|
CV_WRAP_AS(predict_all) virtual void predict( cv::InputArray samples, cv::OutputArray results ) const;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
CV_WRAP virtual int get_support_vector_count() const;
|
CV_WRAP virtual int get_support_vector_count() const;
|
||||||
|
@ -2124,6 +2124,12 @@ float CvSVM::predict(const CvMat* samples, CV_OUT CvMat* results) const
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CvSVM::predict( cv::InputArray _samples, cv::OutputArray _results ) const
|
||||||
|
{
|
||||||
|
_results.create(_samples.size().height, 1, CV_32F);
|
||||||
|
CvMat samples = _samples.getMat(), results = _results.getMat();
|
||||||
|
predict(&samples, &results);
|
||||||
|
}
|
||||||
|
|
||||||
CvSVM::CvSVM( const Mat& _train_data, const Mat& _responses,
|
CvSVM::CvSVM( const Mat& _train_data, const Mat& _responses,
|
||||||
const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params )
|
const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params )
|
||||||
|
@ -88,7 +88,7 @@ class SVM(LetterStatModel):
|
|||||||
self.model.train(samples, responses, params = params)
|
self.model.train(samples, responses, params = params)
|
||||||
|
|
||||||
def predict(self, samples):
|
def predict(self, samples):
|
||||||
return np.float32( [self.model.predict(s) for s in samples] )
|
return self.model.predict_all(samples).ravel()
|
||||||
|
|
||||||
|
|
||||||
class MLP(LetterStatModel):
|
class MLP(LetterStatModel):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user