Vectorization of "FilterAdaptation":

* 1.0% AEC overall speedup for straight C path.
* 6.2% AEC overall speedup for SSE2 path.
* fix warnings, make code compile with "-std=gnu89
-Wstrict-prototypes -Wold-style-definition -Wmissing-prototypes
-Wmissing-declarations -Wdeclaration-after-statement -Wextra -Wall
-Werror"
Review URL: http://webrtc-codereview.appspot.com/24012

git-svn-id: http://webrtc.googlecode.com/svn/trunk@38 4adac7df-926f-26a2-2b94-8c16560cd09d
This commit is contained in:
cduvivier@google.com
2011-06-02 23:50:06 +00:00
parent 56307f1257
commit a4f6303c5d
3 changed files with 168 additions and 66 deletions

View File

@@ -20,9 +20,6 @@
#include "ring_buffer.h"
#include "system_wrappers/interface/cpu_features_wrapper.h"
#define IP_LEN PART_LEN // this must be at least ceil(2 + sqrt(PART_LEN))
#define W_LEN PART_LEN
// Noise suppression
static const int converged = 250;
@@ -218,15 +215,16 @@ int WebRtcAec_FreeAec(aec_t *aec)
static void FilterFar(aec_t *aec, float yf[2][PART_LEN1])
{
int i, j, pos;
int i;
for (i = 0; i < NR_PART; i++) {
int j;
int xPos = (i + aec->xfBufBlockPos) * PART_LEN1;
int pos = i * PART_LEN1;
// Check for wrap
if (i + aec->xfBufBlockPos >= NR_PART) {
xPos -= NR_PART*(PART_LEN1);
}
pos = i * PART_LEN1;
for (j = 0; j < PART_LEN1; j++) {
yf[0][j] += MulRe(aec->xfBuf[0][xPos + j], aec->xfBuf[1][xPos + j],
aec->wfBuf[0][ pos + j], aec->wfBuf[1][ pos + j]);
@@ -257,10 +255,68 @@ static void ScaleErrorSignal(aec_t *aec, float ef[2][PART_LEN1])
}
}
static void FilterAdaptation(aec_t *aec, float *fft, float ef[2][PART_LEN1],
int ip[IP_LEN], float wfft[W_LEN]) {
int i, j;
for (i = 0; i < NR_PART; i++) {
int xPos = (i + aec->xfBufBlockPos)*(PART_LEN1);
int pos;
// Check for wrap
if (i + aec->xfBufBlockPos >= NR_PART) {
xPos -= NR_PART * PART_LEN1;
}
pos = i * PART_LEN1;
#ifdef UNCONSTR
for (j = 0; j < PART_LEN1; j++) {
aec->wfBuf[pos + j][0] += MulRe(aec->xfBuf[xPos + j][0],
-aec->xfBuf[xPos + j][1],
ef[j][0], ef[j][1]);
aec->wfBuf[pos + j][1] += MulIm(aec->xfBuf[xPos + j][0],
-aec->xfBuf[xPos + j][1],
ef[j][0], ef[j][1]);
}
#else
for (j = 0; j < PART_LEN; j++) {
fft[2 * j] = MulRe(aec->xfBuf[0][xPos + j],
-aec->xfBuf[1][xPos + j],
ef[0][j], ef[1][j]);
fft[2 * j + 1] = MulIm(aec->xfBuf[0][xPos + j],
-aec->xfBuf[1][xPos + j],
ef[0][j], ef[1][j]);
}
fft[1] = MulRe(aec->xfBuf[0][xPos + PART_LEN],
-aec->xfBuf[1][xPos + PART_LEN],
ef[0][PART_LEN], ef[1][PART_LEN]);
rdft(PART_LEN2, -1, fft, ip, wfft);
memset(fft + PART_LEN, 0, sizeof(float) * PART_LEN);
// fft scaling
{
float scale = 2.0f / PART_LEN2;
for (j = 0; j < PART_LEN; j++) {
fft[j] *= scale;
}
}
rdft(PART_LEN2, 1, fft, ip, wfft);
aec->wfBuf[0][pos] += fft[0];
aec->wfBuf[0][pos + PART_LEN] += fft[1];
for (j = 1; j < PART_LEN; j++) {
aec->wfBuf[0][pos + j] += fft[2 * j];
aec->wfBuf[1][pos + j] += fft[2 * j + 1];
}
#endif // UNCONSTR
}
}
WebRtcAec_FilterFar_t WebRtcAec_FilterFar;
WebRtcAec_ScaleErrorSignal_t WebRtcAec_ScaleErrorSignal;
extern void WebRtcAec_InitAec_SSE2(void);
WebRtcAec_FilterAdaptation_t WebRtcAec_FilterAdaptation;
int WebRtcAec_InitAec(aec_t *aec, int sampFreq)
{
@@ -387,6 +443,7 @@ int WebRtcAec_InitAec(aec_t *aec, int sampFreq)
// Assembly optimization
WebRtcAec_FilterFar = FilterFar;
WebRtcAec_ScaleErrorSignal = ScaleErrorSignal;
WebRtcAec_FilterAdaptation = FilterAdaptation;
if (WebRtc_GetCPUInfo(kSSE2)) {
#if defined(__SSE2__)
WebRtcAec_InitAec_SSE2();
@@ -483,11 +540,10 @@ static void ProcessBlock(aec_t *aec, const short *farend,
const short *nearend, const short *nearendH,
short *output, short *outputH)
{
int i, j, pos;
int i;
float d[PART_LEN], y[PART_LEN], e[PART_LEN], dH[PART_LEN];
short eInt16[PART_LEN];
float scale;
int xPos;
float fft[PART_LEN2];
float xf[2][PART_LEN1], yf[2][PART_LEN1], ef[2][PART_LEN1];
@@ -656,56 +712,7 @@ static void ProcessBlock(aec_t *aec, const short *farend,
if (aec->adaptToggle) {
#endif
// Filter adaptation
for (i = 0; i < NR_PART; i++) {
xPos = (i + aec->xfBufBlockPos)*(PART_LEN1);
// Check for wrap
if (i + aec->xfBufBlockPos >= NR_PART) {
xPos -= NR_PART*(PART_LEN1);
}
pos = i * PART_LEN1;
#ifdef UNCONSTR
for (j = 0; j < PART_LEN1; j++) {
aec->wfBuf[pos + j][0] += MulRe(aec->xfBuf[xPos + j][0],
-aec->xfBuf[xPos + j][1], ef[j][0], ef[j][1]);
aec->wfBuf[pos + j][1] += MulIm(aec->xfBuf[xPos + j][0],
-aec->xfBuf[xPos + j][1], ef[j][0], ef[j][1]);
}
#else
fft[0] = MulRe(aec->xfBuf[0][xPos], -aec->xfBuf[1][xPos],
ef[0][0], ef[1][0]);
fft[1] = MulRe(aec->xfBuf[0][xPos + PART_LEN],
-aec->xfBuf[1][xPos + PART_LEN],
ef[0][PART_LEN], ef[1][PART_LEN]);
for (j = 1; j < PART_LEN; j++) {
fft[2 * j] = MulRe(aec->xfBuf[0][xPos + j],
-aec->xfBuf[1][xPos + j],
ef[0][j], ef[1][j]);
fft[2 * j + 1] = MulIm(aec->xfBuf[0][xPos + j],
-aec->xfBuf[1][xPos + j],
ef[0][j], ef[1][j]);
}
rdft(PART_LEN2, -1, fft, ip, wfft);
memset(fft + PART_LEN, 0, sizeof(float)*PART_LEN);
scale = 2.0f / PART_LEN2;
for (j = 0; j < PART_LEN; j++) {
fft[j] *= scale; // fft scaling
}
rdft(PART_LEN2, 1, fft, ip, wfft);
aec->wfBuf[0][pos] += fft[0];
aec->wfBuf[0][pos + PART_LEN] += fft[1];
for (j = 1; j < PART_LEN; j++) {
aec->wfBuf[0][pos + j] += fft[2 * j];
aec->wfBuf[1][pos + j] += fft[2 * j + 1];
}
#endif // UNCONSTR
}
WebRtcAec_FilterAdaptation(aec, fft, ef, ip, wfft);
#ifdef G167
}
#endif

View File

@@ -170,10 +170,17 @@ typedef void (*WebRtcAec_FilterFar_t)(aec_t *aec, float yf[2][PART_LEN1]);
extern WebRtcAec_FilterFar_t WebRtcAec_FilterFar;
typedef void (*WebRtcAec_ScaleErrorSignal_t)(aec_t *aec, float ef[2][PART_LEN1]);
extern WebRtcAec_ScaleErrorSignal_t WebRtcAec_ScaleErrorSignal;
#define IP_LEN PART_LEN // this must be at least ceil(2 + sqrt(PART_LEN))
#define W_LEN PART_LEN
typedef void (*WebRtcAec_FilterAdaptation_t)
(aec_t *aec, float *fft, float ef[2][PART_LEN1], int ip[IP_LEN],
float wfft[W_LEN]);
extern WebRtcAec_FilterAdaptation_t WebRtcAec_FilterAdaptation;
int WebRtcAec_CreateAec(aec_t **aec);
int WebRtcAec_FreeAec(aec_t *aec);
int WebRtcAec_InitAec(aec_t *aec, int sampFreq);
void WebRtcAec_InitAec_SSE2(void);
void WebRtcAec_InitMetrics(aec_t *aec);
void WebRtcAec_ProcessFrame(aec_t *aec, const short *farend,

View File

@@ -30,15 +30,16 @@ __inline static float MulIm(float aRe, float aIm, float bRe, float bIm)
static void FilterFarSSE2(aec_t *aec, float yf[2][PART_LEN1])
{
int i, j, pos;
int i;
for (i = 0; i < NR_PART; i++) {
int j;
int xPos = (i + aec->xfBufBlockPos) * PART_LEN1;
int pos = i * PART_LEN1;
// Check for wrap
if (i + aec->xfBufBlockPos >= NR_PART) {
xPos -= NR_PART*(PART_LEN1);
}
pos = i * PART_LEN1;
// vectorized code (four at once)
for (j = 0; j + 3 < PART_LEN1; j += 4) {
const __m128 xfBuf_re = _mm_loadu_ps(&aec->xfBuf[0][xPos + j]);
@@ -78,12 +79,12 @@ static void ScaleErrorSignalSSE2(aec_t *aec, float ef[2][PART_LEN1])
// vectorized code (four at once)
for (i = 0; i + 3 < PART_LEN1; i += 4) {
const __m128 xPow = _mm_loadu_ps(&aec->xPow[i]);
__m128 ef_re = _mm_loadu_ps(&ef[0][i]);
__m128 ef_im = _mm_loadu_ps(&ef[1][i]);
const __m128 ef_re_base = _mm_loadu_ps(&ef[0][i]);
const __m128 ef_im_base = _mm_loadu_ps(&ef[1][i]);
const __m128 xPowPlus = _mm_add_ps(xPow, k1e_10f);
ef_re = _mm_div_ps(ef_re, xPowPlus);
ef_im = _mm_div_ps(ef_im, xPowPlus);
__m128 ef_re = _mm_div_ps(ef_re_base, xPowPlus);
__m128 ef_im = _mm_div_ps(ef_im_base, xPowPlus);
const __m128 ef_re2 = _mm_mul_ps(ef_re, ef_re);
const __m128 ef_im2 = _mm_mul_ps(ef_im, ef_im);
const __m128 ef_sum2 = _mm_add_ps(ef_re2, ef_im2);
@@ -107,9 +108,10 @@ static void ScaleErrorSignalSSE2(aec_t *aec, float ef[2][PART_LEN1])
}
// scalar code for the remaining items.
for (; i < (PART_LEN1); i++) {
float absEf;
ef[0][i] /= (aec->xPow[i] + 1e-10f);
ef[1][i] /= (aec->xPow[i] + 1e-10f);
float absEf = sqrtf(ef[0][i] * ef[0][i] + ef[1][i] * ef[1][i]);
absEf = sqrtf(ef[0][i] * ef[0][i] + ef[1][i] * ef[1][i]);
if (absEf > aec->errThresh) {
absEf = aec->errThresh / (absEf + 1e-10f);
@@ -123,9 +125,95 @@ static void ScaleErrorSignalSSE2(aec_t *aec, float ef[2][PART_LEN1])
}
}
static void FilterAdaptationSSE2(aec_t *aec, float *fft, float ef[2][PART_LEN1],
int ip[IP_LEN], float wfft[W_LEN]) {
int i, j;
for (i = 0; i < NR_PART; i++) {
int xPos = (i + aec->xfBufBlockPos)*(PART_LEN1);
int pos = i * PART_LEN1;
// Check for wrap
if (i + aec->xfBufBlockPos >= NR_PART) {
xPos -= NR_PART * PART_LEN1;
}
#ifdef UNCONSTR
for (j = 0; j < PART_LEN1; j++) {
aec->wfBuf[pos + j][0] += MulRe(aec->xfBuf[xPos + j][0],
-aec->xfBuf[xPos + j][1],
ef[j][0], ef[j][1]);
aec->wfBuf[pos + j][1] += MulIm(aec->xfBuf[xPos + j][0],
-aec->xfBuf[xPos + j][1],
ef[j][0], ef[j][1]);
}
#else
// Process the whole array...
for (j = 0; j < PART_LEN; j+= 4) {
// Load xfBuf and ef.
const __m128 xfBuf_re = _mm_loadu_ps(&aec->xfBuf[0][xPos + j]);
const __m128 xfBuf_im = _mm_loadu_ps(&aec->xfBuf[1][xPos + j]);
const __m128 ef_re = _mm_loadu_ps(&ef[0][j]);
const __m128 ef_im = _mm_loadu_ps(&ef[1][j]);
// Calculate the product of conjugate(xfBuf) by ef.
// re(conjugate(a) * b) = aRe * bRe + aIm * bIm
// im(conjugate(a) * b)= aRe * bIm - aIm * bRe
const __m128 a = _mm_mul_ps(xfBuf_re, ef_re);
const __m128 b = _mm_mul_ps(xfBuf_im, ef_im);
const __m128 c = _mm_mul_ps(xfBuf_re, ef_im);
const __m128 d = _mm_mul_ps(xfBuf_im, ef_re);
const __m128 e = _mm_add_ps(a, b);
const __m128 f = _mm_sub_ps(c, d);
// Interleave real and imaginary parts.
const __m128 g = _mm_unpacklo_ps(e, f);
const __m128 h = _mm_unpackhi_ps(e, f);
// Store
_mm_storeu_ps(&fft[2*j + 0], g);
_mm_storeu_ps(&fft[2*j + 4], h);
}
// ... and fixup the first imaginary entry.
fft[1] = MulRe(aec->xfBuf[0][xPos + PART_LEN],
-aec->xfBuf[1][xPos + PART_LEN],
ef[0][PART_LEN], ef[1][PART_LEN]);
rdft(PART_LEN2, -1, fft, ip, wfft);
memset(fft + PART_LEN, 0, sizeof(float)*PART_LEN);
// fft scaling
{
float scale = 2.0f / PART_LEN2;
const __m128 scale_ps = _mm_load_ps1(&scale);
for (j = 0; j < PART_LEN; j+=4) {
const __m128 fft_ps = _mm_loadu_ps(&fft[j]);
const __m128 fft_scale = _mm_mul_ps(fft_ps, scale_ps);
_mm_storeu_ps(&fft[j], fft_scale);
}
}
rdft(PART_LEN2, 1, fft, ip, wfft);
{
float wt1 = aec->wfBuf[1][pos];
aec->wfBuf[0][pos + PART_LEN] += fft[1];
for (j = 0; j < PART_LEN; j+= 4) {
__m128 wtBuf_re = _mm_loadu_ps(&aec->wfBuf[0][pos + j]);
__m128 wtBuf_im = _mm_loadu_ps(&aec->wfBuf[1][pos + j]);
const __m128 fft0 = _mm_loadu_ps(&fft[2 * j + 0]);
const __m128 fft4 = _mm_loadu_ps(&fft[2 * j + 4]);
const __m128 fft_re = _mm_shuffle_ps(fft0, fft4, _MM_SHUFFLE(2, 0, 2 ,0));
const __m128 fft_im = _mm_shuffle_ps(fft0, fft4, _MM_SHUFFLE(3, 1, 3 ,1));
wtBuf_re = _mm_add_ps(wtBuf_re, fft_re);
wtBuf_im = _mm_add_ps(wtBuf_im, fft_im);
_mm_storeu_ps(&aec->wfBuf[0][pos + j], wtBuf_re);
_mm_storeu_ps(&aec->wfBuf[1][pos + j], wtBuf_im);
}
aec->wfBuf[1][pos] = wt1;
}
#endif // UNCONSTR
}
}
void WebRtcAec_InitAec_SSE2(void) {
WebRtcAec_FilterFar = FilterFarSSE2;
WebRtcAec_ScaleErrorSignal = ScaleErrorSignalSSE2;
WebRtcAec_FilterAdaptation = FilterAdaptationSSE2;
}
#endif //__SSE2__