From d0982bbcd89662ca9e8fb582e12412d60e35742c Mon Sep 17 00:00:00 2001 From: Vincent Date: Wed, 11 Feb 2026 22:21:39 +0000 Subject: [PATCH 1/4] Added verifier for `migraphx.dot` --- .../mlir/Dialect/MIGraphX/IR/MIGraphX.td | 1 + mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp | 49 ++++++++++++++++++- mlir/test/Dialect/MIGraphX/invalid.mlir | 36 ++++++++++++++ mlir/test/Dialect/MIGraphX/ops.mlir | 22 +++++++++ 4 files changed, 107 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td b/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td index 55de986b6721..d8f97698af49 100644 --- a/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td +++ b/mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td @@ -564,6 +564,7 @@ def MIGraphX_DotOp : MIGraphX_Op<"dot">, let assemblyFormat = [{ $in_a `,` $in_b attr-dict `:` type($in_a) `,` type($in_b) `->` type($output) }]; + let hasVerifier = 1; } def MIGraphX_SoftmaxOp : diff --git a/mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp b/mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp index 605881e9c691..097f1a4c5da2 100644 --- a/mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp +++ b/mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp @@ -384,6 +384,46 @@ LogicalResult UnpackOp::verify() { return success(); } +static LogicalResult isValidDotOp(Operation *op, MIXRShapedType inAType, + MIXRShapedType inBType, + MIXRShapedType outputType) { + ArrayRef shapeA = inAType.getShape(); + ArrayRef shapeB = inBType.getShape(); + int64_t outputRank = outputType.getRank(); + + if (!llvm::all_of( + ArrayRef{inAType.getRank(), inBType.getRank(), outputRank}, + [](int64_t rank) { return rank >= 2; })) { + return op->emitOpError("expect operand to have rank greater or equal to 2"); + } + + // Batch dimensions (all dims except the last two) must be compatible. + // Broadcasting is allowed when one operand's batch dims are all ones + // or when one operand has no batch dims (rank 2). For example: + // A = {3, 2, 2, 2} and B = {1, 1, 2, 2} (batch B is all ones) - valid + // A = {3, 2, 2, 2} and B = {2, 2} (B has no batch dims) - valid + // A = {3, 2, 2, 2} and B = {2, 3, 2, 2} (batch dims differ) - invalid + ArrayRef batchA = shapeA.drop_back(2); + ArrayRef batchB = shapeB.drop_back(2); + bool hasLeadingOnesB = llvm::all_of(batchB, [](int64_t d) { return d == 1; }); + if (!hasLeadingOnesB && + !std::equal(batchA.begin(), batchA.end(), batchB.begin(), batchB.end())) { + return op->emitOpError("batch dimension mismatch: the first operand (") + << inAType << ") and the second operand (" << inBType + << ") have incompatible batch dimensions"; + } + + int64_t lastAShape = shapeA[shapeA.size() - 1]; + int64_t secondLastBShape = shapeB[shapeB.size() - 2]; + if (lastAShape != secondLastBShape) { + return op->emitOpError("the first operand (") + << inAType << ") and the second operand(" << inBType + << "are incompatible"; + } + + return success(); +} + LogicalResult QuantDotOp::verify() { MIXRShapedType inAType = getInA().getType(); MIXRShapedType inBType = getInB().getType(); @@ -431,5 +471,12 @@ LogicalResult QuantDotOp::verify() { return emitOpError("Quant Dot ops requires scales to be provided to use " "f4E2M1FN element type"); } - return success(); + return isValidDotOp(getOperation(), inAType, inBType, getType()); +} + +LogicalResult DotOp::verify() { + MIXRShapedType inAType = getInA().getType(); + MIXRShapedType inBType = getInB().getType(); + + return isValidDotOp(getOperation(), inAType, inBType, getType()); } diff --git a/mlir/test/Dialect/MIGraphX/invalid.mlir b/mlir/test/Dialect/MIGraphX/invalid.mlir index 3e5a94c0e3c7..57f3028c19d1 100644 --- a/mlir/test/Dialect/MIGraphX/invalid.mlir +++ b/mlir/test/Dialect/MIGraphX/invalid.mlir @@ -225,3 +225,39 @@ func.func @migraphx_quant_dot_f4_n_scales(%arg0: !migraphx.shaped<1x16x512xf4E2M -> <1x16x16xf32, 256x16x1> return %0 : !migraphx.shaped<1x16x16xf32, 256x16x1> } + +// ----- + +// CHECK-LABEL: func.func @dot_rank_less_than_2 +func.func @dot_rank_less_than_2(%arg0: !migraphx.shaped<320xf16, 1>, %arg1: !migraphx.shaped<320x64xf16, 64x1>) -> !migraphx.shaped<64xf16, 1> { + // expected-error @+1 {{expect operand to have rank greater or equal to 2}} + %0 = migraphx.dot %arg0, %arg1 : <320xf16, 1>, <320x64xf16, 64x1> -> <64xf16, 1> + return %0 : !migraphx.shaped<64xf16, 1> +} + +// ----- + +// CHECK-LABEL: func.func @dot_incompatible_inner_dim +func.func @dot_incompatible_inner_dim(%arg0: !migraphx.shaped<2x64x320xf16, 20480x320x1>, %arg1: !migraphx.shaped<2x256x64xf16, 16384x64x1>) -> !migraphx.shaped<2x64x64xf16, 4096x64x1> { + // expected-error @+1 {{the first operand ('!migraphx.shaped<2x64x320xf16, 20480x320x1>') and the second operand('!migraphx.shaped<2x256x64xf16, 16384x64x1>'are incompatible}} + %0 = migraphx.dot %arg0, %arg1 : <2x64x320xf16, 20480x320x1>, <2x256x64xf16, 16384x64x1> -> <2x64x64xf16, 4096x64x1> + return %0 : !migraphx.shaped<2x64x64xf16, 4096x64x1> +} + +// ----- + +// CHECK-LABEL: func.func @dot_invalid_batch +func.func @dot_invalid_batch(%arg0: !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<6x2x2xf32, 4x2x1>) -> !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> attributes {kernel, arch="gfx950"} { + // expected-error@+1 {{batch dimension mismatch: the first operand ('!migraphx.shaped<3x2x2x2xf32, 8x4x2x1>') and the second operand ('!migraphx.shaped<6x2x2xf32, 4x2x1>') have incompatible batch dimensions}} + %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf32, 8x4x2x1>, <6x2x2xf32, 4x2x1> -> <3x2x2x2xf32, 8x4x2x1> + func.return %0 : !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> +} + +// ----- + +// CHECK-LABEL: func.func @dot_invalid_batch +func.func @dot_invalid_broadcast(%arg0: !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<6x2x2xf32, 4x2x1>) -> !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> attributes {kernel, arch="gfx950"} { + // expected-error@+1 {{batch dimension mismatch: the first operand}} + %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf32, 8x4x2x1>, <6x2x2xf32, 4x2x1> -> <3x2x2x2xf32, 8x4x2x1> + func.return %0 : !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> +} diff --git a/mlir/test/Dialect/MIGraphX/ops.mlir b/mlir/test/Dialect/MIGraphX/ops.mlir index 047ad2b4caf2..f22240c07f9d 100644 --- a/mlir/test/Dialect/MIGraphX/ops.mlir +++ b/mlir/test/Dialect/MIGraphX/ops.mlir @@ -24,3 +24,25 @@ func.func @migraphx_quant_dot_scaled(%arg0: !migraphx.shaped<1x16x512xf4E2M1FN, -> !migraphx.shaped<1x16x16xf32, 256x16x1> return %0 : !migraphx.shaped<1x16x16xf32, 256x16x1> } + +// Checking to see if the verifier allows for broadcast +// CHECK-LABEL: func.func @migraphx_dot_no_batch_b +// CHECK-NEXT: migraphx.dot +func.func @migraphx_dot_no_batch_b(%arg0: !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>, %arg1: !migraphx.shaped<2x2xf16, 2x1>) -> !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> { + %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf16, 8x4x2x1>, <2x2xf16, 2x1> -> <3x2x2x2xf16, 8x4x2x1> + return %0 : !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> +} + +// CHECK-LABEL: func.func @migraphx_dot_leading_ones_b_rank3 +// CHECK-NEXT: migraphx.dot +func.func @migraphx_dot_leading_ones_b_rank3(%arg0: !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>, %arg1: !migraphx.shaped<1x2x2xf16, 4x2x1>) -> !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> { + %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf16, 8x4x2x1>, <1x2x2xf16, 4x2x1> -> <3x2x2x2xf16, 8x4x2x1> + return %0 : !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> +} + +// CHECK-LABEL: func.func @migraphx_dot_leading_ones_b_rank4 +// CHECK-NEXT: migraphx.dot +func.func @migraphx_dot_leading_ones_b_rank4(%arg0: !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>, %arg1: !migraphx.shaped<1x1x2x2xf16, 4x2x1x1>) -> !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> { + %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf16, 8x4x2x1>, <1x1x2x2xf16, 4x2x1x1> -> <3x2x2x2xf16, 8x4x2x1> + return %0 : !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> +} From 0586d0a24b81b0b004bf1d6f0e8dcf6a0c89a85b Mon Sep 17 00:00:00 2001 From: Vincent Date: Wed, 11 Feb 2026 22:54:40 +0000 Subject: [PATCH 2/4] update testcase --- mlir/test/Dialect/MIGraphX/invalid.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/test/Dialect/MIGraphX/invalid.mlir b/mlir/test/Dialect/MIGraphX/invalid.mlir index 57f3028c19d1..00e20f5fa932 100644 --- a/mlir/test/Dialect/MIGraphX/invalid.mlir +++ b/mlir/test/Dialect/MIGraphX/invalid.mlir @@ -256,8 +256,8 @@ func.func @dot_invalid_batch(%arg0: !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>, %arg // ----- // CHECK-LABEL: func.func @dot_invalid_batch -func.func @dot_invalid_broadcast(%arg0: !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<6x2x2xf32, 4x2x1>) -> !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> attributes {kernel, arch="gfx950"} { +func.func @dot_invalid_broadcast(%arg0: !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<2x3x2x2xf32, 12x4x2x1>) -> !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> attributes {kernel, arch="gfx950"} { // expected-error@+1 {{batch dimension mismatch: the first operand}} - %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf32, 8x4x2x1>, <6x2x2xf32, 4x2x1> -> <3x2x2x2xf32, 8x4x2x1> + %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf32, 8x4x2x1>, <2x3x2x2xf32, 12x4x2x1> -> <3x2x2x2xf32, 8x4x2x1> func.return %0 : !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> } From 38d85b4f7c6ece9dc73df29534cf743ce4bfe25d Mon Sep 17 00:00:00 2001 From: Vincent Date: Thu, 12 Feb 2026 14:52:04 +0000 Subject: [PATCH 3/4] Address comments --- mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp | 16 +++++++++++++--- mlir/test/Dialect/MIGraphX/invalid.mlir | 13 +++++++++++-- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp b/mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp index 097f1a4c5da2..30707536a049 100644 --- a/mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp +++ b/mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp @@ -389,6 +389,7 @@ static LogicalResult isValidDotOp(Operation *op, MIXRShapedType inAType, MIXRShapedType outputType) { ArrayRef shapeA = inAType.getShape(); ArrayRef shapeB = inBType.getShape(); + ArrayRef shapeOut = outputType.getShape(); int64_t outputRank = outputType.getRank(); if (!llvm::all_of( @@ -416,9 +417,18 @@ static LogicalResult isValidDotOp(Operation *op, MIXRShapedType inAType, int64_t lastAShape = shapeA[shapeA.size() - 1]; int64_t secondLastBShape = shapeB[shapeB.size() - 2]; if (lastAShape != secondLastBShape) { - return op->emitOpError("the first operand (") - << inAType << ") and the second operand(" << inBType - << "are incompatible"; + return op->emitOpError( + "contraction dimension mismatch: the first operand (") + << inAType << ") and the second operand (" << inBType + << ") have incompatible contraction dimensions"; + } + + // checking the output dimension, which must match the input + if (!std::equal(shapeA.rbegin() + 2, shapeA.rend(), shapeOut.rbegin() + 2, + shapeOut.rend()) || + *std::prev(shapeOut.end()) != *std::prev(shapeB.end()) || + *std::prev(shapeOut.end(), 2) != *std::prev(shapeA.end(), 2)) { + return op->emitOpError("result type is inconsistent with input shapes"); } return success(); diff --git a/mlir/test/Dialect/MIGraphX/invalid.mlir b/mlir/test/Dialect/MIGraphX/invalid.mlir index 00e20f5fa932..32c92195fd9b 100644 --- a/mlir/test/Dialect/MIGraphX/invalid.mlir +++ b/mlir/test/Dialect/MIGraphX/invalid.mlir @@ -239,7 +239,7 @@ func.func @dot_rank_less_than_2(%arg0: !migraphx.shaped<320xf16, 1>, %arg1: !mig // CHECK-LABEL: func.func @dot_incompatible_inner_dim func.func @dot_incompatible_inner_dim(%arg0: !migraphx.shaped<2x64x320xf16, 20480x320x1>, %arg1: !migraphx.shaped<2x256x64xf16, 16384x64x1>) -> !migraphx.shaped<2x64x64xf16, 4096x64x1> { - // expected-error @+1 {{the first operand ('!migraphx.shaped<2x64x320xf16, 20480x320x1>') and the second operand('!migraphx.shaped<2x256x64xf16, 16384x64x1>'are incompatible}} + // expected-error @+1 {{contraction dimension mismatch: the first operand}} %0 = migraphx.dot %arg0, %arg1 : <2x64x320xf16, 20480x320x1>, <2x256x64xf16, 16384x64x1> -> <2x64x64xf16, 4096x64x1> return %0 : !migraphx.shaped<2x64x64xf16, 4096x64x1> } @@ -255,9 +255,18 @@ func.func @dot_invalid_batch(%arg0: !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>, %arg // ----- -// CHECK-LABEL: func.func @dot_invalid_batch +// CHECK-LABEL: func.func @dot_invalid_broadcast func.func @dot_invalid_broadcast(%arg0: !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<2x3x2x2xf32, 12x4x2x1>) -> !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> attributes {kernel, arch="gfx950"} { // expected-error@+1 {{batch dimension mismatch: the first operand}} %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf32, 8x4x2x1>, <2x3x2x2xf32, 12x4x2x1> -> <3x2x2x2xf32, 8x4x2x1> func.return %0 : !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> } + +// ----- + +// CHECK-LABEL: func.func @dot_result_shape_mismatch +func.func @dot_result_shape_mismatch(%arg0: !migraphx.shaped<2x3x4xf16, 12x4x1>, %arg1: !migraphx.shaped<2x4x5xf16, 20x5x1>) -> !migraphx.shaped<2x3x4xf16, 12x4x1> { + // expected-error @+1 {{result type is inconsistent with input shapes}} + %0 = migraphx.dot %arg0, %arg1 : <2x3x4xf16, 12x4x1>, <2x4x5xf16, 20x5x1> -> <2x3x4xf16, 12x4x1> + return %0 : !migraphx.shaped<2x3x4xf16, 12x4x1> +} From bd72b42dc6f1df8be1af1fc70eb3fb2da341734e Mon Sep 17 00:00:00 2001 From: Vincent Date: Thu, 12 Feb 2026 17:17:48 +0000 Subject: [PATCH 4/4] Remove duplicate testcase --- .../MIGraphXToLinalg/migraphx-to-linalg-dot.mlir | 9 --------- 1 file changed, 9 deletions(-) diff --git a/mlir/test/Conversion/MIGraphXToLinalg/migraphx-to-linalg-dot.mlir b/mlir/test/Conversion/MIGraphXToLinalg/migraphx-to-linalg-dot.mlir index 4e9ab03592da..1dec63931fd2 100644 --- a/mlir/test/Conversion/MIGraphXToLinalg/migraphx-to-linalg-dot.mlir +++ b/mlir/test/Conversion/MIGraphXToLinalg/migraphx-to-linalg-dot.mlir @@ -72,15 +72,6 @@ func.func @dot_f16(%arg0: !migraphx.shaped<8x64x64x320xf16, 1310720x20480x320x1> // ----- -func.func @dot_broadcast(%arg0: !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<6x2x2xf32, 4x2x1>) -> !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> attributes {kernel, arch="gfx950"} { - // expected-error @+2 {{operands must have the same rank}} - // expected-error @+1 {{failed to legalize operation 'migraphx.dot' that was explicitly marked illegal}} - %0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf32, 8x4x2x1>, <6x2x2xf32, 4x2x1> -> <3x2x2x2xf32, 8x4x2x1> - func.return %0 : !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> -} - -// ----- - func.func @dot_unranked_tensor(%arg0 : !migraphx.shaped, %arg1: !migraphx.shaped) -> !migraphx.shaped { // expected-error @+2 {{only static shape is supported for now}} // expected-error @+1 {{failed to legalize operation 'migraphx.dot' that was explicitly marked illegal}}