Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ def MIGraphX_DotOp : MIGraphX_Op<"dot">,
let assemblyFormat = [{
$in_a `,` $in_b attr-dict `:` type($in_a) `,` type($in_b) `->` type($output)
}];
let hasVerifier = 1;
}

def MIGraphX_SoftmaxOp :
Expand Down
59 changes: 58 additions & 1 deletion mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,56 @@ LogicalResult UnpackOp::verify() {
return success();
}

static LogicalResult isValidDotOp(Operation *op, MIXRShapedType inAType,
MIXRShapedType inBType,
MIXRShapedType outputType) {
ArrayRef<int64_t> shapeA = inAType.getShape();
ArrayRef<int64_t> shapeB = inBType.getShape();
ArrayRef<int64_t> shapeOut = outputType.getShape();
int64_t outputRank = outputType.getRank();

if (!llvm::all_of(
ArrayRef<int64_t>{inAType.getRank(), inBType.getRank(), outputRank},
[](int64_t rank) { return rank >= 2; })) {
return op->emitOpError("expect operand to have rank greater or equal to 2");
}

// Batch dimensions (all dims except the last two) must be compatible.
// Broadcasting is allowed when one operand's batch dims are all ones
// or when one operand has no batch dims (rank 2). For example:
// A = {3, 2, 2, 2} and B = {1, 1, 2, 2} (batch B is all ones) - valid
// A = {3, 2, 2, 2} and B = {2, 2} (B has no batch dims) - valid
// A = {3, 2, 2, 2} and B = {2, 3, 2, 2} (batch dims differ) - invalid
ArrayRef<int64_t> batchA = shapeA.drop_back(2);
ArrayRef<int64_t> batchB = shapeB.drop_back(2);
bool hasLeadingOnesB = llvm::all_of(batchB, [](int64_t d) { return d == 1; });
if (!hasLeadingOnesB &&
!std::equal(batchA.begin(), batchA.end(), batchB.begin(), batchB.end())) {
return op->emitOpError("batch dimension mismatch: the first operand (")
<< inAType << ") and the second operand (" << inBType
<< ") have incompatible batch dimensions";
}

int64_t lastAShape = shapeA[shapeA.size() - 1];
int64_t secondLastBShape = shapeB[shapeB.size() - 2];
if (lastAShape != secondLastBShape) {
return op->emitOpError(
"contraction dimension mismatch: the first operand (")
<< inAType << ") and the second operand (" << inBType
<< ") have incompatible contraction dimensions";
}

// checking the output dimension, which must match the input
if (!std::equal(shapeA.rbegin() + 2, shapeA.rend(), shapeOut.rbegin() + 2,
shapeOut.rend()) ||
*std::prev(shapeOut.end()) != *std::prev(shapeB.end()) ||
*std::prev(shapeOut.end(), 2) != *std::prev(shapeA.end(), 2)) {
return op->emitOpError("result type is inconsistent with input shapes");
}

return success();
}

LogicalResult QuantDotOp::verify() {
MIXRShapedType inAType = getInA().getType();
MIXRShapedType inBType = getInB().getType();
Expand Down Expand Up @@ -431,5 +481,12 @@ LogicalResult QuantDotOp::verify() {
return emitOpError("Quant Dot ops requires scales to be provided to use "
"f4E2M1FN element type");
}
return success();
return isValidDotOp(getOperation(), inAType, inBType, getType());
}

