diff --git a/src/common_audio/vad/vad_gmm.c b/src/common_audio/vad/vad_gmm.c index 086526177..20a703af0 100644 --- a/src/common_audio/vad/vad_gmm.c +++ b/src/common_audio/vad/vad_gmm.c @@ -8,68 +8,76 @@ * be found in the AUTHORS file in the root of the source tree. */ - -/* - * This file includes the implementation of the internal VAD call - * WebRtcVad_GaussianProbability. For function description, see vad_gmm.h. - */ - #include "vad_gmm.h" #include "signal_processing_library.h" #include "typedefs.h" -static const WebRtc_Word32 kCompVar = 22005; -// Constant log2(exp(1)) in Q12 -static const WebRtc_Word16 kLog10Const = 5909; +static const int32_t kCompVar = 22005; +static const int16_t kLog2Exp = 5909; // log2(exp(1)) in Q12. -WebRtc_Word32 WebRtcVad_GaussianProbability(WebRtc_Word16 in_sample, - WebRtc_Word16 mean, - WebRtc_Word16 std, - WebRtc_Word16 *delta) -{ - WebRtc_Word16 tmp16, tmpDiv, tmpDiv2, expVal, tmp16_1, tmp16_2; - WebRtc_Word32 tmp32, y32; +// For a normal distribution, the probability of |input| is calculated and +// returned (in Q20). The formula for normal distributed probability is +// +// 1 / s * exp(-(x - m)^2 / (2 * s^2)) +// +// where the parameters are given in the following Q domains: +// m = |mean| (Q7) +// s = |std| (Q7) +// x = |input| (Q4) +// in addition to the probability we output |delta| (in Q11) used when updating +// the noise/speech model. +int32_t WebRtcVad_GaussianProbability(int16_t input, + int16_t mean, + int16_t std, + int16_t* delta) { + int16_t tmp16, inv_std, inv_std2, exp_value = 0; + int32_t tmp32; - // Calculate tmpDiv=1/std, in Q10 - tmp32 = (WebRtc_Word32)WEBRTC_SPL_RSHIFT_W16(std,1) + (WebRtc_Word32)131072; // 1 in Q17 - tmpDiv = (WebRtc_Word16)WebRtcSpl_DivW32W16(tmp32, std); // Q17/Q7 = Q10 + // Calculate |inv_std| = 1 / s, in Q10. + // 131072 = 1 in Q17, and (|std| >> 1) is for rounding instead of truncation. + // Q-domain: Q17 / Q7 = Q10. + tmp32 = (int32_t) 131072 + (int32_t) (std >> 1); + inv_std = (int16_t) WebRtcSpl_DivW32W16(tmp32, std); - // Calculate tmpDiv2=1/std^2, in Q14 - tmp16 = WEBRTC_SPL_RSHIFT_W16(tmpDiv, 2); // From Q10 to Q8 - tmpDiv2 = (WebRtc_Word16)WEBRTC_SPL_MUL_16_16_RSFT(tmp16, tmp16, 2); // (Q8 * Q8)>>2 = Q14 + // Calculate |inv_std2| = 1 / s^2, in Q14. + tmp16 = (inv_std >> 2); // Q10 -> Q8. + // Q-domain: (Q8 * Q8) >> 2 = Q14. + inv_std2 = (int16_t) WEBRTC_SPL_MUL_16_16_RSFT(tmp16, tmp16, 2); + // TODO(bjornv): Investigate if changing to + // |inv_std2| = (int16_t) WEBRTC_SPL_MUL_16_16_RSFT(|inv_std|, |inv_std|, 6); + // gives better accuracy. - tmp16 = WEBRTC_SPL_LSHIFT_W16(in_sample, 3); // Q7 - tmp16 = tmp16 - mean; // Q7 - Q7 = Q7 + tmp16 = (input << 3); // Q4 -> Q7 + tmp16 = tmp16 - mean; // Q7 - Q7 = Q7 - // To be used later, when updating noise/speech model - // delta = (x-m)/std^2, in Q11 - *delta = (WebRtc_Word16)WEBRTC_SPL_MUL_16_16_RSFT(tmpDiv2, tmp16, 10); //(Q14*Q7)>>10 = Q11 + // To be used later, when updating noise/speech model. + // |delta| = (x - m) / s^2, in Q11. + // Q-domain: (Q14 * Q7) >> 10 = Q11. + *delta = (int16_t) WEBRTC_SPL_MUL_16_16_RSFT(inv_std2, tmp16, 10); - // Calculate tmp32=(x-m)^2/(2*std^2), in Q10 - tmp32 = (WebRtc_Word32)WEBRTC_SPL_MUL_16_16_RSFT(*delta, tmp16, 9); // One shift for /2 + // Calculate the exponent |tmp32| = (x - m)^2 / (2 * s^2), in Q10. Replacing + // division by two with one shift. + // Q-domain: (Q11 * Q7) >> 8 = Q10. + tmp32 = WEBRTC_SPL_MUL_16_16_RSFT(*delta, tmp16, 9); - // Calculate expVal ~= exp(-(x-m)^2/(2*std^2)) ~= exp2(-log2(exp(1))*tmp32) - if (tmp32 < kCompVar) - { - // Calculate tmp16 = log2(exp(1))*tmp32 , in Q10 - tmp16 = (WebRtc_Word16)WEBRTC_SPL_MUL_16_16_RSFT((WebRtc_Word16)tmp32, - kLog10Const, 12); - tmp16 = -tmp16; - tmp16_2 = (WebRtc_Word16)(0x0400 | (tmp16 & 0x03FF)); - tmp16_1 = (WebRtc_Word16)(tmp16 ^ 0xFFFF); - tmp16 = (WebRtc_Word16)WEBRTC_SPL_RSHIFT_W16(tmp16_1, 10); - tmp16 += 1; - // Calculate expVal=log2(-tmp32), in Q10 - expVal = (WebRtc_Word16)WEBRTC_SPL_RSHIFT_W32((WebRtc_Word32)tmp16_2, tmp16); + // If the exponent is small enough to give a non-zero probability we calculate + // |exp_value| ~= exp(-(x - m)^2 / (2 * s^2)) + // ~= exp2(-log2(exp(1)) * |tmp32|). + if (tmp32 < kCompVar) { + // Calculate |tmp16| = log2(exp(1)) * |tmp32|, in Q10. + // Q-domain: (Q12 * Q10) >> 12 = Q10. + tmp16 = (int16_t) WEBRTC_SPL_MUL_16_16_RSFT(kLog2Exp, (int16_t) tmp32, 12); + tmp16 = -tmp16; + exp_value = (0x0400 | (tmp16 & 0x03FF)); + tmp16 ^= 0xFFFF; + tmp16 >>= 10; + tmp16 += 1; + // Get |exp_value| = exp(-|tmp32|) in Q10. + exp_value >>= tmp16; + } - } else - { - expVal = 0; - } - - // Calculate y32=(1/std)*exp(-(x-m)^2/(2*std^2)), in Q20 - y32 = WEBRTC_SPL_MUL_16_16(tmpDiv, expVal); // Q10 * Q10 = Q20 - - return y32; // Q20 + // Calculate and return (1 / s) * exp(-(x - m)^2 / (2 * s^2)), in Q20. + // Q-domain: Q10 * Q10 = Q20. + return WEBRTC_SPL_MUL_16_16(inv_std, exp_value); } diff --git a/src/common_audio/vad/vad_gmm.h b/src/common_audio/vad/vad_gmm.h index e0747fb7e..2333af7e0 100644 --- a/src/common_audio/vad/vad_gmm.h +++ b/src/common_audio/vad/vad_gmm.h @@ -8,40 +8,32 @@ * be found in the AUTHORS file in the root of the source tree. */ +// Gaussian probability calculations internally used in vad_core.c. -/* - * This header file includes the description of the internal VAD call - * WebRtcVad_GaussianProbability. - */ - -#ifndef WEBRTC_VAD_GMM_H_ -#define WEBRTC_VAD_GMM_H_ +#ifndef WEBRTC_COMMON_AUDIO_VAD_VAD_GMM_H_ +#define WEBRTC_COMMON_AUDIO_VAD_VAD_GMM_H_ #include "typedefs.h" -/**************************************************************************** - * WebRtcVad_GaussianProbability(...) - * - * This function calculates the probability for the value 'in_sample', given that in_sample - * comes from a normal distribution with mean 'mean' and standard deviation 'std'. - * - * Input: - * - in_sample : Input sample in Q4 - * - mean : mean value in the statistical model, Q7 - * - std : standard deviation, Q7 - * - * Output: - * - * - delta : Value used when updating the model, Q11 - * - * Return: - * - out : out = 1/std * exp(-(x-m)^2/(2*std^2)); - * Probability for x. - * - */ -WebRtc_Word32 WebRtcVad_GaussianProbability(WebRtc_Word16 in_sample, - WebRtc_Word16 mean, - WebRtc_Word16 std, - WebRtc_Word16 *delta); +// Calculates the probability for |input|, given that |input| comes from a +// normal distribution with mean and standard deviation (|mean|, |std|). +// +// Inputs: +// - input : input sample in Q4. +// - mean : mean input in the statistical model, Q7. +// - std : standard deviation, Q7. +// +// Output: +// +// - delta : input used when updating the model, Q11. +// |delta| = (|input| - |mean|) / |std|^2. +// +// Return: +// (probability for |input|) = +// 1 / |std| * exp(-(|input| - |mean|)^2 / (2 * |std|^2)); +int32_t WebRtcVad_GaussianProbability(int16_t input, + int16_t mean, + int16_t std, + int16_t* delta); -#endif // WEBRTC_VAD_GMM_H_ +#endif // WEBRTC_COMMON_AUDIO_VAD_VAD_GMM_H_ diff --git a/src/common_audio/vad/vad_unittest.cc b/src/common_audio/vad/vad_unittest.cc index 1fcb6111d..66e5ffad5 100644 --- a/src/common_audio/vad/vad_unittest.cc +++ b/src/common_audio/vad/vad_unittest.cc @@ -8,13 +8,20 @@ * be found in the AUTHORS file in the root of the source tree. */ -#include // size_t +#include // size_t #include #include "gtest/gtest.h" #include "typedefs.h" #include "webrtc_vad.h" +#ifdef __cplusplus +extern "C" +{ +#include "vad_gmm.h" +} +#endif + namespace webrtc { namespace { const int16_t kModes[] = { 0, 1, 2, 3 }; @@ -152,6 +159,29 @@ TEST_F(VadTest, ApiTest) { EXPECT_EQ(0, WebRtcVad_Free(handle)); } +TEST_F(VadTest, GMMTests) { + int16_t delta = 0; + // Input value at mean. + EXPECT_EQ(1048576, WebRtcVad_GaussianProbability(0, 0, 128, &delta)); + EXPECT_EQ(0, delta); + EXPECT_EQ(1048576, WebRtcVad_GaussianProbability(16, 128, 128, &delta)); + EXPECT_EQ(0, delta); + EXPECT_EQ(1048576, WebRtcVad_GaussianProbability(-16, -128, 128, &delta)); + EXPECT_EQ(0, delta); + + // Largest possible input to give non-zero probability. + EXPECT_EQ(1024, WebRtcVad_GaussianProbability(59, 0, 128, &delta)); + EXPECT_EQ(7552, delta); + EXPECT_EQ(1024, WebRtcVad_GaussianProbability(75, 128, 128, &delta)); + EXPECT_EQ(7552, delta); + EXPECT_EQ(1024, WebRtcVad_GaussianProbability(-75, -128, 128, &delta)); + EXPECT_EQ(-7552, delta); + + // Too large input, should give zero probability. + EXPECT_EQ(0, WebRtcVad_GaussianProbability(105, 0, 128, &delta)); + EXPECT_EQ(13440, delta); +} + // TODO(bjornv): Add a process test, run on file. } // namespace