Refactor vad_filterbank: Some restructuring.

- Removed unnecessary type casting.
- Added comments.
- Removed shift macros.
- Name change of _get_features() to _CalculateFeatures(). Affects vad_core.c and vad_filterbank_unittest.cc.
Review URL: http://webrtc-codereview.appspot.com/343002

git-svn-id: http://webrtc.googlecode.com/svn/trunk@1371 4adac7df-926f-26a2-2b94-8c16560cd09d
This commit is contained in:
bjornv@webrtc.org 2012-01-10 13:48:09 +00:00
parent d4e8c0b3ff
commit d1f148da77
4 changed files with 116 additions and 124 deletions

View File

@ -316,7 +316,8 @@ WebRtc_Word16 WebRtcVad_CalcVad8khz(VadInstT *inst, WebRtc_Word16 *speech_frame,
WebRtc_Word16 feature_vector[NUM_CHANNELS], total_power; WebRtc_Word16 feature_vector[NUM_CHANNELS], total_power;
// Get power in the bands // Get power in the bands
total_power = WebRtcVad_get_features(inst, speech_frame, frame_length, feature_vector); total_power = WebRtcVad_CalculateFeatures(inst, speech_frame, frame_length,
feature_vector);
// Make a VAD // Make a VAD
inst->vad = WebRtcVad_GmmProbability(inst, feature_vector, total_power, frame_length); inst->vad = WebRtcVad_GmmProbability(inst, feature_vector, total_power, frame_length);

View File

@ -8,13 +8,10 @@
* be found in the AUTHORS file in the root of the source tree. * be found in the AUTHORS file in the root of the source tree.
*/ */
/*
* This file includes the implementation of the internal filterbank associated functions.
* For function description, see vad_filterbank.h.
*/
#include "vad_filterbank.h" #include "vad_filterbank.h"
#include <assert.h>
#include "signal_processing_library.h" #include "signal_processing_library.h"
#include "typedefs.h" #include "typedefs.h"
#include "vad_defines.h" #include "vad_defines.h"
@ -58,20 +55,18 @@ static void HighPassFilter(const int16_t* in_vector, int in_vector_length,
// Impulse response: 1.0000 0.4734 -0.1189 -0.2187 -0.0627 0.04532 // Impulse response: 1.0000 0.4734 -0.1189 -0.2187 -0.0627 0.04532
for (i = 0; i < in_vector_length; i++) { for (i = 0; i < in_vector_length; i++) {
// all-zero section (filter coefficients in Q14) // All-zero section (filter coefficients in Q14).
tmp32 = (int32_t) WEBRTC_SPL_MUL_16_16(kHpZeroCoefs[0], (*in_ptr)); tmp32 = WEBRTC_SPL_MUL_16_16(kHpZeroCoefs[0], *in_ptr);
tmp32 += (int32_t) WEBRTC_SPL_MUL_16_16(kHpZeroCoefs[1], filter_state[0]); tmp32 += WEBRTC_SPL_MUL_16_16(kHpZeroCoefs[1], filter_state[0]);
tmp32 += (int32_t) WEBRTC_SPL_MUL_16_16(kHpZeroCoefs[2], tmp32 += WEBRTC_SPL_MUL_16_16(kHpZeroCoefs[2], filter_state[1]);
filter_state[1]); // Q14
filter_state[1] = filter_state[0]; filter_state[1] = filter_state[0];
filter_state[0] = *in_ptr++; filter_state[0] = *in_ptr++;
// all-pole section // All-pole section (filter coefficients in Q14).
tmp32 -= (int32_t) WEBRTC_SPL_MUL_16_16(kHpPoleCoefs[1], tmp32 -= WEBRTC_SPL_MUL_16_16(kHpPoleCoefs[1], filter_state[2]);
filter_state[2]); // Q14 tmp32 -= WEBRTC_SPL_MUL_16_16(kHpPoleCoefs[2], filter_state[3]);
tmp32 -= (int32_t) WEBRTC_SPL_MUL_16_16(kHpPoleCoefs[2], filter_state[3]);
filter_state[3] = filter_state[2]; filter_state[3] = filter_state[2];
filter_state[2] = (int16_t) WEBRTC_SPL_RSHIFT_W32 (tmp32, 14); filter_state[2] = (int16_t) (tmp32 >> 14);
*out_ptr++ = filter_state[2]; *out_ptr++ = filter_state[2];
} }
} }
@ -92,25 +87,25 @@ static void AllPassFilter(const int16_t* in_vector, int vector_length,
// The filter can only cause overflow (in the w16 output variable) // The filter can only cause overflow (in the w16 output variable)
// if more than 4 consecutive input numbers are of maximum value and // if more than 4 consecutive input numbers are of maximum value and
// has the the same sign as the impulse responses first taps. // has the the same sign as the impulse responses first taps.
// First 6 taps of the impulse response: 0.6399 0.5905 -0.3779 // First 6 taps of the impulse response:
// 0.2418 -0.1547 0.0990 // 0.6399 0.5905 -0.3779 0.2418 -0.1547 0.0990
int i; int i;
int16_t tmp16 = 0; int16_t tmp16 = 0;
int32_t tmp32 = 0, in32 = 0; int32_t tmp32 = 0;
int32_t state32 = WEBRTC_SPL_LSHIFT_W32((int32_t) (*filter_state), 16); // Q31 int32_t state32 = ((int32_t) (*filter_state) << 16); // Q15
for (i = 0; i < vector_length; i++) { for (i = 0; i < vector_length; i++) {
tmp32 = state32 + WEBRTC_SPL_MUL_16_16(filter_coefficient, (*in_vector)); tmp32 = state32 + WEBRTC_SPL_MUL_16_16(filter_coefficient, *in_vector);
tmp16 = (int16_t) WEBRTC_SPL_RSHIFT_W32(tmp32, 16); tmp16 = (int16_t) (tmp32 >> 16); // Q(-1)
*out_vector++ = tmp16; *out_vector++ = tmp16;
in32 = WEBRTC_SPL_LSHIFT_W32(((int32_t) (*in_vector)), 14); state32 = (((int32_t) (*in_vector)) << 14); // Q14
state32 = in32 - WEBRTC_SPL_MUL_16_16(filter_coefficient, tmp16); state32 -= WEBRTC_SPL_MUL_16_16(filter_coefficient, tmp16); // Q14
state32 = WEBRTC_SPL_LSHIFT_W32(state32, 1); state32 <<= 1; // Q15.
in_vector += 2; in_vector += 2;
} }
*filter_state = (int16_t) WEBRTC_SPL_RSHIFT_W32(state32, 16); *filter_state = (int16_t) (state32 >> 16); // Q(-1)
} }
// Splits |in_vector| into |out_vector_hp| and |out_vector_lp| corresponding to // Splits |in_vector| into |out_vector_hp| and |out_vector_lp| corresponding to
@ -128,19 +123,19 @@ static void AllPassFilter(const int16_t* in_vector, int vector_length,
static void SplitFilter(const int16_t* in_vector, int in_vector_length, static void SplitFilter(const int16_t* in_vector, int in_vector_length,
int16_t* upper_state, int16_t* lower_state, int16_t* upper_state, int16_t* lower_state,
int16_t* out_vector_hp, int16_t* out_vector_lp) { int16_t* out_vector_hp, int16_t* out_vector_lp) {
int16_t tmp_out;
int i; int i;
int half_length = WEBRTC_SPL_RSHIFT_W16(in_vector_length, 1); int half_length = in_vector_length >> 1; // Downsampling by 2.
int16_t tmp_out;
// All-pass filtering upper branch // All-pass filtering upper branch.
AllPassFilter(&in_vector[0], half_length, kAllPassCoefsQ15[0], upper_state, AllPassFilter(&in_vector[0], half_length, kAllPassCoefsQ15[0], upper_state,
out_vector_hp); out_vector_hp);
// All-pass filtering lower branch // All-pass filtering lower branch.
AllPassFilter(&in_vector[1], half_length, kAllPassCoefsQ15[1], lower_state, AllPassFilter(&in_vector[1], half_length, kAllPassCoefsQ15[1], lower_state,
out_vector_lp); out_vector_lp);
// Make LP and HP signals // Make LP and HP signals.
for (i = 0; i < half_length; i++) { for (i = 0; i < half_length; i++) {
tmp_out = *out_vector_hp; tmp_out = *out_vector_hp;
*out_vector_hp++ -= *out_vector_lp; *out_vector_hp++ -= *out_vector_lp;
@ -211,87 +206,91 @@ static void LogOfEnergy(const int16_t* in_vector, int vector_length,
} }
} }
int16_t WebRtcVad_get_features(VadInstT* inst, const int16_t* in_vector, int16_t WebRtcVad_CalculateFeatures(VadInstT* self, const int16_t* data_in,
int frame_size, int16_t* out_vector) { int data_length, int16_t* data_out) {
int16_t power = 0; int16_t power = 0;
// We expect |frame_size| to be 80, 160 or 240 samples, which corresponds to // We expect |data_length| to be 80, 160 or 240 samples, which corresponds to
// 10, 20 or 30 ms in 8 kHz. Therefore, the intermediate downsampled data will // 10, 20 or 30 ms in 8 kHz. Therefore, the intermediate downsampled data will
// have at most 120 samples after the first split and at most 60 samples after // have at most 120 samples after the first split and at most 60 samples after
// the second split. // the second split.
int16_t hp_120[120], lp_120[120]; int16_t hp_120[120], lp_120[120];
int16_t hp_60[60], lp_60[60]; int16_t hp_60[60], lp_60[60];
const int half_data_length = data_length >> 1;
int length = half_data_length; // |data_length| / 2, corresponds to
// bandwidth = 2000 Hz after downsampling.
// Initialize variables for the first SplitFilter(). // Initialize variables for the first SplitFilter().
int length = frame_size;
int frequency_band = 0; int frequency_band = 0;
const int16_t* in_ptr = in_vector; const int16_t* in_ptr = data_in; // [0 - 4000] Hz.
int16_t* hp_out_ptr = hp_120; int16_t* hp_out_ptr = hp_120; // [2000 - 4000] Hz.
int16_t* lp_out_ptr = lp_120; int16_t* lp_out_ptr = lp_120; // [0 - 2000] Hz.
// Split at 2000 Hz and downsample assert(data_length >= 0);
SplitFilter(in_ptr, length, &inst->upper_state[frequency_band], assert(data_length <= 240);
&inst->lower_state[frequency_band], hp_out_ptr, lp_out_ptr); assert(4 < NUM_CHANNELS - 1); // Checking maximum |frequency_band|.
// Split at 3000 Hz and downsample // Split at 2000 Hz and downsample.
SplitFilter(in_ptr, data_length, &self->upper_state[frequency_band],
&self->lower_state[frequency_band], hp_out_ptr, lp_out_ptr);
// For the upper band (2000 Hz - 4000 Hz) split at 3000 Hz and downsample.
frequency_band = 1; frequency_band = 1;
in_ptr = hp_120; in_ptr = hp_120; // [2000 - 4000] Hz.
hp_out_ptr = hp_60; hp_out_ptr = hp_60; // [3000 - 4000] Hz.
lp_out_ptr = lp_60; lp_out_ptr = lp_60; // [2000 - 3000] Hz.
length = WEBRTC_SPL_RSHIFT_W16(frame_size, 1); SplitFilter(in_ptr, length, &self->upper_state[frequency_band],
&self->lower_state[frequency_band], hp_out_ptr, lp_out_ptr);
SplitFilter(in_ptr, length, &inst->upper_state[frequency_band], // Energy in 3000 Hz - 4000 Hz.
&inst->lower_state[frequency_band], hp_out_ptr, lp_out_ptr); length >>= 1; // |data_length| / 4 <=> bandwidth = 1000 Hz.
// Energy in 3000 Hz - 4000 Hz LogOfEnergy(hp_60, length, kOffsetVector[5], &power, &data_out[5]);
length = WEBRTC_SPL_RSHIFT_W16(length, 1);
LogOfEnergy(hp_60, length, kOffsetVector[5], &power, &out_vector[5]);
// Energy in 2000 Hz - 3000 Hz // Energy in 2000 Hz - 3000 Hz.
LogOfEnergy(lp_60, length, kOffsetVector[4], &power, &out_vector[4]); LogOfEnergy(lp_60, length, kOffsetVector[4], &power, &data_out[4]);
// Split at 1000 Hz and downsample // For the lower band (0 Hz - 2000 Hz) split at 1000 Hz and downsample.
frequency_band = 2; frequency_band = 2;
in_ptr = lp_120; in_ptr = lp_120; // [0 - 2000] Hz.
hp_out_ptr = hp_60; hp_out_ptr = hp_60; // [1000 - 2000] Hz.
lp_out_ptr = lp_60; lp_out_ptr = lp_60; // [0 - 1000] Hz.
length = WEBRTC_SPL_RSHIFT_W16(frame_size, 1); length = half_data_length; // |data_length| / 2 <=> bandwidth = 2000 Hz.
SplitFilter(in_ptr, length, &inst->upper_state[frequency_band], SplitFilter(in_ptr, length, &self->upper_state[frequency_band],
&inst->lower_state[frequency_band], hp_out_ptr, lp_out_ptr); &self->lower_state[frequency_band], hp_out_ptr, lp_out_ptr);
// Energy in 1000 Hz - 2000 Hz // Energy in 1000 Hz - 2000 Hz.
length = WEBRTC_SPL_RSHIFT_W16(length, 1); length >>= 1; // |data_length| / 4 <=> bandwidth = 1000 Hz.
LogOfEnergy(hp_60, length, kOffsetVector[3], &power, &out_vector[3]); LogOfEnergy(hp_60, length, kOffsetVector[3], &power, &data_out[3]);
// Split at 500 Hz // For the lower band (0 Hz - 1000 Hz) split at 500 Hz and downsample.
frequency_band = 3; frequency_band = 3;
in_ptr = lp_60; in_ptr = lp_60; // [0 - 1000] Hz.
hp_out_ptr = hp_120; hp_out_ptr = hp_120; // [500 - 1000] Hz.
lp_out_ptr = lp_120; lp_out_ptr = lp_120; // [0 - 500] Hz.
SplitFilter(in_ptr, length, &self->upper_state[frequency_band],
&self->lower_state[frequency_band], hp_out_ptr, lp_out_ptr);
SplitFilter(in_ptr, length, &inst->upper_state[frequency_band], // Energy in 500 Hz - 1000 Hz.
&inst->lower_state[frequency_band], hp_out_ptr, lp_out_ptr); length >>= 1; // |data_length| / 8 <=> bandwidth = 500 Hz.
LogOfEnergy(hp_120, length, kOffsetVector[2], &power, &data_out[2]);
// Energy in 500 Hz - 1000 Hz // For the lower band (0 Hz - 500 Hz) split at 250 Hz and downsample.
length = WEBRTC_SPL_RSHIFT_W16(length, 1);
LogOfEnergy(hp_120, length, kOffsetVector[2], &power, &out_vector[2]);
// Split at 250 Hz
frequency_band = 4; frequency_band = 4;
in_ptr = lp_120; in_ptr = lp_120; // [0 - 500] Hz.
hp_out_ptr = hp_60; hp_out_ptr = hp_60; // [250 - 500] Hz.
lp_out_ptr = lp_60; lp_out_ptr = lp_60; // [0 - 250] Hz.
SplitFilter(in_ptr, length, &self->upper_state[frequency_band],
&self->lower_state[frequency_band], hp_out_ptr, lp_out_ptr);
SplitFilter(in_ptr, length, &inst->upper_state[frequency_band], // Energy in 250 Hz - 500 Hz.
&inst->lower_state[frequency_band], hp_out_ptr, lp_out_ptr); length >>= 1; // |data_length| / 16 <=> bandwidth = 250 Hz.
LogOfEnergy(hp_60, length, kOffsetVector[1], &power, &data_out[1]);
// Energy in 250 Hz - 500 Hz // Remove 0 Hz - 80 Hz, by high pass filtering the lower band.
length = WEBRTC_SPL_RSHIFT_W16(length, 1); HighPassFilter(lp_60, length, self->hp_filter_state, hp_120);
LogOfEnergy(hp_60, length, kOffsetVector[1], &power, &out_vector[1]);
// Remove DC and LFs // Energy in 80 Hz - 250 Hz.
HighPassFilter(lp_60, length, inst->hp_filter_state, hp_120); LogOfEnergy(hp_120, length, kOffsetVector[0], &power, &data_out[0]);
// Power in 80 Hz - 250 Hz
LogOfEnergy(hp_120, length, kOffsetVector[0], &power, &out_vector[0]);
return power; return power;
} }

View File

@ -9,8 +9,7 @@
*/ */
/* /*
* This header file includes the description of the internal VAD call * This file includes feature calculating functionality used in vad_core.c.
* WebRtcVad_GaussianProbability.
*/ */
#ifndef WEBRTC_COMMON_AUDIO_VAD_VAD_FILTERBANK_H_ #ifndef WEBRTC_COMMON_AUDIO_VAD_VAD_FILTERBANK_H_
@ -19,34 +18,27 @@
#include "typedefs.h" #include "typedefs.h"
#include "vad_core.h" #include "vad_core.h"
// TODO(bjornv): Rename to CalcFeatures() or similar. Update at the same time // Takes |data_length| samples of |data_in| and calculates the logarithm of the
// comments and parameter order. // power of each of the |NUM_CHANNELS| = 6 frequency bands used by the VAD:
/**************************************************************************** // 80 Hz - 250 Hz
* WebRtcVad_get_features(...) // 250 Hz - 500 Hz
* // 500 Hz - 1000 Hz
* This function is used to get the logarithm of the power of each of the // 1000 Hz - 2000 Hz
* 6 frequency bands used by the VAD: // 2000 Hz - 3000 Hz
* 80 Hz - 250 Hz // 3000 Hz - 4000 Hz
* 250 Hz - 500 Hz //
* 500 Hz - 1000 Hz // The values are given in Q4 and written to |data_out|. Further, an approximate
* 1000 Hz - 2000 Hz // overall power is returned. The return value is used in
* 2000 Hz - 3000 Hz // WebRtcVad_GmmProbability() as a signal indicator, hence it is arbitrary above
* 3000 Hz - 4000 Hz // the threshold MIN_ENERGY.
* //
* Input: // - self [i/o] : State information of the VAD.
* - inst : Pointer to VAD instance // - data_in [i] : Input audio data, for feature extraction.
* - in_vector : Input speech signal // - data_length [i] : Audio data size, in number of samples.
* - frame_size : Frame size, in number of samples // - data_out [o] : 10 * log10(power in each frequency band), Q4.
* // - returns : Total power of the signal (NOTE! This value is not
* Output: // exact. It is only used in a comparison.)
* - out_vector : 10*log10(power in each freq. band), Q4 int16_t WebRtcVad_CalculateFeatures(VadInstT* self, const int16_t* data_in,
* int data_length, int16_t* data_out);
* Return: total power in the signal (NOTE! This value is not exact since it
* is only used in a comparison.
*/
int16_t WebRtcVad_get_features(VadInstT* inst,
const int16_t* in_vector,
int frame_size,
int16_t* out_vector);
#endif // WEBRTC_COMMON_AUDIO_VAD_VAD_FILTERBANK_H_ #endif // WEBRTC_COMMON_AUDIO_VAD_VAD_FILTERBANK_H_

View File

@ -49,8 +49,8 @@ TEST_F(VadTest, vad_filterbank) {
for (size_t j = 0; j < kFrameLengthsSize; ++j) { for (size_t j = 0; j < kFrameLengthsSize; ++j) {
if (ValidRatesAndFrameLengths(8000, kFrameLengths[j])) { if (ValidRatesAndFrameLengths(8000, kFrameLengths[j])) {
EXPECT_EQ(kReference[frame_length_index], EXPECT_EQ(kReference[frame_length_index],
WebRtcVad_get_features(self, speech, kFrameLengths[j], WebRtcVad_CalculateFeatures(self, speech, kFrameLengths[j],
data_out)); data_out));
for (int k = 0; k < NUM_CHANNELS; ++k) { for (int k = 0; k < NUM_CHANNELS; ++k) {
EXPECT_EQ(kReferencePowers[k + frame_length_index * NUM_CHANNELS], EXPECT_EQ(kReferencePowers[k + frame_length_index * NUM_CHANNELS],
data_out[k]); data_out[k]);
@ -65,8 +65,8 @@ TEST_F(VadTest, vad_filterbank) {
ASSERT_EQ(0, WebRtcVad_InitCore(self, 0)); ASSERT_EQ(0, WebRtcVad_InitCore(self, 0));
for (size_t j = 0; j < kFrameLengthsSize; ++j) { for (size_t j = 0; j < kFrameLengthsSize; ++j) {
if (ValidRatesAndFrameLengths(8000, kFrameLengths[j])) { if (ValidRatesAndFrameLengths(8000, kFrameLengths[j])) {
EXPECT_EQ(0, WebRtcVad_get_features(self, speech, kFrameLengths[j], EXPECT_EQ(0, WebRtcVad_CalculateFeatures(self, speech, kFrameLengths[j],
data_out)); data_out));
for (int k = 0; k < NUM_CHANNELS; ++k) { for (int k = 0; k < NUM_CHANNELS; ++k) {
EXPECT_EQ(kOffsetVector[k], data_out[k]); EXPECT_EQ(kOffsetVector[k], data_out[k]);
} }
@ -81,8 +81,8 @@ TEST_F(VadTest, vad_filterbank) {
for (size_t j = 0; j < kFrameLengthsSize; ++j) { for (size_t j = 0; j < kFrameLengthsSize; ++j) {
if (ValidRatesAndFrameLengths(8000, kFrameLengths[j])) { if (ValidRatesAndFrameLengths(8000, kFrameLengths[j])) {
ASSERT_EQ(0, WebRtcVad_InitCore(self, 0)); ASSERT_EQ(0, WebRtcVad_InitCore(self, 0));
EXPECT_EQ(0, WebRtcVad_get_features(self, speech, kFrameLengths[j], EXPECT_EQ(0, WebRtcVad_CalculateFeatures(self, speech, kFrameLengths[j],
data_out)); data_out));
for (int k = 0; k < NUM_CHANNELS; ++k) { for (int k = 0; k < NUM_CHANNELS; ++k) {
EXPECT_EQ(kOffsetVector[k], data_out[k]); EXPECT_EQ(kOffsetVector[k], data_out[k]);
} }