Skip to content

FindIntrinsics lifts cast from uint16 to float out of multiplication. #8913

@mcourteaux

Description

@mcourteaux

TLDR: (float)x * (float)x with x uint16, gets transformed into (float)widing_mul(x,x).

Repro:

#include <Halide.h>

using namespace Halide;
int main(int argc, char **argv) {

  // Generate something to read from in uint16
  Var x{"x"};
  Func f("f_uint16");
  f(x) = cast<uint16_t>(x);
  f.compute_root();

  // Read it
  Func g{"g"};
  Func casted{"casted"};
  casted(x) = cast<float>(f(x));
  g(x) = casted(x) * casted(x);
  g.compute_root().vectorize(x, 8).align_bounds(x, 8);

  Target t = get_target_from_environment().with_feature(Halide::Target::NoRuntime).with_feature(Target::NoAsserts).with_feature(Halide::Target::NoBoundsQuery);

  printf("\n\n\n\n\n");
  printf("Widing mul (problematic)\n");
  g.compile_to_assembly("widening_mul.asm", {}, t);


  printf("\n\n\n\n\n");
  printf("Force the cast first!\n");
  casted.compute_at(g, x).vectorize(x);
  g.compile_to_assembly("cast_first.asm", {}, t);

  return 0;
}

The problematic case produces:

produce g {
 consume f_uint16 {
  let t12 = g.min.0/8
  for (g.s0.x.x, 0, g.extent.0.required.s) {
   let t9.s = f_uint16[ramp(g.s0.x.x*8, 1, 8) aligned(8, 0)]
   g[ramp(((g.s0.x.x + t12)*8) - g.min.0, 1, 8) aligned(8, 0)] = float32x8((uint32x8)widening_mul(t9.s, t9.s))
  }
 }
}

which corresponds to this assembly:

# %bb.3:                                # %"2_for_g.s0.x.x.preheader"
	vpbroadcastd	.LCPI0_0(%rip), %ymm0   # ymm0 = [1258291200,1258291200,1258291200,1258291200,1258291200,1258291200,1258291200,1258291200]
	vpbroadcastd	.LCPI0_1(%rip), %ymm1   # ymm1 = [1392508928,1392508928,1392508928,1392508928,1392508928,1392508928,1392508928,1392508928]
	vbroadcastss	.LCPI0_2(%rip), %ymm2   # ymm2 = [5.49764202E+11,5.49764202E+11,5.49764202E+11,5.49764202E+11,5.49764202E+11,5.49764202E+11,5.49764202E+11,5.49764202E+11]
	movslq	%ebp, %rdx
	shlq	$2, %r14
	movl	%r15d, %ecx
	shlq	$5, %rdx
	subq	%r14, %rdx
	addq	%rdx, %rbx
	xorl	%edx, %edx
	.p2align	4

.LBB0_4:                                # %"2_for_g.s0.x.x"
                                        # =>This Inner Loop Header: Depth=1
	vpmovzxwd	(%rax,%rdx), %ymm3      # ymm3 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero
	vpmulld	%ymm3, %ymm3, %ymm3
	vpblendw	$170, %ymm0, %ymm3, %ymm4       # ymm4 = ymm3[0],ymm0[1],ymm3[2],ymm0[3],ymm3[4],ymm0[5],ymm3[6],ymm0[7],ymm3[8],ymm0[9],ymm3[10],ymm0[11],ymm3[12],ymm0[13],ymm3[14],ymm0[15]
	vpsrld	$16, %ymm3, %ymm3
	vpblendw	$170, %ymm1, %ymm3, %ymm3       # ymm3 = ymm3[0],ymm1[1],ymm3[2],ymm1[3],ymm3[4],ymm1[5],ymm3[6],ymm1[7],ymm3[8],ymm1[9],ymm3[10],ymm1[11],ymm3[12],ymm1[13],ymm3[14],ymm1[15]
	vsubps	%ymm2, %ymm3, %ymm3
	vaddps	%ymm3, %ymm4, %ymm3
	vmovups	%ymm3, (%rbx,%rdx,2)
	addq	$16, %rdx
	decq	%rcx
	jne	.LBB0_4

