diff --git a/test/vp9_error_block_test.cc b/test/vp9_error_block_test.cc index 8c5d5a2e2..d779706fc 100644 --- a/test/vp9_error_block_test.cc +++ b/test/vp9_error_block_test.cc @@ -136,7 +136,23 @@ TEST_P(ErrorBlockTest, ExtremeValues) { using std::tr1::make_tuple; -#if HAVE_SSE2 +#if CONFIG_USE_X86INC && HAVE_SSE2 +int64_t wrap_vp9_highbd_block_error_8bit_sse2(const tran_low_t *coeff, + const tran_low_t *dqcoeff, + intptr_t block_size, + int64_t *ssz, int bps) { + assert(bps == 8); + return vp9_highbd_block_error_8bit_sse2(coeff, dqcoeff, block_size, ssz); +} + +int64_t wrap_vp9_highbd_block_error_8bit_c(const tran_low_t *coeff, + const tran_low_t *dqcoeff, + intptr_t block_size, + int64_t *ssz, int bps) { + assert(bps == 8); + return vp9_highbd_block_error_8bit_c(coeff, dqcoeff, block_size, ssz); +} + INSTANTIATE_TEST_CASE_P( SSE2, ErrorBlockTest, ::testing::Values( @@ -145,7 +161,9 @@ INSTANTIATE_TEST_CASE_P( make_tuple(&vp9_highbd_block_error_sse2, &vp9_highbd_block_error_c, VPX_BITS_12), make_tuple(&vp9_highbd_block_error_sse2, - &vp9_highbd_block_error_c, VPX_BITS_8))); + &vp9_highbd_block_error_c, VPX_BITS_8), + make_tuple(&wrap_vp9_highbd_block_error_8bit_sse2, + &wrap_vp9_highbd_block_error_8bit_c, VPX_BITS_8))); #endif // HAVE_SSE2 #endif // CONFIG_VP9_HIGHBITDEPTH } // namespace diff --git a/vp9/common/vp9_rtcd_defs.pl b/vp9/common/vp9_rtcd_defs.pl index e633691e7..ed5f4ca32 100644 --- a/vp9/common/vp9_rtcd_defs.pl +++ b/vp9/common/vp9_rtcd_defs.pl @@ -241,11 +241,15 @@ if (vpx_config("CONFIG_VP9_TEMPORAL_DENOISING") eq "yes") { } if (vpx_config("CONFIG_VP9_HIGHBITDEPTH") eq "yes") { -# the transform coefficients are held in 32-bit -# values, so the assembler code for vp9_block_error can no longer be used. add_proto qw/int64_t vp9_block_error/, "const tran_low_t *coeff, const tran_low_t *dqcoeff, intptr_t block_size, int64_t *ssz"; specialize qw/vp9_block_error/; + add_proto qw/int64_t vp9_highbd_block_error/, "const tran_low_t *coeff, const tran_low_t *dqcoeff, intptr_t block_size, int64_t *ssz, int bd"; + specialize qw/vp9_highbd_block_error/, "$sse2_x86inc"; + + add_proto qw/int64_t vp9_highbd_block_error_8bit/, "const tran_low_t *coeff, const tran_low_t *dqcoeff, intptr_t block_size, int64_t *ssz"; + specialize qw/vp9_highbd_block_error_8bit/, "$sse2_x86inc"; + add_proto qw/void vp9_quantize_fp/, "const tran_low_t *coeff_ptr, intptr_t n_coeffs, int skip_block, const int16_t *zbin_ptr, const int16_t *round_ptr, const int16_t *quant_ptr, const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *scan, const int16_t *iscan"; specialize qw/vp9_quantize_fp/; @@ -320,9 +324,6 @@ if (vpx_config("CONFIG_VP9_HIGHBITDEPTH") eq "yes") { # ENCODEMB INVOKE - add_proto qw/int64_t vp9_highbd_block_error/, "const tran_low_t *coeff, const tran_low_t *dqcoeff, intptr_t block_size, int64_t *ssz, int bd"; - specialize qw/vp9_highbd_block_error sse2/; - add_proto qw/void vp9_highbd_quantize_fp/, "const tran_low_t *coeff_ptr, intptr_t n_coeffs, int skip_block, const int16_t *zbin_ptr, const int16_t *round_ptr, const int16_t *quant_ptr, const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *scan, const int16_t *iscan"; specialize qw/vp9_highbd_quantize_fp/; diff --git a/vp9/encoder/vp9_rdopt.c b/vp9/encoder/vp9_rdopt.c index 1818906df..19442917a 100644 --- a/vp9/encoder/vp9_rdopt.c +++ b/vp9/encoder/vp9_rdopt.c @@ -269,6 +269,71 @@ static void model_rd_for_sb(VP9_COMP *cpi, BLOCK_SIZE bsize, *out_dist_sum = dist_sum << 4; } +#if CONFIG_VP9_HIGHBITDEPTH +int64_t vp9_highbd_block_error_c(const tran_low_t *coeff, + const tran_low_t *dqcoeff, + intptr_t block_size, + int64_t *ssz, int bd) { + int i; + int64_t error = 0, sqcoeff = 0; + int shift = 2 * (bd - 8); + int rounding = shift > 0 ? 1 << (shift - 1) : 0; + + for (i = 0; i < block_size; i++) { + const int64_t diff = coeff[i] - dqcoeff[i]; + error += diff * diff; + sqcoeff += (int64_t)coeff[i] * (int64_t)coeff[i]; + } + assert(error >= 0 && sqcoeff >= 0); + error = (error + rounding) >> shift; + sqcoeff = (sqcoeff + rounding) >> shift; + + *ssz = sqcoeff; + return error; +} + +int64_t vp9_highbd_block_error_8bit_c(const tran_low_t *coeff, + const tran_low_t *dqcoeff, + intptr_t block_size, + int64_t *ssz) { + int i; + int32_t c, d; + int64_t error = 0, sqcoeff = 0; + int16_t diff; + + const int32_t hi = 0x00007fff; + const int32_t lo = 0xffff8000; + + for (i = 0; i < block_size; i++) { + c = coeff[i]; + d = dqcoeff[i]; + + // Saturate to 16 bits + c = (c > hi) ? hi : ((c < lo) ? lo : c); + d = (d > hi) ? hi : ((d < lo) ? lo : d); + + diff = d - c; + error += diff * diff; + sqcoeff += c * c; + } + assert(error >= 0 && sqcoeff >= 0); + + *ssz = sqcoeff; + return error; +} + +static int64_t vp9_highbd_block_error_dispatch(const tran_low_t *coeff, + const tran_low_t *dqcoeff, + intptr_t block_size, + int64_t *ssz, int bd) { + if (bd == 8) { + return vp9_highbd_block_error_8bit(coeff, dqcoeff, block_size, ssz); + } else { + return vp9_highbd_block_error(coeff, dqcoeff, block_size, ssz, bd); + } +} +#endif // CONFIG_VP9_HIGHBITDEPTH + int64_t vp9_block_error_c(const tran_low_t *coeff, const tran_low_t *dqcoeff, intptr_t block_size, int64_t *ssz) { int i; @@ -297,30 +362,6 @@ int64_t vp9_block_error_fp_c(const int16_t *coeff, const int16_t *dqcoeff, return error; } -#if CONFIG_VP9_HIGHBITDEPTH -int64_t vp9_highbd_block_error_c(const tran_low_t *coeff, - const tran_low_t *dqcoeff, - intptr_t block_size, - int64_t *ssz, int bd) { - int i; - int64_t error = 0, sqcoeff = 0; - int shift = 2 * (bd - 8); - int rounding = shift > 0 ? 1 << (shift - 1) : 0; - - for (i = 0; i < block_size; i++) { - const int64_t diff = coeff[i] - dqcoeff[i]; - error += diff * diff; - sqcoeff += (int64_t)coeff[i] * (int64_t)coeff[i]; - } - assert(error >= 0 && sqcoeff >= 0); - error = (error + rounding) >> shift; - sqcoeff = (sqcoeff + rounding) >> shift; - - *ssz = sqcoeff; - return error; -} -#endif // CONFIG_VP9_HIGHBITDEPTH - /* The trailing '0' is a terminator which is used inside cost_coeffs() to * decide whether to include cost of a trailing EOB node or not (i.e. we * can skip this if the last coefficient in this transform block, e.g. the @@ -430,8 +471,9 @@ static void dist_block(MACROBLOCK *x, int plane, int block, TX_SIZE tx_size, tran_low_t *const dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block); #if CONFIG_VP9_HIGHBITDEPTH const int bd = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd : 8; - *out_dist = vp9_highbd_block_error(coeff, dqcoeff, 16 << ss_txfrm_size, - &this_sse, bd) >> shift; + *out_dist = vp9_highbd_block_error_dispatch(coeff, dqcoeff, + 16 << ss_txfrm_size, + &this_sse, bd) >> shift; #else *out_dist = vp9_block_error(coeff, dqcoeff, 16 << ss_txfrm_size, &this_sse) >> shift; @@ -831,7 +873,7 @@ static int64_t rd_pick_intra4x4block(VP9_COMP *cpi, MACROBLOCK *x, ratey += cost_coeffs(x, 0, block, tempa + idx, templ + idy, TX_4X4, so->scan, so->neighbors, cpi->sf.use_fast_coef_costing); - distortion += vp9_highbd_block_error( + distortion += vp9_highbd_block_error_dispatch( coeff, BLOCK_OFFSET(pd->dqcoeff, block), 16, &unused, xd->bd) >> 2; if (RDCOST(x->rdmult, x->rddiv, ratey, distortion) >= best_rd) @@ -929,8 +971,13 @@ static int64_t rd_pick_intra4x4block(VP9_COMP *cpi, MACROBLOCK *x, ratey += cost_coeffs(x, 0, block, tempa + idx, templ + idy, TX_4X4, so->scan, so->neighbors, cpi->sf.use_fast_coef_costing); +#if CONFIG_VP9_HIGHBITDEPTH + distortion += vp9_highbd_block_error_8bit( + coeff, BLOCK_OFFSET(pd->dqcoeff, block), 16, &unused) >> 2; +#else distortion += vp9_block_error(coeff, BLOCK_OFFSET(pd->dqcoeff, block), 16, &unused) >> 2; +#endif if (RDCOST(x->rdmult, x->rddiv, ratey, distortion) >= best_rd) goto next; vp9_iht4x4_add(tx_type, BLOCK_OFFSET(pd->dqcoeff, block), @@ -1368,6 +1415,9 @@ static int64_t encode_inter_mb_segment(VP9_COMP *cpi, k = i; for (idy = 0; idy < height / 4; ++idy) { for (idx = 0; idx < width / 4; ++idx) { +#if CONFIG_VP9_HIGHBITDEPTH + const int bd = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd : 8; +#endif int64_t ssz, rd, rd1, rd2; tran_low_t* coeff; @@ -1377,14 +1427,8 @@ static int64_t encode_inter_mb_segment(VP9_COMP *cpi, coeff, 8); vp9_regular_quantize_b_4x4(x, 0, k, so->scan, so->iscan); #if CONFIG_VP9_HIGHBITDEPTH - if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) { - thisdistortion += vp9_highbd_block_error(coeff, - BLOCK_OFFSET(pd->dqcoeff, k), - 16, &ssz, xd->bd); - } else { - thisdistortion += vp9_block_error(coeff, BLOCK_OFFSET(pd->dqcoeff, k), - 16, &ssz); - } + thisdistortion += vp9_highbd_block_error_dispatch( + coeff, BLOCK_OFFSET(pd->dqcoeff, k), 16, &ssz, bd); #else thisdistortion += vp9_block_error(coeff, BLOCK_OFFSET(pd->dqcoeff, k), 16, &ssz); diff --git a/vp9/encoder/x86/vp9_highbd_error_sse2.asm b/vp9/encoder/x86/vp9_highbd_error_sse2.asm new file mode 100644 index 000000000..f3b8f0194 --- /dev/null +++ b/vp9/encoder/x86/vp9_highbd_error_sse2.asm @@ -0,0 +1,98 @@ +; +; Copyright (c) 2010 The WebM project authors. All Rights Reserved. +; +; Use of this source code is governed by a BSD-style license +; that can be found in the LICENSE file in the root of the source +; tree. An additional intellectual property rights grant can be found +; in the file PATENTS. All contributing project authors may +; be found in the AUTHORS file in the root of the source tree. +; + +%define private_prefix vp9 + +%include "third_party/x86inc/x86inc.asm" + +SECTION .text +ALIGN 16 + +; +; int64_t vp9_highbd_block_error_8bit(int32_t *coeff, int32_t *dqcoeff, +; intptr_t block_size, int64_t *ssz) +; + +INIT_XMM sse2 +cglobal highbd_block_error_8bit, 3, 3, 8, uqc, dqc, size, ssz + pxor m4, m4 ; sse accumulator + pxor m6, m6 ; ssz accumulator + pxor m5, m5 ; dedicated zero register + lea uqcq, [uqcq+sizeq*4] + lea dqcq, [dqcq+sizeq*4] + neg sizeq + + ALIGN 16 + +.loop: + mova m0, [dqcq+sizeq*4] + packssdw m0, [dqcq+sizeq*4+mmsize] + mova m2, [uqcq+sizeq*4] + packssdw m2, [uqcq+sizeq*4+mmsize] + + mova m1, [dqcq+sizeq*4+mmsize*2] + packssdw m1, [dqcq+sizeq*4+mmsize*3] + mova m3, [uqcq+sizeq*4+mmsize*2] + packssdw m3, [uqcq+sizeq*4+mmsize*3] + + add sizeq, mmsize + + ; individual errors are max. 15bit+sign, so squares are 30bit, and + ; thus the sum of 2 should fit in a 31bit integer (+ unused sign bit) + + psubw m0, m2 + pmaddwd m2, m2 + pmaddwd m0, m0 + + psubw m1, m3 + pmaddwd m3, m3 + pmaddwd m1, m1 + + ; accumulate in 64bit + punpckldq m7, m0, m5 + punpckhdq m0, m5 + paddq m4, m7 + + punpckldq m7, m2, m5 + punpckhdq m2, m5 + paddq m6, m7 + + punpckldq m7, m1, m5 + punpckhdq m1, m5 + paddq m4, m7 + + punpckldq m7, m3, m5 + punpckhdq m3, m5 + paddq m6, m7 + + paddq m4, m0 + paddq m4, m1 + paddq m6, m2 + paddq m6, m3 + + jnz .loop + + ; accumulate horizontally and store in return value + movhlps m5, m4 + movhlps m7, m6 + paddq m4, m5 + paddq m6, m7 + +%if ARCH_X86_64 + movq rax, m4 + movq [sszq], m6 +%else + mov eax, sszm + pshufd m5, m4, 0x1 + movq [eax], m6 + movd eax, m4 + movd edx, m5 +%endif + RET diff --git a/vp9/vp9cx.mk b/vp9/vp9cx.mk index 84b12d78e..a2cbacf48 100644 --- a/vp9/vp9cx.mk +++ b/vp9/vp9cx.mk @@ -100,8 +100,12 @@ endif ifeq ($(CONFIG_USE_X86INC),yes) VP9_CX_SRCS-$(HAVE_MMX) += encoder/x86/vp9_dct_mmx.asm +ifeq ($(CONFIG_VP9_HIGHBITDEPTH),yes) +VP9_CX_SRCS-$(HAVE_SSE2) += encoder/x86/vp9_highbd_error_sse2.asm +else VP9_CX_SRCS-$(HAVE_SSE2) += encoder/x86/vp9_error_sse2.asm endif +endif ifeq ($(ARCH_X86_64),yes) ifeq ($(CONFIG_USE_X86INC),yes)