Partial vectorization of "ProcessBlock":

* new file for SSE2 code, code selection through function pointers.
* structure change for array of complex numbers.
* 3.8% AEC overall speedup for straight C path.
* 8.8% AEC overall speedup for SSE2 path.
Review URL: http://webrtc-codereview.appspot.com/34002

git-svn-id: http://webrtc.googlecode.com/svn/trunk@36 4adac7df-926f-26a2-2b94-8c16560cd09d
This commit is contained in:
cduvivier@google.com 2011-06-02 01:38:10 +00:00
parent 43a4bd594d
commit 936b36dbf6
4 changed files with 244 additions and 66 deletions

View File

@ -30,6 +30,7 @@
'../interface/echo_cancellation.h',
'echo_cancellation.c',
'aec_core.c',
'aec_core_sse2.c',
'aec_core.h',
'resampler.c',
'resampler.h',

View File

@ -18,6 +18,7 @@
#include "aec_core.h"
#include "ring_buffer.h"
#include "system_wrappers/interface/cpu_features_wrapper.h"
#define IP_LEN PART_LEN // this must be at least ceil(2 + sqrt(PART_LEN))
#define W_LEN PART_LEN
@ -215,6 +216,52 @@ int WebRtcAec_FreeAec(aec_t *aec)
return 0;
}
static void FilterFar(aec_t *aec, float yf[2][PART_LEN1])
{
int i, j, pos;
for (i = 0; i < NR_PART; i++) {
int xPos = (i + aec->xfBufBlockPos) * PART_LEN1;
// Check for wrap
if (i + aec->xfBufBlockPos >= NR_PART) {
xPos -= NR_PART*(PART_LEN1);
}
pos = i * PART_LEN1;
for (j = 0; j < PART_LEN1; j++) {
yf[0][j] += MulRe(aec->xfBuf[0][xPos + j], aec->xfBuf[1][xPos + j],
aec->wfBuf[0][ pos + j], aec->wfBuf[1][ pos + j]);
yf[1][j] += MulIm(aec->xfBuf[0][xPos + j], aec->xfBuf[1][xPos + j],
aec->wfBuf[0][ pos + j], aec->wfBuf[1][ pos + j]);
}
}
}
static void ScaleErrorSignal(aec_t *aec, float ef[2][PART_LEN1])
{
int i;
float absEf;
for (i = 0; i < (PART_LEN1); i++) {
ef[0][i] /= (aec->xPow[i] + 1e-10f);
ef[1][i] /= (aec->xPow[i] + 1e-10f);
absEf = sqrtf(ef[0][i] * ef[0][i] + ef[1][i] * ef[1][i]);
if (absEf > aec->errThresh) {
absEf = aec->errThresh / (absEf + 1e-10f);
ef[0][i] *= absEf;
ef[1][i] *= absEf;
}
// Stepsize factor
ef[0][i] *= aec->mu;
ef[1][i] *= aec->mu;
}
}
WebRtcAec_FilterFar_t WebRtcAec_FilterFar;
WebRtcAec_ScaleErrorSignal_t WebRtcAec_ScaleErrorSignal;
extern void WebRtcAec_InitAec_SSE2(void);
int WebRtcAec_InitAec(aec_t *aec, int sampFreq)
{
int i;
@ -291,7 +338,8 @@ int WebRtcAec_InitAec(aec_t *aec, int sampFreq)
// Holds the last block written to
aec->xfBufBlockPos = 0;
// TODO: Investigate need for these initializations. Deleting them doesn't
// change the output at all and yields 0.4% overall speedup.
memset(aec->xfBuf, 0, sizeof(complex_t) * NR_PART * PART_LEN1);
memset(aec->wfBuf, 0, sizeof(complex_t) * NR_PART * PART_LEN1);
memset(aec->sde, 0, sizeof(complex_t) * PART_LEN1);
@ -336,6 +384,15 @@ int WebRtcAec_InitAec(aec_t *aec, int sampFreq)
aec->metricsMode = 0;
WebRtcAec_InitMetrics(aec);
// Assembly optimization
WebRtcAec_FilterFar = FilterFar;
WebRtcAec_ScaleErrorSignal = ScaleErrorSignal;
if (WebRtc_GetCPUInfo(kSSE2)) {
#if defined(__SSE2__)
WebRtcAec_InitAec_SSE2();
#endif
}
return 0;
}
@ -431,10 +488,10 @@ static void ProcessBlock(aec_t *aec, const short *farend,
short eInt16[PART_LEN];
float scale;
int xPos;
float absEf;
float fft[PART_LEN2];
complex_t xf[PART_LEN1], df[PART_LEN1], yf[PART_LEN1], ef[PART_LEN1];
float xf[2][PART_LEN1], yf[2][PART_LEN1], ef[2][PART_LEN1];
complex_t df[PART_LEN1];
int ip[IP_LEN];
float wfft[W_LEN];
@ -479,14 +536,14 @@ static void ProcessBlock(aec_t *aec, const short *farend,
rdft(PART_LEN2, 1, fft, ip, wfft);
// Far fft
xf[0][1] = 0;
xf[PART_LEN][1] = 0;
xf[1][0] = 0;
xf[1][PART_LEN] = 0;
xf[0][0] = fft[0];
xf[PART_LEN][0] = fft[1];
xf[0][PART_LEN] = fft[1];
for (i = 1; i < PART_LEN; i++) {
xf[i][0] = fft[2 * i];
xf[i][1] = fft[2 * i + 1];
xf[0][i] = fft[2 * i];
xf[1][i] = fft[2 * i + 1];
}
// Near fft
@ -505,7 +562,7 @@ static void ProcessBlock(aec_t *aec, const short *farend,
// Power smoothing
for (i = 0; i < PART_LEN1; i++) {
aec->xPow[i] = gPow[0] * aec->xPow[i] + gPow[1] * NR_PART *
(xf[i][0] * xf[i][0] + xf[i][1] * xf[i][1]);
(xf[0][i] * xf[0][i] + xf[1][i] * xf[1][i]);
aec->dPow[i] = gPow[0] * aec->dPow[i] + gPow[1] *
(df[i][0] * df[i][0] + df[i][1] * df[i][1]);
}
@ -550,34 +607,22 @@ static void ProcessBlock(aec_t *aec, const short *farend,
}
// Buffer xf
memcpy(aec->xfBuf + aec->xfBufBlockPos * PART_LEN1, xf, sizeof(complex_t) *
PART_LEN1);
memcpy(aec->xfBuf[0] + aec->xfBufBlockPos * PART_LEN1, xf[0],
sizeof(float) * PART_LEN1);
memcpy(aec->xfBuf[1] + aec->xfBufBlockPos * PART_LEN1, xf[1],
sizeof(float) * PART_LEN1);
memset(yf, 0, sizeof(complex_t) * (PART_LEN1));
memset(yf[0], 0, sizeof(float) * (PART_LEN1 * 2));
// Filter far
for (i = 0; i < NR_PART; i++) {
xPos = (i + aec->xfBufBlockPos) * PART_LEN1;
// Check for wrap
if (i + aec->xfBufBlockPos >= NR_PART) {
xPos -= NR_PART*(PART_LEN1);
}
pos = i * PART_LEN1;
for (j = 0; j < PART_LEN1; j++) {
yf[j][0] += MulRe(aec->xfBuf[xPos + j][0], aec->xfBuf[xPos + j][1],
aec->wfBuf[pos + j][0], aec->wfBuf[pos + j][1]);
yf[j][1] += MulIm(aec->xfBuf[xPos + j][0], aec->xfBuf[xPos + j][1],
aec->wfBuf[pos + j][0], aec->wfBuf[pos + j][1]);
}
}
WebRtcAec_FilterFar(aec, yf);
// Inverse fft to obtain echo estimate and error.
fft[0] = yf[0][0];
fft[1] = yf[PART_LEN][0];
fft[1] = yf[0][PART_LEN];
for (i = 1; i < PART_LEN; i++) {
fft[2 * i] = yf[i][0];
fft[2 * i + 1] = yf[i][1];
fft[2 * i] = yf[0][i];
fft[2 * i + 1] = yf[1][i];
}
rdft(PART_LEN2, -1, fft, ip, wfft);
@ -596,32 +641,17 @@ static void ProcessBlock(aec_t *aec, const short *farend,
memcpy(fft + PART_LEN, e, sizeof(float) * PART_LEN);
rdft(PART_LEN2, 1, fft, ip, wfft);
ef[0][1] = 0;
ef[PART_LEN][1] = 0;
ef[1][0] = 0;
ef[1][PART_LEN] = 0;
ef[0][0] = fft[0];
ef[PART_LEN][0] = fft[1];
ef[0][PART_LEN] = fft[1];
for (i = 1; i < PART_LEN; i++) {
ef[i][0] = fft[2 * i];
ef[i][1] = fft[2 * i + 1];
ef[0][i] = fft[2 * i];
ef[1][i] = fft[2 * i + 1];
}
// Scale error signal inversely with far power.
for (i = 0; i < (PART_LEN1); i++) {
ef[i][0] /= (aec->xPow[i] + 1e-10f);
ef[i][1] /= (aec->xPow[i] + 1e-10f);
absEf = sqrtf(ef[i][0] * ef[i][0] + ef[i][1] * ef[i][1]);
if (absEf > aec->errThresh) {
absEf = aec->errThresh / (absEf + 1e-10f);
ef[i][0] *= absEf;
ef[i][1] *= absEf;
}
// Stepsize factor
ef[i][0] *= aec->mu;
ef[i][1] *= aec->mu;
}
WebRtcAec_ScaleErrorSignal(aec, ef);
#ifdef G167
if (aec->adaptToggle) {
#endif
@ -643,16 +673,20 @@ static void ProcessBlock(aec_t *aec, const short *farend,
-aec->xfBuf[xPos + j][1], ef[j][0], ef[j][1]);
}
#else
fft[0] = MulRe(aec->xfBuf[xPos][0], -aec->xfBuf[xPos][1], ef[0][0], ef[0][1]);
fft[1] = MulRe(aec->xfBuf[xPos + PART_LEN][0], -aec->xfBuf[xPos + PART_LEN][1],
ef[PART_LEN][0], ef[PART_LEN][1]);
fft[0] = MulRe(aec->xfBuf[0][xPos], -aec->xfBuf[1][xPos],
ef[0][0], ef[1][0]);
fft[1] = MulRe(aec->xfBuf[0][xPos + PART_LEN],
-aec->xfBuf[1][xPos + PART_LEN],
ef[0][PART_LEN], ef[1][PART_LEN]);
for (j = 1; j < PART_LEN; j++) {
fft[2 * j] = MulRe(aec->xfBuf[xPos + j][0], -aec->xfBuf[xPos + j][1],
ef[j][0], ef[j][1]);
fft[2 * j + 1] = MulIm(aec->xfBuf[xPos + j][0], -aec->xfBuf[xPos + j][1],
ef[j][0], ef[j][1]);
fft[2 * j] = MulRe(aec->xfBuf[0][xPos + j],
-aec->xfBuf[1][xPos + j],
ef[0][j], ef[1][j]);
fft[2 * j + 1] = MulIm(aec->xfBuf[0][xPos + j],
-aec->xfBuf[1][xPos + j],
ef[0][j], ef[1][j]);
}
rdft(PART_LEN2, -1, fft, ip, wfft);
memset(fft + PART_LEN, 0, sizeof(float)*PART_LEN);
@ -663,12 +697,12 @@ static void ProcessBlock(aec_t *aec, const short *farend,
}
rdft(PART_LEN2, 1, fft, ip, wfft);
aec->wfBuf[pos][0] += fft[0];
aec->wfBuf[pos + PART_LEN][0] += fft[1];
aec->wfBuf[0][pos] += fft[0];
aec->wfBuf[0][pos + PART_LEN] += fft[1];
for (j = 1; j < PART_LEN; j++) {
aec->wfBuf[pos + j][0] += fft[2 * j];
aec->wfBuf[pos + j][1] += fft[2 * j + 1];
aec->wfBuf[0][pos + j] += fft[2 * j];
aec->wfBuf[1][pos + j] += fft[2 * j + 1];
}
#endif // UNCONSTR
}
@ -759,8 +793,8 @@ static void NonLinearProcessing(aec_t *aec, int *ip, float *wfft, short *output,
pos = i * PART_LEN1;
wfEn = 0;
for (j = 0; j < PART_LEN1; j++) {
wfEn += aec->wfBuf[pos + j][0] * aec->wfBuf[pos + j][0] +
aec->wfBuf[pos + j][1] * aec->wfBuf[pos + j][1];
wfEn += aec->wfBuf[0][pos + j] * aec->wfBuf[0][pos + j] +
aec->wfBuf[1][pos + j] * aec->wfBuf[1][pos + j];
}
if (wfEn > wfEnMax) {

View File

@ -37,6 +37,13 @@
#define BLOCKL_MAX FRAME_LEN
typedef float complex_t[2];
// For performance reasons, some arrays of complex numbers are replaced by twice
// as long arrays of float, all the real parts followed by all the imaginary
// ones (complex_t[SIZE] -> float[2][SIZE]). This allows SIMD optimizations and
// is better than two arrays (one for the real parts and one for the imaginary
// parts) as this other way would require two pointers instead of one and cause
// extra register spilling. This also allows the offsets to be calculated at
// compile time.
// Metrics
enum {offsetLevel = -100};
@ -95,8 +102,8 @@ typedef struct {
fftw_complex wfBuf[NR_PART * PART_LEN1];
fftw_complex sde[PART_LEN1];
#else
complex_t xfBuf[NR_PART * PART_LEN1]; // farend fft buffer
complex_t wfBuf[NR_PART * PART_LEN1]; // filter fft
float xfBuf[2][NR_PART * PART_LEN1]; // farend fft buffer
float wfBuf[2][NR_PART * PART_LEN1]; // filter fft
complex_t sde[PART_LEN1]; // cross-psd of nearend and error
complex_t sxd[PART_LEN1]; // cross-psd of farend and nearend
complex_t xfwBuf[NR_PART * PART_LEN1]; // farend windowed fft buffer
@ -159,6 +166,11 @@ typedef struct {
#endif
} aec_t;
typedef void (*WebRtcAec_FilterFar_t)(aec_t *aec, float yf[2][PART_LEN1]);
extern WebRtcAec_FilterFar_t WebRtcAec_FilterFar;
typedef void (*WebRtcAec_ScaleErrorSignal_t)(aec_t *aec, float ef[2][PART_LEN1]);
extern WebRtcAec_ScaleErrorSignal_t WebRtcAec_ScaleErrorSignal;
int WebRtcAec_CreateAec(aec_t **aec);
int WebRtcAec_FreeAec(aec_t *aec);
int WebRtcAec_InitAec(aec_t *aec, int sampFreq);

View File

@ -0,0 +1,131 @@
/*
* Copyright (c) 2011 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
/*
* The core AEC algorithm, SSE2 version of speed-critical functions.
*/
#if defined(__SSE2__)
#include <emmintrin.h>
#include <math.h>
#include "aec_core.h"
__inline static float MulRe(float aRe, float aIm, float bRe, float bIm)
{
return aRe * bRe - aIm * bIm;
}
__inline static float MulIm(float aRe, float aIm, float bRe, float bIm)
{
return aRe * bIm + aIm * bRe;
}
static void FilterFarSSE2(aec_t *aec, float yf[2][PART_LEN1])
{
int i, j, pos;
for (i = 0; i < NR_PART; i++) {
int xPos = (i + aec->xfBufBlockPos) * PART_LEN1;
// Check for wrap
if (i + aec->xfBufBlockPos >= NR_PART) {
xPos -= NR_PART*(PART_LEN1);
}
pos = i * PART_LEN1;
// vectorized code (four at once)
for (j = 0; j + 3 < PART_LEN1; j += 4) {
const __m128 xfBuf_re = _mm_loadu_ps(&aec->xfBuf[0][xPos + j]);
const __m128 xfBuf_im = _mm_loadu_ps(&aec->xfBuf[1][xPos + j]);
const __m128 wfBuf_re = _mm_loadu_ps(&aec->wfBuf[0][pos + j]);
const __m128 wfBuf_im = _mm_loadu_ps(&aec->wfBuf[1][pos + j]);
const __m128 yf_re = _mm_loadu_ps(&yf[0][j]);
const __m128 yf_im = _mm_loadu_ps(&yf[1][j]);
const __m128 a = _mm_mul_ps(xfBuf_re, wfBuf_re);
const __m128 b = _mm_mul_ps(xfBuf_im, wfBuf_im);
const __m128 c = _mm_mul_ps(xfBuf_re, wfBuf_im);
const __m128 d = _mm_mul_ps(xfBuf_im, wfBuf_re);
const __m128 e = _mm_sub_ps(a, b);
const __m128 f = _mm_add_ps(c, d);
const __m128 g = _mm_add_ps(yf_re, e);
const __m128 h = _mm_add_ps(yf_im, f);
_mm_storeu_ps(&yf[0][j], g);
_mm_storeu_ps(&yf[1][j], h);
}
// scalar code for the remaining items.
for (; j < PART_LEN1; j++) {
yf[0][j] += MulRe(aec->xfBuf[0][xPos + j], aec->xfBuf[1][xPos + j],
aec->wfBuf[0][ pos + j], aec->wfBuf[1][ pos + j]);
yf[1][j] += MulIm(aec->xfBuf[0][xPos + j], aec->xfBuf[1][xPos + j],
aec->wfBuf[0][ pos + j], aec->wfBuf[1][ pos + j]);
}
}
}
static void ScaleErrorSignalSSE2(aec_t *aec, float ef[2][PART_LEN1])
{
const __m128 k1e_10f = _mm_set1_ps(1e-10f);
const __m128 kThresh = _mm_set1_ps(aec->errThresh);
const __m128 kMu = _mm_set1_ps(aec->mu);
int i;
// vectorized code (four at once)
for (i = 0; i + 3 < PART_LEN1; i += 4) {
const __m128 xPow = _mm_loadu_ps(&aec->xPow[i]);
__m128 ef_re = _mm_loadu_ps(&ef[0][i]);
__m128 ef_im = _mm_loadu_ps(&ef[1][i]);
const __m128 xPowPlus = _mm_add_ps(xPow, k1e_10f);
ef_re = _mm_div_ps(ef_re, xPowPlus);
ef_im = _mm_div_ps(ef_im, xPowPlus);
const __m128 ef_re2 = _mm_mul_ps(ef_re, ef_re);
const __m128 ef_im2 = _mm_mul_ps(ef_im, ef_im);
const __m128 ef_sum2 = _mm_add_ps(ef_re2, ef_im2);
const __m128 absEf = _mm_sqrt_ps(ef_sum2);
const __m128 bigger = _mm_cmpgt_ps(absEf, kThresh);
__m128 absEfPlus = _mm_add_ps(absEf, k1e_10f);
const __m128 absEfInv = _mm_div_ps(kThresh, absEfPlus);
__m128 ef_re_if = _mm_mul_ps(ef_re, absEfInv);
__m128 ef_im_if = _mm_mul_ps(ef_im, absEfInv);
ef_re_if = _mm_and_ps(bigger, ef_re_if);
ef_im_if = _mm_and_ps(bigger, ef_im_if);
ef_re = _mm_andnot_ps(bigger, ef_re);
ef_im = _mm_andnot_ps(bigger, ef_im);
ef_re = _mm_or_ps(ef_re, ef_re_if);
ef_im = _mm_or_ps(ef_im, ef_im_if);
ef_re = _mm_mul_ps(ef_re, kMu);
ef_im = _mm_mul_ps(ef_im, kMu);
_mm_storeu_ps(&ef[0][i], ef_re);
_mm_storeu_ps(&ef[1][i], ef_im);
}
// scalar code for the remaining items.
for (; i < (PART_LEN1); i++) {
ef[0][i] /= (aec->xPow[i] + 1e-10f);
ef[1][i] /= (aec->xPow[i] + 1e-10f);
float absEf = sqrtf(ef[0][i] * ef[0][i] + ef[1][i] * ef[1][i]);
if (absEf > aec->errThresh) {
absEf = aec->errThresh / (absEf + 1e-10f);
ef[0][i] *= absEf;
ef[1][i] *= absEf;
}
// Stepsize factor
ef[0][i] *= aec->mu;
ef[1][i] *= aec->mu;
}
}
void WebRtcAec_InitAec_SSE2(void) {
WebRtcAec_FilterFar = FilterFarSSE2;
WebRtcAec_ScaleErrorSignal = ScaleErrorSignalSSE2;
}
#endif //__SSE2__