diff --git a/modules/audio_processing/aec/main/source/aec.gyp b/modules/audio_processing/aec/main/source/aec.gyp index 769135999..96f9bd75b 100644 --- a/modules/audio_processing/aec/main/source/aec.gyp +++ b/modules/audio_processing/aec/main/source/aec.gyp @@ -30,6 +30,7 @@ '../interface/echo_cancellation.h', 'echo_cancellation.c', 'aec_core.c', + 'aec_core_sse2.c', 'aec_core.h', 'resampler.c', 'resampler.h', diff --git a/modules/audio_processing/aec/main/source/aec_core.c b/modules/audio_processing/aec/main/source/aec_core.c index c13654ca7..daf684755 100644 --- a/modules/audio_processing/aec/main/source/aec_core.c +++ b/modules/audio_processing/aec/main/source/aec_core.c @@ -18,6 +18,7 @@ #include "aec_core.h" #include "ring_buffer.h" +#include "system_wrappers/interface/cpu_features_wrapper.h" #define IP_LEN PART_LEN // this must be at least ceil(2 + sqrt(PART_LEN)) #define W_LEN PART_LEN @@ -215,6 +216,49 @@ int WebRtcAec_FreeAec(aec_t *aec) return 0; } +static void FilterFar(aec_t *aec, float yf[2][PART_LEN1]) +{ + for (int i = 0; i < NR_PART; i++) { + int xPos = (i + aec->xfBufBlockPos) * PART_LEN1; + // Check for wrap + if (i + aec->xfBufBlockPos >= NR_PART) { + xPos -= NR_PART*(PART_LEN1); + } + + int pos = i * PART_LEN1; + for (int j = 0; j < PART_LEN1; j++) { + yf[0][j] += MulRe(aec->xfBuf[0][xPos + j], aec->xfBuf[1][xPos + j], + aec->wfBuf[0][ pos + j], aec->wfBuf[1][ pos + j]); + yf[1][j] += MulIm(aec->xfBuf[0][xPos + j], aec->xfBuf[1][xPos + j], + aec->wfBuf[0][ pos + j], aec->wfBuf[1][ pos + j]); + } + } +} + +static void ScaleErrorSignal(aec_t *aec, float ef[2][PART_LEN1]) +{ + for (int i = 0; i < (PART_LEN1); i++) { + ef[0][i] /= (aec->xPow[i] + 1e-10f); + ef[1][i] /= (aec->xPow[i] + 1e-10f); + float absEf = sqrtf(ef[0][i] * ef[0][i] + ef[1][i] * ef[1][i]); + + if (absEf > aec->errThresh) { + absEf = aec->errThresh / (absEf + 1e-10f); + ef[0][i] *= absEf; + ef[1][i] *= absEf; + } + + // Stepsize factor + ef[0][i] *= aec->mu; + ef[1][i] *= aec->mu; + } +} + +WebRtcAec_FilterFar_t WebRtcAec_FilterFar; +WebRtcAec_ScaleErrorSignal_t WebRtcAec_ScaleErrorSignal; + +extern void WebRtcAec_InitAec_SSE2(void); + int WebRtcAec_InitAec(aec_t *aec, int sampFreq) { int i; @@ -291,7 +335,8 @@ int WebRtcAec_InitAec(aec_t *aec, int sampFreq) // Holds the last block written to aec->xfBufBlockPos = 0; - + // TODO: Investigate need for these initializations. Deleting them doesn't + // change the output at all and yields 0.4% overall speedup. memset(aec->xfBuf, 0, sizeof(complex_t) * NR_PART * PART_LEN1); memset(aec->wfBuf, 0, sizeof(complex_t) * NR_PART * PART_LEN1); memset(aec->sde, 0, sizeof(complex_t) * PART_LEN1); @@ -336,6 +381,15 @@ int WebRtcAec_InitAec(aec_t *aec, int sampFreq) aec->metricsMode = 0; WebRtcAec_InitMetrics(aec); + // Assembly optimization + WebRtcAec_FilterFar = FilterFar; + WebRtcAec_ScaleErrorSignal = ScaleErrorSignal; + if (WebRtc_GetCPUInfo(kSSE2)) { +#if defined(__SSE2__) + WebRtcAec_InitAec_SSE2(); +#endif + } + return 0; } @@ -434,7 +488,8 @@ static void ProcessBlock(aec_t *aec, const short *farend, float absEf; float fft[PART_LEN2]; - complex_t xf[PART_LEN1], df[PART_LEN1], yf[PART_LEN1], ef[PART_LEN1]; + float xf[2][PART_LEN1], yf[2][PART_LEN1], ef[2][PART_LEN1]; + complex_t df[PART_LEN1]; int ip[IP_LEN]; float wfft[W_LEN]; @@ -479,14 +534,14 @@ static void ProcessBlock(aec_t *aec, const short *farend, rdft(PART_LEN2, 1, fft, ip, wfft); // Far fft - xf[0][1] = 0; - xf[PART_LEN][1] = 0; + xf[1][0] = 0; + xf[1][PART_LEN] = 0; xf[0][0] = fft[0]; - xf[PART_LEN][0] = fft[1]; + xf[0][PART_LEN] = fft[1]; for (i = 1; i < PART_LEN; i++) { - xf[i][0] = fft[2 * i]; - xf[i][1] = fft[2 * i + 1]; + xf[0][i] = fft[2 * i]; + xf[1][i] = fft[2 * i + 1]; } // Near fft @@ -505,7 +560,7 @@ static void ProcessBlock(aec_t *aec, const short *farend, // Power smoothing for (i = 0; i < PART_LEN1; i++) { aec->xPow[i] = gPow[0] * aec->xPow[i] + gPow[1] * NR_PART * - (xf[i][0] * xf[i][0] + xf[i][1] * xf[i][1]); + (xf[0][i] * xf[0][i] + xf[1][i] * xf[1][i]); aec->dPow[i] = gPow[0] * aec->dPow[i] + gPow[1] * (df[i][0] * df[i][0] + df[i][1] * df[i][1]); } @@ -550,34 +605,22 @@ static void ProcessBlock(aec_t *aec, const short *farend, } // Buffer xf - memcpy(aec->xfBuf + aec->xfBufBlockPos * PART_LEN1, xf, sizeof(complex_t) * - PART_LEN1); + memcpy(aec->xfBuf[0] + aec->xfBufBlockPos * PART_LEN1, xf[0], + sizeof(float) * PART_LEN1); + memcpy(aec->xfBuf[1] + aec->xfBufBlockPos * PART_LEN1, xf[1], + sizeof(float) * PART_LEN1); - memset(yf, 0, sizeof(complex_t) * (PART_LEN1)); + memset(yf[0], 0, sizeof(float) * (PART_LEN1 * 2)); // Filter far - for (i = 0; i < NR_PART; i++) { - xPos = (i + aec->xfBufBlockPos) * PART_LEN1; - // Check for wrap - if (i + aec->xfBufBlockPos >= NR_PART) { - xPos -= NR_PART*(PART_LEN1); - } - - pos = i * PART_LEN1; - for (j = 0; j < PART_LEN1; j++) { - yf[j][0] += MulRe(aec->xfBuf[xPos + j][0], aec->xfBuf[xPos + j][1], - aec->wfBuf[pos + j][0], aec->wfBuf[pos + j][1]); - yf[j][1] += MulIm(aec->xfBuf[xPos + j][0], aec->xfBuf[xPos + j][1], - aec->wfBuf[pos + j][0], aec->wfBuf[pos + j][1]); - } - } + WebRtcAec_FilterFar(aec, yf); // Inverse fft to obtain echo estimate and error. fft[0] = yf[0][0]; - fft[1] = yf[PART_LEN][0]; + fft[1] = yf[0][PART_LEN]; for (i = 1; i < PART_LEN; i++) { - fft[2 * i] = yf[i][0]; - fft[2 * i + 1] = yf[i][1]; + fft[2 * i] = yf[0][i]; + fft[2 * i + 1] = yf[1][i]; } rdft(PART_LEN2, -1, fft, ip, wfft); @@ -596,32 +639,17 @@ static void ProcessBlock(aec_t *aec, const short *farend, memcpy(fft + PART_LEN, e, sizeof(float) * PART_LEN); rdft(PART_LEN2, 1, fft, ip, wfft); - ef[0][1] = 0; - ef[PART_LEN][1] = 0; + ef[1][0] = 0; + ef[1][PART_LEN] = 0; ef[0][0] = fft[0]; - ef[PART_LEN][0] = fft[1]; + ef[0][PART_LEN] = fft[1]; for (i = 1; i < PART_LEN; i++) { - ef[i][0] = fft[2 * i]; - ef[i][1] = fft[2 * i + 1]; + ef[0][i] = fft[2 * i]; + ef[1][i] = fft[2 * i + 1]; } // Scale error signal inversely with far power. - for (i = 0; i < (PART_LEN1); i++) { - ef[i][0] /= (aec->xPow[i] + 1e-10f); - ef[i][1] /= (aec->xPow[i] + 1e-10f); - absEf = sqrtf(ef[i][0] * ef[i][0] + ef[i][1] * ef[i][1]); - - if (absEf > aec->errThresh) { - absEf = aec->errThresh / (absEf + 1e-10f); - ef[i][0] *= absEf; - ef[i][1] *= absEf; - } - - // Stepsize factor - ef[i][0] *= aec->mu; - ef[i][1] *= aec->mu; - } - + WebRtcAec_ScaleErrorSignal(aec, ef); #ifdef G167 if (aec->adaptToggle) { #endif @@ -643,16 +671,20 @@ static void ProcessBlock(aec_t *aec, const short *farend, -aec->xfBuf[xPos + j][1], ef[j][0], ef[j][1]); } #else - fft[0] = MulRe(aec->xfBuf[xPos][0], -aec->xfBuf[xPos][1], ef[0][0], ef[0][1]); - fft[1] = MulRe(aec->xfBuf[xPos + PART_LEN][0], -aec->xfBuf[xPos + PART_LEN][1], - ef[PART_LEN][0], ef[PART_LEN][1]); + fft[0] = MulRe(aec->xfBuf[0][xPos], -aec->xfBuf[1][xPos], + ef[0][0], ef[1][0]); + fft[1] = MulRe(aec->xfBuf[0][xPos + PART_LEN], + -aec->xfBuf[1][xPos + PART_LEN], + ef[0][PART_LEN], ef[1][PART_LEN]); for (j = 1; j < PART_LEN; j++) { - fft[2 * j] = MulRe(aec->xfBuf[xPos + j][0], -aec->xfBuf[xPos + j][1], - ef[j][0], ef[j][1]); - fft[2 * j + 1] = MulIm(aec->xfBuf[xPos + j][0], -aec->xfBuf[xPos + j][1], - ef[j][0], ef[j][1]); + fft[2 * j] = MulRe(aec->xfBuf[0][xPos + j], + -aec->xfBuf[1][xPos + j], + ef[0][j], ef[1][j]); + fft[2 * j + 1] = MulIm(aec->xfBuf[0][xPos + j], + -aec->xfBuf[1][xPos + j], + ef[0][j], ef[1][j]); } rdft(PART_LEN2, -1, fft, ip, wfft); memset(fft + PART_LEN, 0, sizeof(float)*PART_LEN); @@ -663,12 +695,12 @@ static void ProcessBlock(aec_t *aec, const short *farend, } rdft(PART_LEN2, 1, fft, ip, wfft); - aec->wfBuf[pos][0] += fft[0]; - aec->wfBuf[pos + PART_LEN][0] += fft[1]; + aec->wfBuf[0][pos] += fft[0]; + aec->wfBuf[0][pos + PART_LEN] += fft[1]; for (j = 1; j < PART_LEN; j++) { - aec->wfBuf[pos + j][0] += fft[2 * j]; - aec->wfBuf[pos + j][1] += fft[2 * j + 1]; + aec->wfBuf[0][pos + j] += fft[2 * j]; + aec->wfBuf[1][pos + j] += fft[2 * j + 1]; } #endif // UNCONSTR } @@ -759,8 +791,8 @@ static void NonLinearProcessing(aec_t *aec, int *ip, float *wfft, short *output, pos = i * PART_LEN1; wfEn = 0; for (j = 0; j < PART_LEN1; j++) { - wfEn += aec->wfBuf[pos + j][0] * aec->wfBuf[pos + j][0] + - aec->wfBuf[pos + j][1] * aec->wfBuf[pos + j][1]; + wfEn += aec->wfBuf[0][pos + j] * aec->wfBuf[0][pos + j] + + aec->wfBuf[1][pos + j] * aec->wfBuf[1][pos + j]; } if (wfEn > wfEnMax) { diff --git a/modules/audio_processing/aec/main/source/aec_core.h b/modules/audio_processing/aec/main/source/aec_core.h index 88a3bcb66..35d2f6b70 100644 --- a/modules/audio_processing/aec/main/source/aec_core.h +++ b/modules/audio_processing/aec/main/source/aec_core.h @@ -37,6 +37,13 @@ #define BLOCKL_MAX FRAME_LEN typedef float complex_t[2]; +// For performance reasons, some arrays of complex numbers are replaced by twice +// as long arrays of float, all the real parts followed by all the imaginary +// ones (complex_t[SIZE] -> float[2][SIZE]). This allows SIMD optimizations and +// is better than two arrays (one for the real parts and one for the imaginary +// parts) as this other way would require two pointers instead of one and cause +// extra register spilling. This also allows the offsets to be calculated at +// compile time. // Metrics enum {offsetLevel = -100}; @@ -95,8 +102,8 @@ typedef struct { fftw_complex wfBuf[NR_PART * PART_LEN1]; fftw_complex sde[PART_LEN1]; #else - complex_t xfBuf[NR_PART * PART_LEN1]; // farend fft buffer - complex_t wfBuf[NR_PART * PART_LEN1]; // filter fft + float xfBuf[2][NR_PART * PART_LEN1]; // farend fft buffer + float wfBuf[2][NR_PART * PART_LEN1]; // filter fft complex_t sde[PART_LEN1]; // cross-psd of nearend and error complex_t sxd[PART_LEN1]; // cross-psd of farend and nearend complex_t xfwBuf[NR_PART * PART_LEN1]; // farend windowed fft buffer @@ -159,6 +166,11 @@ typedef struct { #endif } aec_t; +typedef void (*WebRtcAec_FilterFar_t)(aec_t *aec, float yf[2][PART_LEN1]); +extern WebRtcAec_FilterFar_t WebRtcAec_FilterFar; +typedef void (*WebRtcAec_ScaleErrorSignal_t)(aec_t *aec, float ef[2][PART_LEN1]); +extern WebRtcAec_ScaleErrorSignal_t WebRtcAec_ScaleErrorSignal; + int WebRtcAec_CreateAec(aec_t **aec); int WebRtcAec_FreeAec(aec_t *aec); int WebRtcAec_InitAec(aec_t *aec, int sampFreq); diff --git a/modules/audio_processing/aec/main/source/aec_core_sse2.c b/modules/audio_processing/aec/main/source/aec_core_sse2.c new file mode 100644 index 000000000..e431e049c --- /dev/null +++ b/modules/audio_processing/aec/main/source/aec_core_sse2.c @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2011 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +/* + * The core AEC algorithm, SSE2 version of speed-critical functions. + */ + +#if defined(__SSE2__) +#include +#include + +#include "aec_core.h" + +__inline static float MulRe(float aRe, float aIm, float bRe, float bIm) +{ + return aRe * bRe - aIm * bIm; +} + +__inline static float MulIm(float aRe, float aIm, float bRe, float bIm) +{ + return aRe * bIm + aIm * bRe; +} + +static void FilterFarSSE2(aec_t *aec, float yf[2][PART_LEN1]) +{ + for (int i = 0; i < NR_PART; i++) { + int xPos = (i + aec->xfBufBlockPos) * PART_LEN1; + // Check for wrap + if (i + aec->xfBufBlockPos >= NR_PART) { + xPos -= NR_PART*(PART_LEN1); + } + + int pos = i * PART_LEN1; + // vectorized code (four at once) + int j; + for (j = 0; j + 3 < PART_LEN1; j += 4) { + const __m128 xfBuf_re = _mm_loadu_ps(&aec->xfBuf[0][xPos + j]); + const __m128 xfBuf_im = _mm_loadu_ps(&aec->xfBuf[1][xPos + j]); + const __m128 wfBuf_re = _mm_loadu_ps(&aec->wfBuf[0][pos + j]); + const __m128 wfBuf_im = _mm_loadu_ps(&aec->wfBuf[1][pos + j]); + const __m128 yf_re = _mm_loadu_ps(&yf[0][j]); + const __m128 yf_im = _mm_loadu_ps(&yf[1][j]); + const __m128 a = _mm_mul_ps(xfBuf_re, wfBuf_re); + const __m128 b = _mm_mul_ps(xfBuf_im, wfBuf_im); + const __m128 c = _mm_mul_ps(xfBuf_re, wfBuf_im); + const __m128 d = _mm_mul_ps(xfBuf_im, wfBuf_re); + const __m128 e = _mm_sub_ps(a, b); + const __m128 f = _mm_add_ps(c, d); + const __m128 g = _mm_add_ps(yf_re, e); + const __m128 h = _mm_add_ps(yf_im, f); + _mm_storeu_ps(&yf[0][j], g); + _mm_storeu_ps(&yf[1][j], h); + } + // scalar code for the remaining items. + for (; j < PART_LEN1; j++) { + yf[0][j] += MulRe(aec->xfBuf[0][xPos + j], aec->xfBuf[1][xPos + j], + aec->wfBuf[0][ pos + j], aec->wfBuf[1][ pos + j]); + yf[1][j] += MulIm(aec->xfBuf[0][xPos + j], aec->xfBuf[1][xPos + j], + aec->wfBuf[0][ pos + j], aec->wfBuf[1][ pos + j]); + } + } +} + +static void ScaleErrorSignalSSE2(aec_t *aec, float ef[2][PART_LEN1]) +{ + const __m128 k1e_10f = _mm_set1_ps(1e-10f); + const __m128 kThresh = _mm_set1_ps(aec->errThresh); + const __m128 kMu = _mm_set1_ps(aec->mu); + + int i; + // vectorized code (four at once) + for (i = 0; i + 3 < PART_LEN1; i += 4) { + const __m128 xPow = _mm_loadu_ps(&aec->xPow[i]); + __m128 ef_re = _mm_loadu_ps(&ef[0][i]); + __m128 ef_im = _mm_loadu_ps(&ef[1][i]); + + const __m128 xPowPlus = _mm_add_ps(xPow, k1e_10f); + ef_re = _mm_div_ps(ef_re, xPowPlus); + ef_im = _mm_div_ps(ef_im, xPowPlus); + const __m128 ef_re2 = _mm_mul_ps(ef_re, ef_re); + const __m128 ef_im2 = _mm_mul_ps(ef_im, ef_im); + const __m128 ef_sum2 = _mm_add_ps(ef_re2, ef_im2); + const __m128 absEf = _mm_sqrt_ps(ef_sum2); + const __m128 bigger = _mm_cmpgt_ps(absEf, kThresh); + __m128 absEfPlus = _mm_add_ps(absEf, k1e_10f); + const __m128 absEfInv = _mm_div_ps(kThresh, absEfPlus); + __m128 ef_re_if = _mm_mul_ps(ef_re, absEfInv); + __m128 ef_im_if = _mm_mul_ps(ef_im, absEfInv); + ef_re_if = _mm_and_ps(bigger, ef_re_if); + ef_im_if = _mm_and_ps(bigger, ef_im_if); + ef_re = _mm_andnot_ps(bigger, ef_re); + ef_im = _mm_andnot_ps(bigger, ef_im); + ef_re = _mm_or_ps(ef_re, ef_re_if); + ef_im = _mm_or_ps(ef_im, ef_im_if); + ef_re = _mm_mul_ps(ef_re, kMu); + ef_im = _mm_mul_ps(ef_im, kMu); + + _mm_storeu_ps(&ef[0][i], ef_re); + _mm_storeu_ps(&ef[1][i], ef_im); + } + // scalar code for the remaining items. + for (; i < (PART_LEN1); i++) { + ef[0][i] /= (aec->xPow[i] + 1e-10f); + ef[1][i] /= (aec->xPow[i] + 1e-10f); + float absEf = sqrtf(ef[0][i] * ef[0][i] + ef[1][i] * ef[1][i]); + + if (absEf > aec->errThresh) { + absEf = aec->errThresh / (absEf + 1e-10f); + ef[0][i] *= absEf; + ef[1][i] *= absEf; + } + + // Stepsize factor + ef[0][i] *= aec->mu; + ef[1][i] *= aec->mu; + } +} + +void WebRtcAec_InitAec_SSE2(void) { + WebRtcAec_FilterFar = FilterFarSSE2; + WebRtcAec_ScaleErrorSignal = ScaleErrorSignalSSE2; +} + +#endif //__SSE2__