Merge pull request #5486 from amroamroamro:fix_ml_randMVNormal
This commit is contained in:
commit
0d791189ee
@ -151,21 +151,28 @@ static void Cholesky( const Mat& A, Mat& S )
|
||||
average row vector, <cov> - symmetric covariation matrix */
|
||||
void randMVNormal( InputArray _mean, InputArray _cov, int nsamples, OutputArray _samples )
|
||||
{
|
||||
// check mean vector and covariance matrix
|
||||
Mat mean = _mean.getMat(), cov = _cov.getMat();
|
||||
int dim = (int)mean.total();
|
||||
int dim = (int)mean.total(); // dimensionality
|
||||
CV_Assert(mean.rows == 1 || mean.cols == 1);
|
||||
CV_Assert(cov.rows == dim && cov.cols == dim);
|
||||
mean = mean.reshape(1,1); // ensure a row vector
|
||||
|
||||
// generate n-samples of the same dimension, from ~N(0,1)
|
||||
_samples.create(nsamples, dim, CV_32F);
|
||||
Mat samples = _samples.getMat();
|
||||
randu(samples, 0., 1.);
|
||||
randn(samples, Scalar::all(0), Scalar::all(1));
|
||||
|
||||
// decompose covariance using Cholesky: cov = U'*U
|
||||
// (cov must be square, symmetric, and positive semi-definite matrix)
|
||||
Mat utmat;
|
||||
Cholesky(cov, utmat);
|
||||
int flags = mean.cols == 1 ? 0 : GEMM_3_T;
|
||||
|
||||
// transform random numbers using specified mean and covariance
|
||||
for( int i = 0; i < nsamples; i++ )
|
||||
{
|
||||
Mat sample = samples.row(i);
|
||||
gemm(sample, utmat, 1, mean, 1, sample, flags);
|
||||
sample = sample * utmat + mean;
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user