diff --git a/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td b/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td index 55de986b6721..d8f97698af49 100644 --- a/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td +++ b/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td @@ -564,6 +564,7 @@ def MIGraphX_DotOp : MIGraphX_Op<"dot">, let assemblyFormat = [{ $in_a `,` $in_b attr-dict `:` type($in_a) `,` type($in_b) `->` type($output) }]; + let hasVerifier = 1; } def MIGraphX_SoftmaxOp : diff --git a/mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp b/mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp index 605881e9c691..30707536a049 100644 --- a/mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp +++ b/mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp @@ -384,6 +384,56 @@ LogicalResult UnpackOp::verify() { return success(); } +static LogicalResult isValidDotOp(Operation *op, MIXRShapedType inAType, + MIXRShapedType inBType, + MIXRShapedType outputType) { + ArrayRef shapeA = inAType.getShape(); + ArrayRef shapeB = inBType.getShape(); + ArrayRef shapeOut = outputType.getShape(); + int64_t outputRank = outputType.getRank(); + + if (!llvm::all_of( + ArrayRef{inAType.getRank(), inBType.getRank(), outputRank}, + [](int64_t rank) { return rank >= 2; })) { + return op->emitOpError("expect operand to have rank greater or equal to 2"); + } + + // Batch dimensions (all dims except the last two) must be compatible. + // Broadcasting is allowed when one operand's batch dims are all ones + // or when one operand has no batch dims (rank 2). For example: + // A = {3, 2, 2, 2} and B = {1, 1, 2, 2} (batch B is all ones) - valid + // A = {3, 2, 2, 2} and B = {2, 2} (B has no batch dims) - valid + // A = {3, 2, 2, 2} and B = {2, 3, 2, 2} (batch dims differ) - invalid + ArrayRef batchA = shapeA.drop_back(2); + ArrayRef batchB = shapeB.drop_back(2); + bool hasLeadingOnesB = llvm::all_of(batchB, [](int64_t d) { return d == 1; }); + if (!hasLeadingOnesB && + !std::equal(batchA.begin(), batchA.end(), batchB.begin(), batchB.end())) { + return op->emitOpError("batch dimension mismatch: the first operand (") + << inAType << ") and the second operand (" << inBType + << ") have incompatible batch dimensions"; + } + + int64_t lastAShape = shapeA[shapeA.size() - 1]; + int64_t secondLastBShape = shapeB[shapeB.size() - 2]; + if (lastAShape != secondLastBShape) { + return op->emitOpError( + "contraction dimension mismatch: the first operand (") + << inAType << ") and the second operand (" << inBType + << ") have incompatible contraction dimensions"; + } + + // checking the output dimension, which must match the input + if (!std::equal(shapeA.rbegin() + 2, shapeA.rend(), shapeOut.rbegin() + 2, + shapeOut.rend()) || + *std::prev(shapeOut.end()) != *std::prev(shapeB.end()) || + *std::prev(shapeOut.end(), 2) != *std::prev(shapeA.end(), 2)) { + return op->emitOpError("result type is inconsistent with input shapes"); + } + + return success(); +} + LogicalResult QuantDotOp::verify() { MIXRShapedType inAType = getInA().getType(); MIXRShapedType inBType = getInB().getType(); @@ -431,5 +481,12 @@ LogicalResult QuantDotOp::verify() { return emitOpError("Quant Dot ops requires scales to be provided to use " "f4E2M1FN element type"); } - return success(); + return isValidDotOp(getOperation(), inAType, inBType, getType()); +} + +LogicalResult DotOp::verify() { + MIXRShapedType inAType = getInA().getType(); + MIXRShapedType inBType = getInB().getType(); + + return isValidDotOp(getOperation(), inAType, inBType, getType()); } diff --git a/mlir/test/Conversion/MIGraphXToLinalg/migraphx-to-linalg-dot.mlir b/mlir/test/Conversion/MIGraphXToLinalg/migraphx-to-linalg-dot.mlir index 4e9ab03592da..1dec63931fd2 100644 --- a/mlir/test/Conversion/MIGraphXToLinalg/migraphx-to-linalg-dot.mlir +++ b/mlir/test/Conversion/MIGraphXToLinalg/migraphx-to-linalg-dot.mlir @@ -72,15 +72,6 @@ func.func @dot_f16(%arg0: !migraphx.shaped<8x64x64x320xf16, 1310720x20480x320x1> // ----- -func.func @dot_broadcast(%arg0: !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<6x2x2xf32, 4x2x1>) -> !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> attributes {kernel, arch="gfx950"} { - // expected-error @+2 {{operands must have the same rank}} - // expected-error @+1 {{failed to legalize operation 'migraphx.dot' that was explicitly marked illegal}} - %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf32, 8x4x2x1>, <6x2x2xf32, 4x2x1> -> <3x2x2x2xf32, 8x4x2x1> - func.return %0 : !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> -} - -// ----- - func.func @dot_unranked_tensor(%arg0 : !migraphx.shaped, %arg1: !migraphx.shaped) -> !migraphx.shaped { // expected-error @+2 {{only static shape is supported for now}} // expected-error @+1 {{failed to legalize operation 'migraphx.dot' that was explicitly marked illegal}} diff --git a/mlir/test/Dialect/MIGraphX/invalid.mlir b/mlir/test/Dialect/MIGraphX/invalid.mlir index 3e5a94c0e3c7..32c92195fd9b 100644 --- a/mlir/test/Dialect/MIGraphX/invalid.mlir +++ b/mlir/test/Dialect/MIGraphX/invalid.mlir @@ -225,3 +225,48 @@ func.func @migraphx_quant_dot_f4_n_scales(%arg0: !migraphx.shaped<1x16x512xf4E2M -> <1x16x16xf32, 256x16x1> return %0 : !migraphx.shaped<1x16x16xf32, 256x16x1> } + +// ----- + +// CHECK-LABEL: func.func @dot_rank_less_than_2 +func.func @dot_rank_less_than_2(%arg0: !migraphx.shaped<320xf16, 1>, %arg1: !migraphx.shaped<320x64xf16, 64x1>) -> !migraphx.shaped<64xf16, 1> { + // expected-error @+1 {{expect operand to have rank greater or equal to 2}} + %0 = migraphx.dot %arg0, %arg1 : <320xf16, 1>, <320x64xf16, 64x1> -> <64xf16, 1> + return %0 : !migraphx.shaped<64xf16, 1> +} + +// ----- + +// CHECK-LABEL: func.func @dot_incompatible_inner_dim +func.func @dot_incompatible_inner_dim(%arg0: !migraphx.shaped<2x64x320xf16, 20480x320x1>, %arg1: !migraphx.shaped<2x256x64xf16, 16384x64x1>) -> !migraphx.shaped<2x64x64xf16, 4096x64x1> { + // expected-error @+1 {{contraction dimension mismatch: the first operand}} + %0 = migraphx.dot %arg0, %arg1 : <2x64x320xf16, 20480x320x1>, <2x256x64xf16, 16384x64x1> -> <2x64x64xf16, 4096x64x1> + return %0 : !migraphx.shaped<2x64x64xf16, 4096x64x1> +} + +// ----- + +// CHECK-LABEL: func.func @dot_invalid_batch +func.func @dot_invalid_batch(%arg0: !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<6x2x2xf32, 4x2x1>) -> !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> attributes {kernel, arch="gfx950"} { + // expected-error@+1 {{batch dimension mismatch: the first operand ('!migraphx.shaped<3x2x2x2xf32, 8x4x2x1>') and the second operand ('!migraphx.shaped<6x2x2xf32, 4x2x1>') have incompatible batch dimensions}} + %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf32, 8x4x2x1>, <6x2x2xf32, 4x2x1> -> <3x2x2x2xf32, 8x4x2x1> + func.return %0 : !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> +} + +// ----- + +// CHECK-LABEL: func.func @dot_invalid_broadcast +func.func @dot_invalid_broadcast(%arg0: !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<2x3x2x2xf32, 12x4x2x1>) -> !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> attributes {kernel, arch="gfx950"} { + // expected-error@+1 {{batch dimension mismatch: the first operand}} + %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf32, 8x4x2x1>, <2x3x2x2xf32, 12x4x2x1> -> <3x2x2x2xf32, 8x4x2x1> + func.return %0 : !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> +} + +// ----- + +// CHECK-LABEL: func.func @dot_result_shape_mismatch +func.func @dot_result_shape_mismatch(%arg0: !migraphx.shaped<2x3x4xf16, 12x4x1>, %arg1: !migraphx.shaped<2x4x5xf16, 20x5x1>) -> !migraphx.shaped<2x3x4xf16, 12x4x1> { + // expected-error @+1 {{result type is inconsistent with input shapes}} + %0 = migraphx.dot %arg0, %arg1 : <2x3x4xf16, 12x4x1>, <2x4x5xf16, 20x5x1> -> <2x3x4xf16, 12x4x1> + return %0 : !migraphx.shaped<2x3x4xf16, 12x4x1> +} diff --git a/mlir/test/Dialect/MIGraphX/ops.mlir b/mlir/test/Dialect/MIGraphX/ops.mlir index 047ad2b4caf2..f22240c07f9d 100644 --- a/mlir/test/Dialect/MIGraphX/ops.mlir +++ b/mlir/test/Dialect/MIGraphX/ops.mlir @@ -24,3 +24,25 @@ func.func @migraphx_quant_dot_scaled(%arg0: !migraphx.shaped<1x16x512xf4E2M1FN, -> !migraphx.shaped<1x16x16xf32, 256x16x1> return %0 : !migraphx.shaped<1x16x16xf32, 256x16x1> } + +// Checking to see if the verifier allows for broadcast +// CHECK-LABEL: func.func @migraphx_dot_no_batch_b +// CHECK-NEXT: migraphx.dot +func.func @migraphx_dot_no_batch_b(%arg0: !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>, %arg1: !migraphx.shaped<2x2xf16, 2x1>) -> !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> { + %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf16, 8x4x2x1>, <2x2xf16, 2x1> -> <3x2x2x2xf16, 8x4x2x1> + return %0 : !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> +} + +// CHECK-LABEL: func.func @migraphx_dot_leading_ones_b_rank3 +// CHECK-NEXT: migraphx.dot +func.func @migraphx_dot_leading_ones_b_rank3(%arg0: !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>, %arg1: !migraphx.shaped<1x2x2xf16, 4x2x1>) -> !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> { + %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf16, 8x4x2x1>, <1x2x2xf16, 4x2x1> -> <3x2x2x2xf16, 8x4x2x1> + return %0 : !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> +} + +// CHECK-LABEL: func.func @migraphx_dot_leading_ones_b_rank4 +// CHECK-NEXT: migraphx.dot +func.func @migraphx_dot_leading_ones_b_rank4(%arg0: !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>, %arg1: !migraphx.shaped<1x1x2x2xf16, 4x2x1x1>) -> !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> { + %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf16, 8x4x2x1>, <1x1x2x2xf16, 4x2x1x1> -> <3x2x2x2xf16, 8x4x2x1> + return %0 : !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> +}