diff --git a/test/sad_test.cc b/test/sad_test.cc index e37356a68..8a4609254 100644 --- a/test/sad_test.cc +++ b/test/sad_test.cc @@ -676,6 +676,7 @@ INSTANTIATE_TEST_CASE_P(NEON, SADavgTest, ::testing::ValuesIn(avg_neon_tests)); const SadMxNx4Param x4d_neon_tests[] = { SadMxNx4Param(64, 64, &vpx_sad64x64x4d_neon), + SadMxNx4Param(64, 32, &vpx_sad64x32x4d_neon), SadMxNx4Param(32, 64, &vpx_sad32x64x4d_neon), SadMxNx4Param(32, 32, &vpx_sad32x32x4d_neon), SadMxNx4Param(32, 16, &vpx_sad32x16x4d_neon), diff --git a/vpx_dsp/arm/sad4d_neon.c b/vpx_dsp/arm/sad4d_neon.c index afb320aca..b04de3aff 100644 --- a/vpx_dsp/arm/sad4d_neon.c +++ b/vpx_dsp/arm/sad4d_neon.c @@ -176,93 +176,67 @@ void vpx_sad32x64x4d_neon(const uint8_t *src, int src_stride, sad32x_4d(src, src_stride, ref, ref_stride, res, 64); } -static INLINE unsigned int horizontal_long_add_16x8(const uint16x8_t vec_lo, - const uint16x8_t vec_hi) { - const uint32x4_t vec_l_lo = - vaddl_u16(vget_low_u16(vec_lo), vget_high_u16(vec_lo)); - const uint32x4_t vec_l_hi = - vaddl_u16(vget_low_u16(vec_hi), vget_high_u16(vec_hi)); - const uint32x4_t a = vaddq_u32(vec_l_lo, vec_l_hi); - const uint64x2_t b = vpaddlq_u32(a); - const uint32x2_t c = vadd_u32(vreinterpret_u32_u64(vget_low_u64(b)), - vreinterpret_u32_u64(vget_high_u64(b))); - return vget_lane_u32(c, 0); +static INLINE void sum64x(const uint8x16_t a_0, const uint8x16_t a_1, + const uint8x16_t b_0, const uint8x16_t b_1, + uint16x8_t *sum) { + *sum = vabal_u8(*sum, vget_low_u8(a_0), vget_low_u8(b_0)); + *sum = vabal_u8(*sum, vget_high_u8(a_0), vget_high_u8(b_0)); + *sum = vabal_u8(*sum, vget_low_u8(a_1), vget_low_u8(b_1)); + *sum = vabal_u8(*sum, vget_high_u8(a_1), vget_high_u8(b_1)); } -// Calculate the absolute difference of 64 bytes from vec_src_00, vec_src_16, -// vec_src_32, vec_src_48 and ref. Accumulate partial sums in vec_sum_ref_lo -// and vec_sum_ref_hi. -static void sad_neon_64(const uint8x16_t vec_src_00, - const uint8x16_t vec_src_16, - const uint8x16_t vec_src_32, - const uint8x16_t vec_src_48, const uint8_t *ref, - uint16x8_t *vec_sum_ref_lo, - uint16x8_t *vec_sum_ref_hi) { - const uint8x16_t vec_ref_00 = vld1q_u8(ref); - const uint8x16_t vec_ref_16 = vld1q_u8(ref + 16); - const uint8x16_t vec_ref_32 = vld1q_u8(ref + 32); - const uint8x16_t vec_ref_48 = vld1q_u8(ref + 48); +static INLINE void sad64x_4d(const uint8_t *a, int a_stride, + const uint8_t *const b[4], int b_stride, + uint32_t *result, const int height) { + int i; + uint16x8_t sum_0 = vdupq_n_u16(0); + uint16x8_t sum_1 = vdupq_n_u16(0); + uint16x8_t sum_2 = vdupq_n_u16(0); + uint16x8_t sum_3 = vdupq_n_u16(0); + uint16x8_t sum_4 = vdupq_n_u16(0); + uint16x8_t sum_5 = vdupq_n_u16(0); + uint16x8_t sum_6 = vdupq_n_u16(0); + uint16x8_t sum_7 = vdupq_n_u16(0); + const uint8_t *b_loop[4] = { b[0], b[1], b[2], b[3] }; - *vec_sum_ref_lo = vabal_u8(*vec_sum_ref_lo, vget_low_u8(vec_src_00), - vget_low_u8(vec_ref_00)); - *vec_sum_ref_hi = vabal_u8(*vec_sum_ref_hi, vget_high_u8(vec_src_00), - vget_high_u8(vec_ref_00)); - *vec_sum_ref_lo = vabal_u8(*vec_sum_ref_lo, vget_low_u8(vec_src_16), - vget_low_u8(vec_ref_16)); - *vec_sum_ref_hi = vabal_u8(*vec_sum_ref_hi, vget_high_u8(vec_src_16), - vget_high_u8(vec_ref_16)); - *vec_sum_ref_lo = vabal_u8(*vec_sum_ref_lo, vget_low_u8(vec_src_32), - vget_low_u8(vec_ref_32)); - *vec_sum_ref_hi = vabal_u8(*vec_sum_ref_hi, vget_high_u8(vec_src_32), - vget_high_u8(vec_ref_32)); - *vec_sum_ref_lo = vabal_u8(*vec_sum_ref_lo, vget_low_u8(vec_src_48), - vget_low_u8(vec_ref_48)); - *vec_sum_ref_hi = vabal_u8(*vec_sum_ref_hi, vget_high_u8(vec_src_48), - vget_high_u8(vec_ref_48)); + for (i = 0; i < height; ++i) { + const uint8x16_t a_0 = vld1q_u8(a); + const uint8x16_t a_1 = vld1q_u8(a + 16); + const uint8x16_t a_2 = vld1q_u8(a + 32); + const uint8x16_t a_3 = vld1q_u8(a + 48); + a += a_stride; + sum64x(a_0, a_1, vld1q_u8(b_loop[0]), vld1q_u8(b_loop[0] + 16), &sum_0); + sum64x(a_2, a_3, vld1q_u8(b_loop[0] + 32), vld1q_u8(b_loop[0] + 48), + &sum_1); + b_loop[0] += b_stride; + sum64x(a_0, a_1, vld1q_u8(b_loop[1]), vld1q_u8(b_loop[1] + 16), &sum_2); + sum64x(a_2, a_3, vld1q_u8(b_loop[1] + 32), vld1q_u8(b_loop[1] + 48), + &sum_3); + b_loop[1] += b_stride; + sum64x(a_0, a_1, vld1q_u8(b_loop[2]), vld1q_u8(b_loop[2] + 16), &sum_4); + sum64x(a_2, a_3, vld1q_u8(b_loop[2] + 32), vld1q_u8(b_loop[2] + 48), + &sum_5); + b_loop[2] += b_stride; + sum64x(a_0, a_1, vld1q_u8(b_loop[3]), vld1q_u8(b_loop[3] + 16), &sum_6); + sum64x(a_2, a_3, vld1q_u8(b_loop[3] + 32), vld1q_u8(b_loop[3] + 48), + &sum_7); + b_loop[3] += b_stride; + } + + result[0] = vget_lane_u32(horizontal_add_long_uint16x8(sum_0, sum_1), 0); + result[1] = vget_lane_u32(horizontal_add_long_uint16x8(sum_2, sum_3), 0); + result[2] = vget_lane_u32(horizontal_add_long_uint16x8(sum_4, sum_5), 0); + result[3] = vget_lane_u32(horizontal_add_long_uint16x8(sum_6, sum_7), 0); +} + +void vpx_sad64x32x4d_neon(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], int ref_stride, + uint32_t *res) { + sad64x_4d(src, src_stride, ref, ref_stride, res, 32); } void vpx_sad64x64x4d_neon(const uint8_t *src, int src_stride, const uint8_t *const ref[4], int ref_stride, uint32_t *res) { - int i; - uint16x8_t vec_sum_ref0_lo = vdupq_n_u16(0); - uint16x8_t vec_sum_ref0_hi = vdupq_n_u16(0); - uint16x8_t vec_sum_ref1_lo = vdupq_n_u16(0); - uint16x8_t vec_sum_ref1_hi = vdupq_n_u16(0); - uint16x8_t vec_sum_ref2_lo = vdupq_n_u16(0); - uint16x8_t vec_sum_ref2_hi = vdupq_n_u16(0); - uint16x8_t vec_sum_ref3_lo = vdupq_n_u16(0); - uint16x8_t vec_sum_ref3_hi = vdupq_n_u16(0); - const uint8_t *ref0, *ref1, *ref2, *ref3; - ref0 = ref[0]; - ref1 = ref[1]; - ref2 = ref[2]; - ref3 = ref[3]; - - for (i = 0; i < 64; ++i) { - const uint8x16_t vec_src_00 = vld1q_u8(src); - const uint8x16_t vec_src_16 = vld1q_u8(src + 16); - const uint8x16_t vec_src_32 = vld1q_u8(src + 32); - const uint8x16_t vec_src_48 = vld1q_u8(src + 48); - - sad_neon_64(vec_src_00, vec_src_16, vec_src_32, vec_src_48, ref0, - &vec_sum_ref0_lo, &vec_sum_ref0_hi); - sad_neon_64(vec_src_00, vec_src_16, vec_src_32, vec_src_48, ref1, - &vec_sum_ref1_lo, &vec_sum_ref1_hi); - sad_neon_64(vec_src_00, vec_src_16, vec_src_32, vec_src_48, ref2, - &vec_sum_ref2_lo, &vec_sum_ref2_hi); - sad_neon_64(vec_src_00, vec_src_16, vec_src_32, vec_src_48, ref3, - &vec_sum_ref3_lo, &vec_sum_ref3_hi); - - src += src_stride; - ref0 += ref_stride; - ref1 += ref_stride; - ref2 += ref_stride; - ref3 += ref_stride; - } - - res[0] = horizontal_long_add_16x8(vec_sum_ref0_lo, vec_sum_ref0_hi); - res[1] = horizontal_long_add_16x8(vec_sum_ref1_lo, vec_sum_ref1_hi); - res[2] = horizontal_long_add_16x8(vec_sum_ref2_lo, vec_sum_ref2_hi); - res[3] = horizontal_long_add_16x8(vec_sum_ref3_lo, vec_sum_ref3_hi); + sad64x_4d(src, src_stride, ref, ref_stride, res, 64); } diff --git a/vpx_dsp/arm/sum_neon.h b/vpx_dsp/arm/sum_neon.h index c09841223..d74fe0cde 100644 --- a/vpx_dsp/arm/sum_neon.h +++ b/vpx_dsp/arm/sum_neon.h @@ -30,6 +30,15 @@ static INLINE uint32x2_t horizontal_add_uint16x8(const uint16x8_t a) { vreinterpret_u32_u64(vget_high_u64(c))); } +static INLINE uint32x2_t horizontal_add_long_uint16x8(const uint16x8_t a, + const uint16x8_t b) { + const uint32x4_t c = vpaddlq_u16(a); + const uint32x4_t d = vpadalq_u16(c, b); + const uint64x2_t e = vpaddlq_u32(d); + return vadd_u32(vreinterpret_u32_u64(vget_low_u64(e)), + vreinterpret_u32_u64(vget_high_u64(e))); +} + static INLINE uint32x2_t horizontal_add_uint32x4(const uint32x4_t a) { const uint64x2_t b = vpaddlq_u32(a); return vadd_u32(vreinterpret_u32_u64(vget_low_u64(b)), diff --git a/vpx_dsp/vpx_dsp_rtcd_defs.pl b/vpx_dsp/vpx_dsp_rtcd_defs.pl index 44d6e4c71..358d16914 100644 --- a/vpx_dsp/vpx_dsp_rtcd_defs.pl +++ b/vpx_dsp/vpx_dsp_rtcd_defs.pl @@ -860,7 +860,7 @@ add_proto qw/void vpx_sad64x64x4d/, "const uint8_t *src_ptr, int src_stride, con specialize qw/vpx_sad64x64x4d avx2 neon msa sse2 vsx/; add_proto qw/void vpx_sad64x32x4d/, "const uint8_t *src_ptr, int src_stride, const uint8_t * const ref_ptr[], int ref_stride, uint32_t *sad_array"; -specialize qw/vpx_sad64x32x4d msa sse2 vsx/; +specialize qw/vpx_sad64x32x4d neon msa sse2 vsx/; add_proto qw/void vpx_sad32x64x4d/, "const uint8_t *src_ptr, int src_stride, const uint8_t * const ref_ptr[], int ref_stride, uint32_t *sad_array"; specialize qw/vpx_sad32x64x4d neon msa sse2 vsx/;