It's clearly doing some fancy bit shifting magic, relying on the maximal range of the MAX_UINT16*MAX_UINT16 value, to manually convert the resulting uint32_t into a float (vpblendw, vpsrl, vsubps, vaddps). There is no way this is faster than a single vcvtdq2ps... (uops.info reports a latency of 3 cycles for both vcvtdq2ps and vpaddps).

Compare this to the forced order of operations by explicitly scheduling the casted Func:

  for (g.s0.x.x, 0, g.extent.0.required.s) {
   allocate casted[float32 * 8]
   produce casted {
    casted[ramp(0, 1, 8)] = float32x8(f_uint16[ramp(g.s0.x.x*8, 1, 8) aligned(8, 0)])
   }
   consume casted {
    let t23 = casted[ramp(0, 1, 8)]
    g[ramp(((g.s0.x.x + t26)*8) - g.min.0, 1, 8) aligned(8, 0)] = t23*t23
   }
   free casted
  }

with much neater assembly:

# %bb.3:                                # %"2_for_g.s0.x.x.preheader"
	movslq	%ebp, %rdx
	shlq	$2, %r14
	movl	%r15d, %ecx
	shlq	$5, %rdx
	subq	%r14, %rdx
	addq	%rdx, %rbx
	xorl	%edx, %edx
	.p2align	4

.LBB0_4:                                # %"2_for_g.s0.x.x"
                                        # =>This Inner Loop Header: Depth=1
	vpmovzxwd	(%rax,%rdx), %ymm0      # ymm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero
	vcvtdq2ps	%ymm0, %ymm0
	vmulps	%ymm0, %ymm0, %ymm0
	vmovups	%ymm0, (%rbx,%rdx,2)
	addq	$16, %rdx
	decq	%rcx
	jne	.LBB0_4

Checking out the lowering passes, FindIntrinsics is responsible.

Lowering after removing dead allocations and hoisting loop invariants:

assert(reinterpret<uint64>((struct halide_buffer_t *)g.buffer) != (uint64)0, halide_error_buffer_argument_is_null("g"))
let g = (void *)_halide_buffer_get_host((struct halide_buffer_t *)g.buffer)
let g.type = (uint32)_halide_buffer_get_type((struct halide_buffer_t *)g.buffer)
let g.device_dirty = (uint1)_halide_buffer_get_device_dirty((struct halide_buffer_t *)g.buffer)
let g.dimensions = _halide_buffer_get_dimensions((struct halide_buffer_t *)g.buffer)
let g.min.0 = _halide_buffer_get_min((struct halide_buffer_t *)g.buffer, 0)
let g.extent.0 = _halide_buffer_get_extent((struct halide_buffer_t *)g.buffer, 0)
let g.stride.0 = _halide_buffer_get_stride((struct halide_buffer_t *)g.buffer, 0)
let g.extent.0.required.s = (((g.extent.0 + g.min.0) + 7)/8) - (g.min.0/8)
assert(g.type == (uint32)73730, halide_error_bad_type("Output buffer g", g.type, (uint32)73730))
assert(g.dimensions == 1, halide_error_bad_dimensions("Output buffer g", g.dimensions, 1))
assert(((((g.min.0/8) + g.extent.0.required.s)*8) <= (g.extent.0 + g.min.0)) && ((g.min.0 % 8) == 0), halide_error_access_out_of_bounds("Output buffer g", 0, (g.min.0/8)*8, (((g.min.0/8) + g.extent.0.required.s)*8) + -1, g.min.0, (g.extent.0 + g.min.0) + -1))
assert(0 <= g.extent.0, halide_error_buffer_extents_negative("Output buffer g", 0, g.extent.0))
assert(g.stride.0 == 1, halide_error_constraint_violated("g.stride.0", g.stride.0, "1", 1))
assert((uint64)abs(int64(g.extent.0)) <= (uint64)2147483647, halide_error_buffer_allocation_too_large("g", (uint64)abs(int64(g.extent.0)), (uint64)2147483647))
assert(!g.device_dirty, halide_error_device_dirty_with_no_device_support("Output buffer g"))
assert(g != reinterpret<(void *)>((uint64)0), halide_error_host_is_null("Output buffer g"))
allocate f_uint16[uint16 * (g.extent.0.required.s*8)]
produce f_uint16 {
 let t11 = (g.min.0/8)*8
 let t10 = g.extent.0.required.s*8
 for (f_uint16.s0.x.rebased, 0, t10) {
  f_uint16[f_uint16.s0.x.rebased] = uint16(f_uint16.s0.x.rebased + t11)
 }
}
produce g {
 consume f_uint16 {
  let t12 = g.min.0/8
  for (g.s0.x.x, 0, g.extent.0.required.s) {
   let t9.s = f_uint16[ramp(g.s0.x.x*8, 1, 8) aligned(8, 0)]
   g[ramp(((g.s0.x.x + t12)*8) - g.min.0, 1, 8) aligned(8, 0)] = float32x8(t9.s)*float32x8(t9.s)
  }
 }
}
free f_uint16

