-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Description
TLDR: (float)x * (float)x with x uint16, gets transformed into (float)widing_mul(x,x).
Repro:
#include <Halide.h>
using namespace Halide;
int main(int argc, char **argv) {
// Generate something to read from in uint16
Var x{"x"};
Func f("f_uint16");
f(x) = cast<uint16_t>(x);
f.compute_root();
// Read it
Func g{"g"};
Func casted{"casted"};
casted(x) = cast<float>(f(x));
g(x) = casted(x) * casted(x);
g.compute_root().vectorize(x, 8).align_bounds(x, 8);
Target t = get_target_from_environment().with_feature(Halide::Target::NoRuntime).with_feature(Target::NoAsserts).with_feature(Halide::Target::NoBoundsQuery);
printf("\n\n\n\n\n");
printf("Widing mul (problematic)\n");
g.compile_to_assembly("widening_mul.asm", {}, t);
printf("\n\n\n\n\n");
printf("Force the cast first!\n");
casted.compute_at(g, x).vectorize(x);
g.compile_to_assembly("cast_first.asm", {}, t);
return 0;
}The problematic case produces:
produce g {
consume f_uint16 {
let t12 = g.min.0/8
for (g.s0.x.x, 0, g.extent.0.required.s) {
let t9.s = f_uint16[ramp(g.s0.x.x*8, 1, 8) aligned(8, 0)]
g[ramp(((g.s0.x.x + t12)*8) - g.min.0, 1, 8) aligned(8, 0)] = float32x8((uint32x8)widening_mul(t9.s, t9.s))
}
}
}which corresponds to this assembly:
# %bb.3: # %"2_for_g.s0.x.x.preheader"
vpbroadcastd .LCPI0_0(%rip), %ymm0 # ymm0 = [1258291200,1258291200,1258291200,1258291200,1258291200,1258291200,1258291200,1258291200]
vpbroadcastd .LCPI0_1(%rip), %ymm1 # ymm1 = [1392508928,1392508928,1392508928,1392508928,1392508928,1392508928,1392508928,1392508928]
vbroadcastss .LCPI0_2(%rip), %ymm2 # ymm2 = [5.49764202E+11,5.49764202E+11,5.49764202E+11,5.49764202E+11,5.49764202E+11,5.49764202E+11,5.49764202E+11,5.49764202E+11]
movslq %ebp, %rdx
shlq $2, %r14
movl %r15d, %ecx
shlq $5, %rdx
subq %r14, %rdx
addq %rdx, %rbx
xorl %edx, %edx
.p2align 4
.LBB0_4: # %"2_for_g.s0.x.x"
# =>This Inner Loop Header: Depth=1
vpmovzxwd (%rax,%rdx), %ymm3 # ymm3 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero
vpmulld %ymm3, %ymm3, %ymm3
vpblendw $170, %ymm0, %ymm3, %ymm4 # ymm4 = ymm3[0],ymm0[1],ymm3[2],ymm0[3],ymm3[4],ymm0[5],ymm3[6],ymm0[7],ymm3[8],ymm0[9],ymm3[10],ymm0[11],ymm3[12],ymm0[13],ymm3[14],ymm0[15]
vpsrld $16, %ymm3, %ymm3
vpblendw $170, %ymm1, %ymm3, %ymm3 # ymm3 = ymm3[0],ymm1[1],ymm3[2],ymm1[3],ymm3[4],ymm1[5],ymm3[6],ymm1[7],ymm3[8],ymm1[9],ymm3[10],ymm1[11],ymm3[12],ymm1[13],ymm3[14],ymm1[15]
vsubps %ymm2, %ymm3, %ymm3
vaddps %ymm3, %ymm4, %ymm3
vmovups %ymm3, (%rbx,%rdx,2)
addq $16, %rdx
decq %rcx
jne .LBB0_4
It's clearly doing some fancy bit shifting magic, relying on the maximal range of the MAX_UINT16*MAX_UINT16 value, to manually convert the resulting uint32_t into a float (vpblendw, vpsrl, vsubps, vaddps). There is no way this is faster than a single vcvtdq2ps... (uops.info reports a latency of 3 cycles for both vcvtdq2ps and vpaddps).
Compare this to the forced order of operations by explicitly scheduling the casted Func:
for (g.s0.x.x, 0, g.extent.0.required.s) {
allocate casted[float32 * 8]
produce casted {
casted[ramp(0, 1, 8)] = float32x8(f_uint16[ramp(g.s0.x.x*8, 1, 8) aligned(8, 0)])
}
consume casted {
let t23 = casted[ramp(0, 1, 8)]
g[ramp(((g.s0.x.x + t26)*8) - g.min.0, 1, 8) aligned(8, 0)] = t23*t23
}
free casted
}with much neater assembly:
# %bb.3: # %"2_for_g.s0.x.x.preheader"
movslq %ebp, %rdx
shlq $2, %r14
movl %r15d, %ecx
shlq $5, %rdx
subq %r14, %rdx
addq %rdx, %rbx
xorl %edx, %edx
.p2align 4
.LBB0_4: # %"2_for_g.s0.x.x"
# =>This Inner Loop Header: Depth=1
vpmovzxwd (%rax,%rdx), %ymm0 # ymm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero
vcvtdq2ps %ymm0, %ymm0
vmulps %ymm0, %ymm0, %ymm0
vmovups %ymm0, (%rbx,%rdx,2)
addq $16, %rdx
decq %rcx
jne .LBB0_4
Checking out the lowering passes, FindIntrinsics is responsible.
Lowering after removing dead allocations and hoisting loop invariants:
assert(reinterpret<uint64>((struct halide_buffer_t *)g.buffer) != (uint64)0, halide_error_buffer_argument_is_null("g")) let g = (void *)_halide_buffer_get_host((struct halide_buffer_t *)g.buffer) let g.type = (uint32)_halide_buffer_get_type((struct halide_buffer_t *)g.buffer) let g.device_dirty = (uint1)_halide_buffer_get_device_dirty((struct halide_buffer_t *)g.buffer) let g.dimensions = _halide_buffer_get_dimensions((struct halide_buffer_t *)g.buffer) let g.min.0 = _halide_buffer_get_min((struct halide_buffer_t *)g.buffer, 0) let g.extent.0 = _halide_buffer_get_extent((struct halide_buffer_t *)g.buffer, 0) let g.stride.0 = _halide_buffer_get_stride((struct halide_buffer_t *)g.buffer, 0) let g.extent.0.required.s = (((g.extent.0 + g.min.0) + 7)/8) - (g.min.0/8) assert(g.type == (uint32)73730, halide_error_bad_type("Output buffer g", g.type, (uint32)73730)) assert(g.dimensions == 1, halide_error_bad_dimensions("Output buffer g", g.dimensions, 1)) assert(((((g.min.0/8) + g.extent.0.required.s)*8) <= (g.extent.0 + g.min.0)) && ((g.min.0 % 8) == 0), halide_error_access_out_of_bounds("Output buffer g", 0, (g.min.0/8)*8, (((g.min.0/8) + g.extent.0.required.s)*8) + -1, g.min.0, (g.extent.0 + g.min.0) + -1)) assert(0 <= g.extent.0, halide_error_buffer_extents_negative("Output buffer g", 0, g.extent.0)) assert(g.stride.0 == 1, halide_error_constraint_violated("g.stride.0", g.stride.0, "1", 1)) assert((uint64)abs(int64(g.extent.0)) <= (uint64)2147483647, halide_error_buffer_allocation_too_large("g", (uint64)abs(int64(g.extent.0)), (uint64)2147483647)) assert(!g.device_dirty, halide_error_device_dirty_with_no_device_support("Output buffer g")) assert(g != reinterpret<(void *)>((uint64)0), halide_error_host_is_null("Output buffer g")) allocate f_uint16[uint16 * (g.extent.0.required.s*8)] produce f_uint16 { let t11 = (g.min.0/8)*8 let t10 = g.extent.0.required.s*8 for (f_uint16.s0.x.rebased, 0, t10) { f_uint16[f_uint16.s0.x.rebased] = uint16(f_uint16.s0.x.rebased + t11) } } produce g { consume f_uint16 { let t12 = g.min.0/8 for (g.s0.x.x, 0, g.extent.0.required.s) { let t9.s = f_uint16[ramp(g.s0.x.x*8, 1, 8) aligned(8, 0)] g[ramp(((g.s0.x.x + t12)*8) - g.min.0, 1, 8) aligned(8, 0)] = float32x8(t9.s)*float32x8(t9.s) } } } free f_uint16Finding intrinsics...
Lowering after finding intrinsics:assert(reinterpret<uint64>((struct halide_buffer_t *)g.buffer) != (uint64)0, halide_error_buffer_argument_is_null("g")) let g = (void *)_halide_buffer_get_host((struct halide_buffer_t *)g.buffer) let g.type = (uint32)_halide_buffer_get_type((struct halide_buffer_t *)g.buffer) let g.device_dirty = (uint1)_halide_buffer_get_device_dirty((struct halide_buffer_t *)g.buffer) let g.dimensions = _halide_buffer_get_dimensions((struct halide_buffer_t *)g.buffer) let g.min.0 = _halide_buffer_get_min((struct halide_buffer_t *)g.buffer, 0) let g.extent.0 = _halide_buffer_get_extent((struct halide_buffer_t *)g.buffer, 0) let g.stride.0 = _halide_buffer_get_stride((struct halide_buffer_t *)g.buffer, 0) let g.extent.0.required.s = (((g.extent.0 + g.min.0) + 7)/8) - (g.min.0/8) assert(g.type == (uint32)73730, halide_error_bad_type("Output buffer g", g.type, (uint32)73730)) assert(g.dimensions == 1, halide_error_bad_dimensions("Output buffer g", g.dimensions, 1)) assert(((((g.min.0/8) + g.extent.0.required.s)*8) <= (g.extent.0 + g.min.0)) && ((g.min.0 % 8) == 0), halide_error_access_out_of_bounds("Output buffer g", 0, (g.min.0/8)*8, (((g.min.0/8) + g.extent.0.required.s)*8) + -1, g.min.0, (g.extent.0 + g.min.0) + -1)) assert(0 <= g.extent.0, halide_error_buffer_extents_negative("Output buffer g", 0, g.extent.0)) assert(g.stride.0 == 1, halide_error_constraint_violated("g.stride.0", g.stride.0, "1", 1)) assert((uint64)abs(int64(g.extent.0)) <= (uint64)2147483647, halide_error_buffer_allocation_too_large("g", (uint64)abs(int64(g.extent.0)), (uint64)2147483647)) assert(!g.device_dirty, halide_error_device_dirty_with_no_device_support("Output buffer g")) assert(g != reinterpret<(void *)>((uint64)0), halide_error_host_is_null("Output buffer g")) allocate f_uint16[uint16 * (g.extent.0.required.s*8)] produce f_uint16 { let t11 = (g.min.0/8)*8 let t10 = g.extent.0.required.s*8 for (f_uint16.s0.x.rebased, 0, t10) { f_uint16[f_uint16.s0.x.rebased] = uint16(f_uint16.s0.x.rebased + t11) } } produce g { consume f_uint16 { let t12 = g.min.0/8 for (g.s0.x.x, 0, g.extent.0.required.s) { let t9.s = f_uint16[ramp(g.s0.x.x*8, 1, 8) aligned(8, 0)] g[ramp(((g.s0.x.x + t12)*8) - g.min.0, 1, 8) aligned(8, 0)] = float32x8((uint32x8)widening_mul(t9.s, t9.s)) } } } free f_uint16