From 1f8e8e5bf1eff47c22bcdfac5fe045d1d33ef01b Mon Sep 17 00:00:00 2001 From: Yi Luo Date: Fri, 17 Feb 2017 10:59:46 -0800 Subject: [PATCH] Fix idct8x8 SSSE3 SingleExtremeCoeff unit tests - In SSSE3 optimization, 16-bit addition and subtraction would overflow when input coefficient is 16-bit signed extreme values. - Function-level speed becomes slower (unit ms): idct8x8_64: 284 -> 294 idct8x8_12: 145 -> 158. BUG=webm:1332 Change-Id: I1e4bf9d30a6d4112b8cac5823729565bf145e40b --- test/partial_idct_test.cc | 35 ++----------- vpx_dsp/x86/inv_txfm_ssse3.c | 95 ++++++++++++++++++++++++++++-------- 2 files changed, 78 insertions(+), 52 deletions(-) diff --git a/test/partial_idct_test.cc b/test/partial_idct_test.cc index 3c78cb3e6..764c544e3 100644 --- a/test/partial_idct_test.cc +++ b/test/partial_idct_test.cc @@ -55,35 +55,6 @@ typedef std::tr1::tuple || - a == &wrapper) { - return 23625 - 1; - } -#else - (void)a; -#endif - return std::numeric_limits::max(); -} - -int16_t MinSupportedCoeff(InvTxfmWithBdFunc a) { -#if HAVE_SSSE3 && ARCH_X86_64 && !CONFIG_EMULATE_HARDWARE && \ - !CONFIG_VP9_HIGHBITDEPTH - if (a == &wrapper || - a == &wrapper) { - return -23625 + 1; - } -#else - (void)a; -#endif - return std::numeric_limits::min(); -} - class PartialIDctTest : public ::testing::TestWithParam { public: virtual ~PartialIDctTest() {} @@ -261,8 +232,8 @@ TEST_P(PartialIDctTest, AddOutputBlock) { } TEST_P(PartialIDctTest, SingleExtremeCoeff) { - const int16_t max_coeff = MaxSupportedCoeff(partial_itxfm_); - const int16_t min_coeff = MinSupportedCoeff(partial_itxfm_); + const int16_t max_coeff = std::numeric_limits::max(); + const int16_t min_coeff = std::numeric_limits::min(); for (int i = 0; i < last_nonzero_; ++i) { memset(input_block_, 0, sizeof(*input_block_) * input_block_size_); // Run once for min and once for max. @@ -285,7 +256,7 @@ TEST_P(PartialIDctTest, SingleExtremeCoeff) { } } -TEST_P(PartialIDctTest, Speed) { +TEST_P(PartialIDctTest, DISABLED_Speed) { // Keep runtime stable with transform size. const int kCountSpeedTestBlock = 500000000 / input_block_size_; InitMem(); diff --git a/vpx_dsp/x86/inv_txfm_ssse3.c b/vpx_dsp/x86/inv_txfm_ssse3.c index 923d482de..cfa6a732a 100644 --- a/vpx_dsp/x86/inv_txfm_ssse3.c +++ b/vpx_dsp/x86/inv_txfm_ssse3.c @@ -23,7 +23,8 @@ void vpx_idct8x8_64_add_ssse3(const tran_low_t *input, uint8_t *dest, const __m128i stg1_1 = pair_set_epi16(cospi_4_64, cospi_28_64); const __m128i stg1_2 = pair_set_epi16(-cospi_20_64, cospi_12_64); const __m128i stg1_3 = pair_set_epi16(cospi_12_64, cospi_20_64); - const __m128i stg2_0 = pair_set_epi16(2 * cospi_16_64, 2 * cospi_16_64); + const __m128i stk2_0 = pair_set_epi16(cospi_16_64, cospi_16_64); + const __m128i stk2_1 = pair_set_epi16(cospi_16_64, -cospi_16_64); const __m128i stg2_2 = pair_set_epi16(cospi_24_64, -cospi_8_64); const __m128i stg2_3 = pair_set_epi16(cospi_8_64, cospi_24_64); @@ -99,10 +100,26 @@ void vpx_idct8x8_64_add_ssse3(const tran_low_t *input, uint8_t *dest, const __m128i hi_26 = _mm_unpackhi_epi16(in2, in6); { - tmp0 = _mm_add_epi16(in0, in4); - tmp1 = _mm_sub_epi16(in0, in4); - stp2_0 = _mm_mulhrs_epi16(tmp0, stg2_0); - stp2_1 = _mm_mulhrs_epi16(tmp1, stg2_0); + tmp0 = _mm_unpacklo_epi16(in0, in4); + tmp1 = _mm_unpackhi_epi16(in0, in4); + + tmp2 = _mm_madd_epi16(tmp0, stk2_0); + tmp3 = _mm_madd_epi16(tmp1, stk2_0); + tmp4 = _mm_madd_epi16(tmp0, stk2_1); + tmp5 = _mm_madd_epi16(tmp1, stk2_1); + + tmp2 = _mm_add_epi32(tmp2, rounding); + tmp3 = _mm_add_epi32(tmp3, rounding); + tmp4 = _mm_add_epi32(tmp4, rounding); + tmp5 = _mm_add_epi32(tmp5, rounding); + + tmp2 = _mm_srai_epi32(tmp2, DCT_CONST_BITS); + tmp3 = _mm_srai_epi32(tmp3, DCT_CONST_BITS); + tmp4 = _mm_srai_epi32(tmp4, DCT_CONST_BITS); + tmp5 = _mm_srai_epi32(tmp5, DCT_CONST_BITS); + + stp2_0 = _mm_packs_epi32(tmp2, tmp3); + stp2_1 = _mm_packs_epi32(tmp4, tmp5); tmp0 = _mm_madd_epi16(lo_26, stg2_2); tmp1 = _mm_madd_epi16(hi_26, stg2_2); @@ -136,10 +153,26 @@ void vpx_idct8x8_64_add_ssse3(const tran_low_t *input, uint8_t *dest, stp1_2 = _mm_sub_epi16(stp2_1, stp2_2); stp1_3 = _mm_sub_epi16(stp2_0, stp2_3); - tmp0 = _mm_sub_epi16(stp2_6, stp2_5); - tmp2 = _mm_add_epi16(stp2_6, stp2_5); - stp1_5 = _mm_mulhrs_epi16(tmp0, stg2_0); - stp1_6 = _mm_mulhrs_epi16(tmp2, stg2_0); + tmp0 = _mm_unpacklo_epi16(stp2_6, stp2_5); + tmp1 = _mm_unpackhi_epi16(stp2_6, stp2_5); + + tmp2 = _mm_madd_epi16(tmp0, stk2_1); + tmp3 = _mm_madd_epi16(tmp1, stk2_1); + tmp4 = _mm_madd_epi16(tmp0, stk2_0); + tmp5 = _mm_madd_epi16(tmp1, stk2_0); + + tmp2 = _mm_add_epi32(tmp2, rounding); + tmp3 = _mm_add_epi32(tmp3, rounding); + tmp4 = _mm_add_epi32(tmp4, rounding); + tmp5 = _mm_add_epi32(tmp5, rounding); + + tmp2 = _mm_srai_epi32(tmp2, DCT_CONST_BITS); + tmp3 = _mm_srai_epi32(tmp3, DCT_CONST_BITS); + tmp4 = _mm_srai_epi32(tmp4, DCT_CONST_BITS); + tmp5 = _mm_srai_epi32(tmp5, DCT_CONST_BITS); + + stp1_5 = _mm_packs_epi32(tmp2, tmp3); + stp1_6 = _mm_packs_epi32(tmp4, tmp5); } /* Stage4 */ @@ -186,14 +219,18 @@ void vpx_idct8x8_64_add_ssse3(const tran_low_t *input, uint8_t *dest, void vpx_idct8x8_12_add_ssse3(const tran_low_t *input, uint8_t *dest, int stride) { const __m128i zero = _mm_setzero_si128(); + const __m128i rounding = _mm_set1_epi32(DCT_CONST_ROUNDING); const __m128i final_rounding = _mm_set1_epi16(1 << 4); const __m128i stg1_0 = pair_set_epi16(2 * cospi_28_64, 2 * cospi_28_64); const __m128i stg1_1 = pair_set_epi16(2 * cospi_4_64, 2 * cospi_4_64); const __m128i stg1_2 = pair_set_epi16(-2 * cospi_20_64, -2 * cospi_20_64); const __m128i stg1_3 = pair_set_epi16(2 * cospi_12_64, 2 * cospi_12_64); const __m128i stg2_0 = pair_set_epi16(2 * cospi_16_64, 2 * cospi_16_64); + const __m128i stk2_0 = pair_set_epi16(cospi_16_64, cospi_16_64); + const __m128i stk2_1 = pair_set_epi16(cospi_16_64, -cospi_16_64); const __m128i stg2_2 = pair_set_epi16(2 * cospi_24_64, 2 * cospi_24_64); const __m128i stg2_3 = pair_set_epi16(2 * cospi_8_64, 2 * cospi_8_64); + const __m128i stg3_0 = pair_set_epi16(-cospi_16_64, cospi_16_64); __m128i in0, in1, in2, in3, in4, in5, in6, in7; __m128i stp1_0, stp1_1, stp1_2, stp1_3, stp1_4, stp1_5, stp1_6, stp1_7; @@ -233,6 +270,17 @@ void vpx_idct8x8_12_add_ssse3(const tran_low_t *input, uint8_t *dest, stp2_5 = _mm_unpacklo_epi64(tmp1, zero); stp2_6 = _mm_unpackhi_epi64(tmp1, zero); + tmp0 = _mm_unpacklo_epi16(stp2_5, stp2_6); + tmp1 = _mm_madd_epi16(tmp0, stg3_0); + tmp2 = _mm_madd_epi16(tmp0, stk2_0); // stg3_1 = stk2_0 + + tmp1 = _mm_add_epi32(tmp1, rounding); + tmp2 = _mm_add_epi32(tmp2, rounding); + tmp1 = _mm_srai_epi32(tmp1, DCT_CONST_BITS); + tmp2 = _mm_srai_epi32(tmp2, DCT_CONST_BITS); + + stp1_5 = _mm_packs_epi32(tmp1, tmp2); + // Stage3 tmp2 = _mm_add_epi16(stp2_0, stp2_2); tmp3 = _mm_sub_epi16(stp2_0, stp2_2); @@ -240,13 +288,6 @@ void vpx_idct8x8_12_add_ssse3(const tran_low_t *input, uint8_t *dest, stp1_2 = _mm_unpackhi_epi64(tmp3, tmp2); stp1_3 = _mm_unpacklo_epi64(tmp3, tmp2); - tmp0 = _mm_sub_epi16(stp2_6, stp2_5); - tmp1 = _mm_add_epi16(stp2_6, stp2_5); - - tmp2 = _mm_mulhrs_epi16(tmp0, stg2_0); - tmp3 = _mm_mulhrs_epi16(tmp1, stg2_0); - stp1_5 = _mm_unpacklo_epi64(tmp2, tmp3); - // Stage4 tmp0 = _mm_add_epi16(stp1_3, stp2_4); tmp1 = _mm_add_epi16(stp1_2, stp1_5); @@ -279,10 +320,24 @@ void vpx_idct8x8_12_add_ssse3(const tran_low_t *input, uint8_t *dest, stp1_2 = _mm_sub_epi16(stp2_1, stp2_2); stp1_3 = _mm_sub_epi16(stp2_0, stp2_3); - tmp0 = _mm_add_epi16(stp2_6, stp2_5); - tmp1 = _mm_sub_epi16(stp2_6, stp2_5); - stp1_6 = _mm_mulhrs_epi16(tmp0, stg2_0); - stp1_5 = _mm_mulhrs_epi16(tmp1, stg2_0); + tmp0 = _mm_unpacklo_epi16(stp2_6, stp2_5); + tmp1 = _mm_unpackhi_epi16(stp2_6, stp2_5); + + tmp2 = _mm_madd_epi16(tmp0, stk2_0); + tmp3 = _mm_madd_epi16(tmp1, stk2_0); + tmp2 = _mm_add_epi32(tmp2, rounding); + tmp3 = _mm_add_epi32(tmp3, rounding); + tmp2 = _mm_srai_epi32(tmp2, DCT_CONST_BITS); + tmp3 = _mm_srai_epi32(tmp3, DCT_CONST_BITS); + stp1_6 = _mm_packs_epi32(tmp2, tmp3); + + tmp2 = _mm_madd_epi16(tmp0, stk2_1); + tmp3 = _mm_madd_epi16(tmp1, stk2_1); + tmp2 = _mm_add_epi32(tmp2, rounding); + tmp3 = _mm_add_epi32(tmp3, rounding); + tmp2 = _mm_srai_epi32(tmp2, DCT_CONST_BITS); + tmp3 = _mm_srai_epi32(tmp3, DCT_CONST_BITS); + stp1_5 = _mm_packs_epi32(tmp2, tmp3); /* Stage4 */ in0 = _mm_add_epi16(stp1_0, stp2_7);