Finding intrinsics...
Lowering after finding intrinsics:

assert(reinterpret<uint64>((struct halide_buffer_t *)g.buffer) != (uint64)0, halide_error_buffer_argument_is_null("g"))
let g = (void *)_halide_buffer_get_host((struct halide_buffer_t *)g.buffer)
let g.type = (uint32)_halide_buffer_get_type((struct halide_buffer_t *)g.buffer)
let g.device_dirty = (uint1)_halide_buffer_get_device_dirty((struct halide_buffer_t *)g.buffer)
let g.dimensions = _halide_buffer_get_dimensions((struct halide_buffer_t *)g.buffer)
let g.min.0 = _halide_buffer_get_min((struct halide_buffer_t *)g.buffer, 0)
let g.extent.0 = _halide_buffer_get_extent((struct halide_buffer_t *)g.buffer, 0)
let g.stride.0 = _halide_buffer_get_stride((struct halide_buffer_t *)g.buffer, 0)
let g.extent.0.required.s = (((g.extent.0 + g.min.0) + 7)/8) - (g.min.0/8)
assert(g.type == (uint32)73730, halide_error_bad_type("Output buffer g", g.type, (uint32)73730))
assert(g.dimensions == 1, halide_error_bad_dimensions("Output buffer g", g.dimensions, 1))
assert(((((g.min.0/8) + g.extent.0.required.s)*8) <= (g.extent.0 + g.min.0)) && ((g.min.0 % 8) == 0), halide_error_access_out_of_bounds("Output buffer g", 0, (g.min.0/8)*8, (((g.min.0/8) + g.extent.0.required.s)*8) + -1, g.min.0, (g.extent.0 + g.min.0) + -1))
assert(0 <= g.extent.0, halide_error_buffer_extents_negative("Output buffer g", 0, g.extent.0))
assert(g.stride.0 == 1, halide_error_constraint_violated("g.stride.0", g.stride.0, "1", 1))
assert((uint64)abs(int64(g.extent.0)) <= (uint64)2147483647, halide_error_buffer_allocation_too_large("g", (uint64)abs(int64(g.extent.0)), (uint64)2147483647))
assert(!g.device_dirty, halide_error_device_dirty_with_no_device_support("Output buffer g"))
assert(g != reinterpret<(void *)>((uint64)0), halide_error_host_is_null("Output buffer g"))
allocate f_uint16[uint16 * (g.extent.0.required.s*8)]
produce f_uint16 {
 let t11 = (g.min.0/8)*8
 let t10 = g.extent.0.required.s*8
 for (f_uint16.s0.x.rebased, 0, t10) {
  f_uint16[f_uint16.s0.x.rebased] = uint16(f_uint16.s0.x.rebased + t11)
 }
}
produce g {
 consume f_uint16 {
  let t12 = g.min.0/8
  for (g.s0.x.x, 0, g.extent.0.required.s) {
   let t9.s = f_uint16[ramp(g.s0.x.x*8, 1, 8) aligned(8, 0)]
   g[ramp(((g.s0.x.x + t12)*8) - g.min.0, 1, 8) aligned(8, 0)] = float32x8((uint32x8)widening_mul(t9.s, t9.s))
  }
 }
}
free f_uint16

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions