From b383a17fa4c36a4242816ba6a1c57dca46d042d6 Mon Sep 17 00:00:00 2001 From: Kyle Siefring Date: Tue, 31 Oct 2017 11:19:19 -0400 Subject: [PATCH] Support building AVX-512 and implement sadx4 for AVX-512 The added AVX-512 support requires the subset of AVX-512 added in Skylake-X. Change-Id: I39666b00d10bf96d06c709823663eb09b89265b7 --- build/make/Makefile | 2 + build/make/configure.sh | 32 +++++++++++++- build/make/rtcd.pl | 4 +- configure | 1 + test/sad_test.cc | 8 ++++ test/test_libvpx.cc | 3 ++ vp9/common/vp9_rtcd_defs.pl | 1 + vpx_dsp/vpx_dsp.mk | 1 + vpx_dsp/vpx_dsp_rtcd_defs.pl | 3 +- vpx_dsp/x86/sad4d_avx512.c | 83 ++++++++++++++++++++++++++++++++++++ vpx_ports/x86.h | 25 +++++++---- 11 files changed, 149 insertions(+), 14 deletions(-) create mode 100644 vpx_dsp/x86/sad4d_avx512.c diff --git a/build/make/Makefile b/build/make/Makefile index 90522e5f6..f6b3f0630 100644 --- a/build/make/Makefile +++ b/build/make/Makefile @@ -139,6 +139,8 @@ $(BUILD_PFX)%_avx.c.d: CFLAGS += -mavx $(BUILD_PFX)%_avx.c.o: CFLAGS += -mavx $(BUILD_PFX)%_avx2.c.d: CFLAGS += -mavx2 $(BUILD_PFX)%_avx2.c.o: CFLAGS += -mavx2 +$(BUILD_PFX)%_avx512.c.d: CFLAGS += -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl +$(BUILD_PFX)%_avx512.c.o: CFLAGS += -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl # POWER $(BUILD_PFX)%_vsx.c.d: CFLAGS += -maltivec -mvsx diff --git a/build/make/configure.sh b/build/make/configure.sh index 36063a535..5f0f5247b 100644 --- a/build/make/configure.sh +++ b/build/make/configure.sh @@ -403,6 +403,23 @@ check_gcc_machine_option() { fi } +# tests for -m$2, -m$3, -m$4... toggling the feature given in $1. +check_gcc_machine_options() { + feature="$1" + shift + flags="-m$1" + shift + for opt in $*; do + flags="$flags -m$opt" + done + + if enabled gcc && ! disabled "$feature" && ! check_cflags $flags; then + RTCD_OPTIONS="${RTCD_OPTIONS}--disable-$feature " + else + soft_enable "$feature" + fi +} + write_common_config_banner() { print_webm_license config.mk "##" "" echo '# This file automatically generated by configure. Do not edit!' >> config.mk @@ -1237,6 +1254,13 @@ EOF AS=msvs msvs_arch_dir=x86-msvs vc_version=${tgt_cc##vs} + case $vc_version in + 7|8|9|10|11|12|13|14) + echo "${tgt_cc} does not support avx512, disabling....." + RTCD_OPTIONS="${RTCD_OPTIONS}--disable-avx512 " + soft_disable avx512 + ;; + esac case $vc_version in 7|8|9|10) echo "${tgt_cc} does not support avx/avx2, disabling....." @@ -1281,8 +1305,12 @@ EOF elif disabled $ext; then disable_exts="yes" else - # use the shortened version for the flag: sse4_1 -> sse4 - check_gcc_machine_option ${ext%_*} $ext + if [ "$ext" = "avx512" ]; then + check_gcc_machine_options $ext avx512f avx512cd avx512bw avx512dq avx512vl + else + # use the shortened version for the flag: sse4_1 -> sse4 + check_gcc_machine_option ${ext%_*} $ext + fi fi done diff --git a/build/make/rtcd.pl b/build/make/rtcd.pl index 9f44cb414..8fd624627 100755 --- a/build/make/rtcd.pl +++ b/build/make/rtcd.pl @@ -391,10 +391,10 @@ EOF &require("c"); if ($opts{arch} eq 'x86') { - @ALL_ARCHS = filter(qw/mmx sse sse2 sse3 ssse3 sse4_1 avx avx2/); + @ALL_ARCHS = filter(qw/mmx sse sse2 sse3 ssse3 sse4_1 avx avx2 avx512/); x86; } elsif ($opts{arch} eq 'x86_64') { - @ALL_ARCHS = filter(qw/mmx sse sse2 sse3 ssse3 sse4_1 avx avx2/); + @ALL_ARCHS = filter(qw/mmx sse sse2 sse3 ssse3 sse4_1 avx avx2 avx512/); @REQUIRES = filter(keys %required ? keys %required : qw/mmx sse sse2/); &require(@REQUIRES); x86; diff --git a/configure b/configure index ba018e952..ae0a95832 100755 --- a/configure +++ b/configure @@ -244,6 +244,7 @@ ARCH_EXT_LIST_X86=" sse4_1 avx avx2 + avx512 " ARCH_EXT_LIST_LOONGSON=" diff --git a/test/sad_test.cc b/test/sad_test.cc index c30dca8a1..67c3c5315 100644 --- a/test/sad_test.cc +++ b/test/sad_test.cc @@ -896,6 +896,14 @@ const SadMxNx4Param x4d_avx2_tests[] = { INSTANTIATE_TEST_CASE_P(AVX2, SADx4Test, ::testing::ValuesIn(x4d_avx2_tests)); #endif // HAVE_AVX2 +#if HAVE_AVX512 +const SadMxNx4Param x4d_avx512_tests[] = { + SadMxNx4Param(64, 64, &vpx_sad64x64x4d_avx512), +}; +INSTANTIATE_TEST_CASE_P(AVX512, SADx4Test, + ::testing::ValuesIn(x4d_avx512_tests)); +#endif // HAVE_AVX512 + //------------------------------------------------------------------------------ // MIPS functions #if HAVE_MSA diff --git a/test/test_libvpx.cc b/test/test_libvpx.cc index 8a70b4e28..30641ae8c 100644 --- a/test/test_libvpx.cc +++ b/test/test_libvpx.cc @@ -53,6 +53,9 @@ int main(int argc, char **argv) { } if (!(simd_caps & HAS_AVX)) append_negative_gtest_filter(":AVX.*:AVX/*"); if (!(simd_caps & HAS_AVX2)) append_negative_gtest_filter(":AVX2.*:AVX2/*"); + if (!(simd_caps & HAS_AVX512)) { + append_negative_gtest_filter(":AVX512.*:AVX512/*"); + } #endif // ARCH_X86 || ARCH_X86_64 #if !CONFIG_SHARED diff --git a/vp9/common/vp9_rtcd_defs.pl b/vp9/common/vp9_rtcd_defs.pl index 28ae15a8b..32147dd52 100644 --- a/vp9/common/vp9_rtcd_defs.pl +++ b/vp9/common/vp9_rtcd_defs.pl @@ -30,6 +30,7 @@ if ($opts{arch} eq "x86_64") { $ssse3_x86_64 = 'ssse3'; $avx_x86_64 = 'avx'; $avx2_x86_64 = 'avx2'; + $avx512_x86_64 = 'avx512'; } # diff --git a/vpx_dsp/vpx_dsp.mk b/vpx_dsp/vpx_dsp.mk index 808ee36de..d18dd3107 100644 --- a/vpx_dsp/vpx_dsp.mk +++ b/vpx_dsp/vpx_dsp.mk @@ -327,6 +327,7 @@ DSP_SRCS-$(HAVE_SSSE3) += x86/sad_ssse3.asm DSP_SRCS-$(HAVE_SSE4_1) += x86/sad_sse4.asm DSP_SRCS-$(HAVE_AVX2) += x86/sad4d_avx2.c DSP_SRCS-$(HAVE_AVX2) += x86/sad_avx2.c +DSP_SRCS-$(HAVE_AVX512) += x86/sad4d_avx512.c DSP_SRCS-$(HAVE_SSE) += x86/sad4d_sse2.asm DSP_SRCS-$(HAVE_SSE) += x86/sad_sse2.asm diff --git a/vpx_dsp/vpx_dsp_rtcd_defs.pl b/vpx_dsp/vpx_dsp_rtcd_defs.pl index bb54503fe..8ae847c3d 100644 --- a/vpx_dsp/vpx_dsp_rtcd_defs.pl +++ b/vpx_dsp/vpx_dsp_rtcd_defs.pl @@ -20,6 +20,7 @@ if ($opts{arch} eq "x86_64") { $ssse3_x86_64 = 'ssse3'; $avx_x86_64 = 'avx'; $avx2_x86_64 = 'avx2'; + $avx512_x86_64 = 'avx512'; } # @@ -872,7 +873,7 @@ specialize qw/vpx_sad4x4x8 sse4_1 msa mmi/; # Multi-block SAD, comparing a reference to N independent blocks # add_proto qw/void vpx_sad64x64x4d/, "const uint8_t *src_ptr, int src_stride, const uint8_t * const ref_ptr[], int ref_stride, uint32_t *sad_array"; -specialize qw/vpx_sad64x64x4d avx2 neon msa sse2 vsx mmi/; +specialize qw/vpx_sad64x64x4d avx512 avx2 neon msa sse2 vsx mmi/; add_proto qw/void vpx_sad64x32x4d/, "const uint8_t *src_ptr, int src_stride, const uint8_t * const ref_ptr[], int ref_stride, uint32_t *sad_array"; specialize qw/vpx_sad64x32x4d neon msa sse2 vsx mmi/; diff --git a/vpx_dsp/x86/sad4d_avx512.c b/vpx_dsp/x86/sad4d_avx512.c new file mode 100644 index 000000000..5f2ab6ea7 --- /dev/null +++ b/vpx_dsp/x86/sad4d_avx512.c @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2017 The WebM project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include // AVX512 +#include "./vpx_dsp_rtcd.h" +#include "vpx/vpx_integer.h" + +void vpx_sad64x64x4d_avx512(const uint8_t *src, int src_stride, + const uint8_t *const ref[4], int ref_stride, + uint32_t res[4]) { + __m512i src_reg, ref0_reg, ref1_reg, ref2_reg, ref3_reg; + __m512i sum_ref0, sum_ref1, sum_ref2, sum_ref3; + __m512i sum_mlow, sum_mhigh; + int i; + const uint8_t *ref0, *ref1, *ref2, *ref3; + + ref0 = ref[0]; + ref1 = ref[1]; + ref2 = ref[2]; + ref3 = ref[3]; + sum_ref0 = _mm512_set1_epi16(0); + sum_ref1 = _mm512_set1_epi16(0); + sum_ref2 = _mm512_set1_epi16(0); + sum_ref3 = _mm512_set1_epi16(0); + for (i = 0; i < 64; i++) { + // load src and all refs + src_reg = _mm512_loadu_si512((const __m512i *)src); + ref0_reg = _mm512_loadu_si512((const __m512i *)ref0); + ref1_reg = _mm512_loadu_si512((const __m512i *)ref1); + ref2_reg = _mm512_loadu_si512((const __m512i *)ref2); + ref3_reg = _mm512_loadu_si512((const __m512i *)ref3); + // sum of the absolute differences between every ref-i to src + ref0_reg = _mm512_sad_epu8(ref0_reg, src_reg); + ref1_reg = _mm512_sad_epu8(ref1_reg, src_reg); + ref2_reg = _mm512_sad_epu8(ref2_reg, src_reg); + ref3_reg = _mm512_sad_epu8(ref3_reg, src_reg); + // sum every ref-i + sum_ref0 = _mm512_add_epi32(sum_ref0, ref0_reg); + sum_ref1 = _mm512_add_epi32(sum_ref1, ref1_reg); + sum_ref2 = _mm512_add_epi32(sum_ref2, ref2_reg); + sum_ref3 = _mm512_add_epi32(sum_ref3, ref3_reg); + + src += src_stride; + ref0 += ref_stride; + ref1 += ref_stride; + ref2 += ref_stride; + ref3 += ref_stride; + } + { + __m256i sum256; + __m128i sum128; + // in sum_ref-i the result is saved in the first 4 bytes + // the other 4 bytes are zeroed. + // sum_ref1 and sum_ref3 are shifted left by 4 bytes + sum_ref1 = _mm512_bslli_epi128(sum_ref1, 4); + sum_ref3 = _mm512_bslli_epi128(sum_ref3, 4); + + // merge sum_ref0 and sum_ref1 also sum_ref2 and sum_ref3 + sum_ref0 = _mm512_or_si512(sum_ref0, sum_ref1); + sum_ref2 = _mm512_or_si512(sum_ref2, sum_ref3); + + // merge every 64 bit from each sum_ref-i + sum_mlow = _mm512_unpacklo_epi64(sum_ref0, sum_ref2); + sum_mhigh = _mm512_unpackhi_epi64(sum_ref0, sum_ref2); + + // add the low 64 bit to the high 64 bit + sum_mlow = _mm512_add_epi32(sum_mlow, sum_mhigh); + + // add the low 128 bit to the high 128 bit + sum256 = _mm256_add_epi32(_mm512_castsi512_si256(sum_mlow), + _mm512_extracti32x8_epi32(sum_mlow, 1)); + sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum256), + _mm256_extractf128_si256(sum256, 1)); + + _mm_storeu_si128((__m128i *)(res), sum128); + } +} diff --git a/vpx_ports/x86.h b/vpx_ports/x86.h index 5aabb9e3a..ced65ac05 100644 --- a/vpx_ports/x86.h +++ b/vpx_ports/x86.h @@ -151,16 +151,17 @@ static INLINE uint64_t xgetbv(void) { #endif #endif -#define HAS_MMX 0x01 -#define HAS_SSE 0x02 -#define HAS_SSE2 0x04 -#define HAS_SSE3 0x08 -#define HAS_SSSE3 0x10 -#define HAS_SSE4_1 0x20 -#define HAS_AVX 0x40 -#define HAS_AVX2 0x80 +#define HAS_MMX 0x001 +#define HAS_SSE 0x002 +#define HAS_SSE2 0x004 +#define HAS_SSE3 0x008 +#define HAS_SSSE3 0x010 +#define HAS_SSE4_1 0x020 +#define HAS_AVX 0x040 +#define HAS_AVX2 0x080 +#define HAS_AVX512 0x100 #ifndef BIT -#define BIT(n) (1 << n) +#define BIT(n) (1u << n) #endif static INLINE int x86_simd_caps(void) { @@ -209,6 +210,12 @@ static INLINE int x86_simd_caps(void) { cpuid(7, 0, reg_eax, reg_ebx, reg_ecx, reg_edx); if (reg_ebx & BIT(5)) flags |= HAS_AVX2; + + // bits 16 (AVX-512F) & 17 (AVX-512DQ) & 28 (AVX-512CD) & + // 30 (AVX-512BW) & 32 (AVX-512VL) + if ((reg_ebx & (BIT(16) | BIT(17) | BIT(28) | BIT(30) | BIT(31))) == + (BIT(16) | BIT(17) | BIT(28) | BIT(30) | BIT(31))) + flags |= HAS_AVX512; } } }