LogicalResult DotOp::verify() {
MIXRShapedType inAType = getInA().getType();
MIXRShapedType inBType = getInB().getType();

return isValidDotOp(getOperation(), inAType, inBType, getType());
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,6 @@ func.func @dot_f16(%arg0: !migraphx.shaped<8x64x64x320xf16, 1310720x20480x320x1>

// -----

func.func @dot_broadcast(%arg0: !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<6x2x2xf32, 4x2x1>) -> !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> attributes {kernel, arch="gfx950"} {
// expected-error @+2 {{operands must have the same rank}}
// expected-error @+1 {{failed to legalize operation 'migraphx.dot' that was explicitly marked illegal}}
%0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf32, 8x4x2x1>, <6x2x2xf32, 4x2x1> -> <3x2x2x2xf32, 8x4x2x1>
func.return %0 : !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>
}

// -----

func.func @dot_unranked_tensor(%arg0 : !migraphx.shaped<?x?x?xf32, ?x?x?>, %arg1: !migraphx.shaped<?x?x?xf32, ?x?x?>) -> !migraphx.shaped<?x?x?xf32, ?x?x?> {
// expected-error @+2 {{only static shape is supported for now}}
// expected-error @+1 {{failed to legalize operation 'migraphx.dot' that was explicitly marked illegal}}
Expand Down
45 changes: 45 additions & 0 deletions mlir/test/Dialect/MIGraphX/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,48 @@ func.func @migraphx_quant_dot_f4_n_scales(%arg0: !migraphx.shaped<1x16x512xf4E2M
-> <1x16x16xf32, 256x16x1>
return %0 : !migraphx.shaped<1x16x16xf32, 256x16x1>
}

// -----

// CHECK-LABEL: func.func @dot_rank_less_than_2
func.func @dot_rank_less_than_2(%arg0: !migraphx.shaped<320xf16, 1>, %arg1: !migraphx.shaped<320x64xf16, 64x1>) -> !migraphx.shaped<64xf16, 1> {
// expected-error @+1 {{expect operand to have rank greater or equal to 2}}
%0 = migraphx.dot %arg0, %arg1 : <320xf16, 1>, <320x64xf16, 64x1> -> <64xf16, 1>
return %0 : !migraphx.shaped<64xf16, 1>
}

// -----

// CHECK-LABEL: func.func @dot_incompatible_inner_dim
func.func @dot_incompatible_inner_dim(%arg0: !migraphx.shaped<2x64x320xf16, 20480x320x1>, %arg1: !migraphx.shaped<2x256x64xf16, 16384x64x1>) -> !migraphx.shaped<2x64x64xf16, 4096x64x1> {
// expected-error @+1 {{contraction dimension mismatch: the first operand}}
%0 = migraphx.dot %arg0, %arg1 : <2x64x320xf16, 20480x320x1>, <2x256x64xf16, 16384x64x1> -> <2x64x64xf16, 4096x64x1>
return %0 : !migraphx.shaped<2x64x64xf16, 4096x64x1>
}

// -----

// CHECK-LABEL: func.func @dot_invalid_batch
func.func @dot_invalid_batch(%arg0: !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<6x2x2xf32, 4x2x1>) -> !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> attributes {kernel, arch="gfx950"} {
// expected-error@+1 {{batch dimension mismatch: the first operand ('!migraphx.shaped<3x2x2x2xf32, 8x4x2x1>') and the second operand ('!migraphx.shaped<6x2x2xf32, 4x2x1>') have incompatible batch dimensions}}
%0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf32, 8x4x2x1>, <6x2x2xf32, 4x2x1> -> <3x2x2x2xf32, 8x4x2x1>
func.return %0 : !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>
}

// -----

// CHECK-LABEL: func.func @dot_invalid_broadcast
func.func @dot_invalid_broadcast(%arg0: !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<2x3x2x2xf32, 12x4x2x1>) -> !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> attributes {kernel, arch="gfx950"} {
// expected-error@+1 {{batch dimension mismatch: the first operand}}
%0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf32, 8x4x2x1>, <2x3x2x2xf32, 12x4x2x1> -> <3x2x2x2xf32, 8x4x2x1>
func.return %0 : !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>
}

// -----

// CHECK-LABEL: func.func @dot_result_shape_mismatch
func.func @dot_result_shape_mismatch(%arg0: !migraphx.shaped<2x3x4xf16, 12x4x1>, %arg1: !migraphx.shaped<2x4x5xf16, 20x5x1>) -> !migraphx.shaped<2x3x4xf16, 12x4x1> {
// expected-error @+1 {{result type is inconsistent with input shapes}}
%0 = migraphx.dot %arg0, %arg1 : <2x3x4xf16, 12x4x1>, <2x4x5xf16, 20x5x1> -> <2x3x4xf16, 12x4x1>
return %0 : !migraphx.shaped<2x3x4xf16, 12x4x1>
}
22 changes: 22 additions & 0 deletions mlir/test/Dialect/MIGraphX/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,25 @@ func.func @migraphx_quant_dot_scaled(%arg0: !migraphx.shaped<1x16x512xf4E2M1FN,
-> !migraphx.shaped<1x16x16xf32, 256x16x1>
return %0 : !migraphx.shaped<1x16x16xf32, 256x16x1>
}

// Checking to see if the verifier allows for broadcast
// CHECK-LABEL: func.func @migraphx_dot_no_batch_b
// CHECK-NEXT: migraphx.dot
func.func @migraphx_dot_no_batch_b(%arg0: !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>, %arg1: !migraphx.shaped<2x2xf16, 2x1>) -> !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> {
%0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf16, 8x4x2x1>, <2x2xf16, 2x1> -> <3x2x2x2xf16, 8x4x2x1>
return %0 : !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>
}

// CHECK-LABEL: func.func @migraphx_dot_leading_ones_b_rank3
// CHECK-NEXT: migraphx.dot
func.func @migraphx_dot_leading_ones_b_rank3(%arg0: !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>, %arg1: !migraphx.shaped<1x2x2xf16, 4x2x1>) -> !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> {
%0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf16, 8x4x2x1>, <1x2x2xf16, 4x2x1> -> <3x2x2x2xf16, 8x4x2x1>
return %0 : !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>
}

// CHECK-LABEL: func.func @migraphx_dot_leading_ones_b_rank4
// CHECK-NEXT: migraphx.dot
func.func @migraphx_dot_leading_ones_b_rank4(%arg0: !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>, %arg1: !migraphx.shaped<1x1x2x2xf16, 4x2x1x1>) -> !migraphx.shaped<3x2x2x2xf16, 8x4x2x1> {
%0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf16, 8x4x2x1>, <1x1x2x2xf16, 4x2x1x1> -> <3x2x2x2xf16, 8x4x2x1>
return %0 : !migraphx.shaped<3x2x2x2xf16, 8x4x2x1>
}