Merge pull request #5486 from amroamroamro:fix_ml_randMVNormal
This commit is contained in:
commit
0d791189ee
@ -151,21 +151,28 @@ static void Cholesky( const Mat& A, Mat& S )
|
|||||||
average row vector, <cov> - symmetric covariation matrix */
|
average row vector, <cov> - symmetric covariation matrix */
|
||||||
void randMVNormal( InputArray _mean, InputArray _cov, int nsamples, OutputArray _samples )
|
void randMVNormal( InputArray _mean, InputArray _cov, int nsamples, OutputArray _samples )
|
||||||
{
|
{
|
||||||
|
// check mean vector and covariance matrix
|
||||||
Mat mean = _mean.getMat(), cov = _cov.getMat();
|
Mat mean = _mean.getMat(), cov = _cov.getMat();
|
||||||
int dim = (int)mean.total();
|
int dim = (int)mean.total(); // dimensionality
|
||||||
|
CV_Assert(mean.rows == 1 || mean.cols == 1);
|
||||||
|
CV_Assert(cov.rows == dim && cov.cols == dim);
|
||||||
|
mean = mean.reshape(1,1); // ensure a row vector
|
||||||
|
|
||||||
|
// generate n-samples of the same dimension, from ~N(0,1)
|
||||||
_samples.create(nsamples, dim, CV_32F);
|
_samples.create(nsamples, dim, CV_32F);
|
||||||
Mat samples = _samples.getMat();
|
Mat samples = _samples.getMat();
|
||||||
randu(samples, 0., 1.);
|
randn(samples, Scalar::all(0), Scalar::all(1));
|
||||||
|
|
||||||
|
// decompose covariance using Cholesky: cov = U'*U
|
||||||
|
// (cov must be square, symmetric, and positive semi-definite matrix)
|
||||||
Mat utmat;
|
Mat utmat;
|
||||||
Cholesky(cov, utmat);
|
Cholesky(cov, utmat);
|
||||||
int flags = mean.cols == 1 ? 0 : GEMM_3_T;
|
|
||||||
|
|
||||||
|
// transform random numbers using specified mean and covariance
|
||||||
for( int i = 0; i < nsamples; i++ )
|
for( int i = 0; i < nsamples; i++ )
|
||||||
{
|
{
|
||||||
Mat sample = samples.row(i);
|
Mat sample = samples.row(i);
|
||||||
gemm(sample, utmat, 1, mean, 1, sample, flags);
|
sample = sample * utmat + mean;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user