From c515a6a6fe5aec0f9fafa3e3ea16c72271cad65a Mon Sep 17 00:00:00 2001 From: Folkert de Vries Date: Thu, 18 Dec 2025 14:11:08 +0100 Subject: [PATCH] add `vpdpbusd` avx512 intrinsic --- src/shims/x86/avx512.rs | 14 ++- src/shims/x86/mod.rs | 47 ++++++++++ tests/pass/shims/x86/intrinsics-x86-avx512.rs | 94 ++++++++++++++++++- 3 files changed, 153 insertions(+), 2 deletions(-) diff --git a/src/shims/x86/avx512.rs b/src/shims/x86/avx512.rs index a886f5622c..fb8554af81 100644 --- a/src/shims/x86/avx512.rs +++ b/src/shims/x86/avx512.rs @@ -3,7 +3,7 @@ use rustc_middle::ty::Ty; use rustc_span::Symbol; use rustc_target::callconv::FnAbi; -use super::{permute, pmaddbw, psadbw, pshufb}; +use super::{permute, pmaddbw, psadbw, pshufb, vpdpbusd}; use crate::*; impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {} @@ -109,6 +109,18 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { pshufb(this, left, right, dest)?; } + + // Used to implement the _mm512_dpbusd_epi32 function. + "vpdpbusd.512" | "vpdpbusd.256" | "vpdpbusd.128" => { + this.expect_target_feature_for_intrinsic(link_name, "avx512vnni")?; + if matches!(unprefixed_name, "vpdpbusd.128" | "vpdpbusd.256") { + this.expect_target_feature_for_intrinsic(link_name, "avx512vl")?; + } + + let [src, a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?; + + vpdpbusd(this, src, a, b, dest)?; + } _ => return interp_ok(EmulateItemResult::NotSupported), } interp_ok(EmulateItemResult::NeedsReturn) diff --git a/src/shims/x86/mod.rs b/src/shims/x86/mod.rs index a5164cc87a..f3dba0b5ba 100644 --- a/src/shims/x86/mod.rs +++ b/src/shims/x86/mod.rs @@ -1200,6 +1200,53 @@ fn pshufb<'tcx>( interp_ok(()) } +/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding signed +/// 8-bit integers in b, producing 4 intermediate signed 16-bit results. Sum these 4 results with +/// the corresponding 32-bit integer in src, and store the packed 32-bit results in dst. +/// +/// +/// +/// +fn vpdpbusd<'tcx>( + ecx: &mut crate::MiriInterpCx<'tcx>, + src: &OpTy<'tcx>, + a: &OpTy<'tcx>, + b: &OpTy<'tcx>, + dest: &MPlaceTy<'tcx>, +) -> InterpResult<'tcx, ()> { + let (src, src_len) = ecx.project_to_simd(src)?; + let (a, a_len) = ecx.project_to_simd(a)?; + let (b, b_len) = ecx.project_to_simd(b)?; + let (dest, dest_len) = ecx.project_to_simd(dest)?; + + // fn vpdpbusd(src: i32x16, a: i32x16, b: i32x16) -> i32x16; + // fn vpdpbusd256(src: i32x8, a: i32x8, b: i32x8) -> i32x8; + // fn vpdpbusd128(src: i32x4, a: i32x4, b: i32x4) -> i32x4; + assert_eq!(dest_len, src_len); + assert_eq!(dest_len, a_len); + assert_eq!(dest_len, b_len); + + for i in 0..dest_len { + let src = ecx.read_scalar(&ecx.project_index(&src, i)?)?.to_i32()?; + let a = ecx.read_scalar(&ecx.project_index(&a, i)?)?.to_u32()?; + let b = ecx.read_scalar(&ecx.project_index(&b, i)?)?.to_u32()?; + let dest = ecx.project_index(&dest, i)?; + + let [a1, a2, a3, a4] = a.to_le_bytes(); + let [b1, b2, b3, b4] = b.to_le_bytes(); + + let intermediate = i32::from(i16::from(a1).wrapping_mul(i16::from(b1.cast_signed()))) + .wrapping_add(i32::from(i16::from(a2).wrapping_mul(i16::from(b2.cast_signed())))) + .wrapping_add(i32::from(i16::from(a3).wrapping_mul(i16::from(b3.cast_signed())))) + .wrapping_add(i32::from(i16::from(a4).wrapping_mul(i16::from(b4.cast_signed())))); + + let res = Scalar::from_i32(intermediate.wrapping_add(src)); + ecx.write_scalar(res, &dest)?; + } + + interp_ok(()) +} + /// Packs two N-bit integer vectors to a single N/2-bit integers. /// /// The conversion from N-bit to N/2-bit should be provided by `f`. diff --git a/tests/pass/shims/x86/intrinsics-x86-avx512.rs b/tests/pass/shims/x86/intrinsics-x86-avx512.rs index f95429d59e..46df6596ae 100644 --- a/tests/pass/shims/x86/intrinsics-x86-avx512.rs +++ b/tests/pass/shims/x86/intrinsics-x86-avx512.rs @@ -1,6 +1,6 @@ // We're testing x86 target specific features //@only-target: x86_64 i686 -//@compile-flags: -C target-feature=+avx512f,+avx512vl,+avx512bitalg,+avx512vpopcntdq +//@compile-flags: -C target-feature=+avx512f,+avx512vl,+avx512bitalg,+avx512vpopcntdq,+avx512vnni #[cfg(target_arch = "x86")] use std::arch::x86::*; @@ -13,12 +13,14 @@ fn main() { assert!(is_x86_feature_detected!("avx512vl")); assert!(is_x86_feature_detected!("avx512bitalg")); assert!(is_x86_feature_detected!("avx512vpopcntdq")); + assert!(is_x86_feature_detected!("avx512vnni")); unsafe { test_avx512(); test_avx512bitalg(); test_avx512vpopcntdq(); test_avx512ternarylogic(); + test_avx512vnni(); } } @@ -411,6 +413,96 @@ unsafe fn test_avx512ternarylogic() { test_mm_ternarylogic_epi32(); } +#[target_feature(enable = "avx512vnni")] +unsafe fn test_avx512vnni() { + #[target_feature(enable = "avx512vnni")] + unsafe fn test_mm512_dpbusd_epi32() { + const SRC: [i32; 16] = [ + 1, + 0, + 0, + 7, + i32::MAX - 10, + i32::MIN + 10, + 12345, + -9876, + 0x01020304, + -1, + 42, + 0, + 1_000_000_000, + -1_000_000_000, + 17, + -17, + ]; + + const A: [i32; 16] = [ + 0x01010101, + 0xFFFF_FFFFu32 as i32, + 0xFFFF_FFFFu32 as i32, + 0x02_80_01_FF, + 0xFFFF_FFFFu32 as i32, + 0xFFFF_FFFFu32 as i32, + 0x00_FF_00_FF, + 0x7F_80_FF_01, + 0x10_20_30_40, + 0xDE_AD_BE_EFu32 as i32, + 0x00_00_00_FF, + 0x12_34_56_78, + 0xFF_00_FF_00u32 as i32, + 0x01_02_03_04, + 0xAA_55_AA_55u32 as i32, + 0x11_22_33_44, + ]; + + const B: [i32; 16] = [ + 0x01010101, + 0x7F7F_7F7F, + 0x8080_8080u32 as i32, + 0xFF_01_80_7Fu32 as i32, + 0x7F7F_7F7F, + 0x8080_8080u32 as i32, + 0x01_FF_01_FF, + 0x80_7F_00_FFu32 as i32, + 0x7F_01_FF_80u32 as i32, + 0x01_02_03_04, + 0xFF_FF_FF_FFu32 as i32, + 0x80_00_7F_FFu32 as i32, + 0x7F_80_7F_80u32 as i32, + 0x40_C0_20_E0u32 as i32, + 0x00_01_02_03, + 0x7F_7E_80_81u32 as i32, + ]; + + const DST: [i32; 16] = [ + 5, + 129540, + -130560, + 32390, + -2147354119, + 2147353098, + 11835, + -9877, + 16902884, + 2093, + -213, + 8498, + 1000064770, + -1000000096, + 697, + -8738, + ]; + + let src = _mm512_loadu_si512(SRC.as_ptr().cast::<__m512i>()); + let a = _mm512_loadu_si512(A.as_ptr().cast::<__m512i>()); + let b = _mm512_loadu_si512(B.as_ptr().cast::<__m512i>()); + let dst = _mm512_loadu_si512(DST.as_ptr().cast::<__m512i>()); + + assert_eq_m512i(_mm512_dpbusd_epi32(src, a, b), dst); + } + test_mm512_dpbusd_epi32(); +} + #[track_caller] unsafe fn assert_eq_m512i(a: __m512i, b: __m512i) { assert_eq!(transmute::<_, [i32; 16]>(a), transmute::<_, [i32; 16]>(b))