[AIROCMLIR-446] Lower migraphx.dot into linalg.matmul/batch_matmul#2223
[AIROCMLIR-446] Lower migraphx.dot into linalg.matmul/batch_matmul#2223
migraphx.dot into linalg.matmul/batch_matmul#2223Conversation
| // CHECK-SAME: tensor<1x3x2xf32>, tensor<1x2x3xf32>) outs(%[[cst]] : tensor<1x3x3xf32>) -> tensor<1x3x3xf32> | ||
| // CHECK-NEXT: %[[collapsed:.*]] = tensor.collapse_shape %[[zero]] | ||
| // CHECK-NEXT: return %[[collapsed]] : tensor<9xf32> | ||
| func.func @dot_one(%arg0 : !migraphx.shaped<1x3x2xf32, 6x2x1>, %arg1: !migraphx.shaped<1x2x3xf32, 6x3x1>) |
There was a problem hiding this comment.
nit: These tests are not formatted properly. For instance, this first line is broken down into 2 lines. Can you run rocmlir-opt with this file, it will give you the tests with proper format.
| // CHECK-SAME: tensor<1x3x2xf32>, tensor<1x2x3xf32>) outs(%[[cst]] : tensor<1x3x3xf32>) -> tensor<1x3x3xf32> | ||
| // CHECK-NEXT: %[[collapsed:.*]] = tensor.collapse_shape %[[zero]] | ||
| // CHECK-NEXT: return %[[collapsed]] : tensor<9xf32> | ||
| func.func @dot_one(%arg0 : !migraphx.shaped<1x3x2xf32, 6x2x1>, %arg1: !migraphx.shaped<1x2x3xf32, 6x3x1>) |
There was a problem hiding this comment.
I'm not sure what you mean with dot_one, dot_two, dot_three? It would make sense if it was 1D, 2D and 3D, but it's not the case. Can you rename them to dot_3D, dot_2D, and dot_4D (the one that fails)?
There was a problem hiding this comment.
I have renamed them to dot_3D, dot_2D, and dot_4D.
There was a problem hiding this comment.
The dot_one, and dot_two, and dot_three originally means the first dot, the second dot, and the third dot test.
| @@ -0,0 +1,49 @@ | |||
| // RUN: rocmlir-opt -split-input-file --migraphx-to-linalg -verify-diagnostics %s | FileCheck %s | |||
There was a problem hiding this comment.
I had a look at the IR that this generates and it looks great, good job!
|
|
||
| func.func @dot_three(%arg0 : !migraphx.shaped<1x1x3x2xf32, 6x6x2x1>, %arg1: !migraphx.shaped<1x1x2x3xf32, 6x6x3x1>) | ||
| -> !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>{ | ||
| // expected-error @+2 {{only support 2D/3D for now}} |
There was a problem hiding this comment.
This is fine for now, in the future we can give support to migraphx.dot in any dimension with linalg.generic
| populateCallOpTypeConversionPattern(patterns, typeConverter); | ||
| } | ||
|
|
||
| void mlir::migraphx::populateMIGrpahXToLinalgTrivialConverter( |
There was a problem hiding this comment.
nit: populateMIGraphXToLinalgTrivialConverter
| Location loc = op.getLoc(); | ||
| Value aIn = adaptor.getInA(); | ||
| Value bIn = adaptor.getInB(); | ||
| RankedTensorType aType = cast<TypedValue<RankedTensorType>>(aIn).getType(); |
There was a problem hiding this comment.
Does migraphx::DotOp accept unranked tensors as input? I think it does. If this is the case, this code will crash.
There was a problem hiding this comment.
My understanding is that the migraphx-to-tosa lowering doesn't support unranked tensor as input to my knowledge.
root@f83ce68a1182$ cat test.mlir
func.func @dot_one(%arg0 : !migraphx.shaped<?x?x?xf32, ?x?x?>, %arg1: !migraphx.shaped<?x?x?xf32, ?x?x?>) -> !migraphx.shaped<?x?x?xf32, ?x?x?> {
%0 = migraphx.dot %arg0, %arg1 : <?x?x?xf32, ?x?x?>, <?x?x?xf32, ?x?x?> -> <?x?x?xf32, ?x?x?>
func.return %0 : !migraphx.shaped<?x?x?xf32, ?x?x?>
}
root@f83ce68a1182$ ./bin/rocmlir-opt test.mlir --migraphx-to-tosa
<unknown>:0: error: !migraphx.shaped type with smallest stride -9223372036854775808 has no supported in-memory layout
test.mlir:1:1: error: failed to legalize operation 'func.func' that was explicitly marked illegal
func.func @dot_one(%arg0 : !migraphx.shaped<?x?x?xf32, ?x?x?>, %arg1: !migraphx.shaped<?x?x?xf32, ?x?x?>) -> !migraphx.shaped<?x?x?xf32, ?x?x?> {
^
test.mlir:1:1: note: see current operation:
"func.func"() <{function_type = (!migraphx.shaped<?x?x?xf32, ?x?x?>, !migraphx.shaped<?x?x?xf32, ?x?x?>) -> !migraphx.shaped<?x?x?xf32, ?x?x?>, sym_name = "dot_one"}> ({
^bb0(%arg0: !migraphx.shaped<?x?x?xf32, ?x?x?>, %arg1: !migraphx.shaped<?x?x?xf32, ?x?x?>):
%0 = "migraphx.mlir.as.logical.shape"(%arg1) : (!migraphx.shaped<?x?x?xf32, ?x?x?>) -> tensor<?x?x?xf32>
%1 = "migraphx.mlir.as.logical.shape"(%arg0) : (!migraphx.shaped<?x?x?xf32, ?x?x?>) -> tensor<?x?x?xf32>
%2 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%3 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%4 = "tosa.matmul"(%1, %0, %2, %3) {acc_type = f32} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?xf32>
%5 = "migraphx.mlir.as.underlying.shape"(%4) : (tensor<?x?x?xf32>) -> !migraphx.shaped<?x?x?xf32, ?x?x?>
"func.return"(%5) : (!migraphx.shaped<?x?x?xf32, ?x?x?>) -> ()
}) : () -> ()Is this a bug?
There was a problem hiding this comment.
I have also added a test case if unranked tensor type is not supported.
There was a problem hiding this comment.
I can't seem to get unranked tensor type to work as a input for the migrpahx.dot, is there a different syntax?
func.func @dot_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) attributes {kernel, arch="gfx950"}{
%test = migraphx.dot %arg0, %arg1: <*xf32>, <*xf32> -> <*xf32>
func.return
}
There was a problem hiding this comment.
Ok, it actually do not support unranked tensor, all must be static it seems. Look at this test:
func.func @mlir_reshape_dynamic_shape(%arg0: !migraphx.shaped<4096x?xf16, 0x1>) {
// expected-error@+1 {{'migraphx.reshape' op Dynamic shapes are not supported}}
%0 = migraphx.reshape %arg0 {dims = [4096, 4096]} : <4096x?xf16, 0x1> -> <4096x?xf16, 16536x2>
return
}
so we are good.
However, defensive programming does no harm. Can you add a check that the type is RankedTensorType and, if its not, throw an error? To avoid crashes. Like the "only static shape is supported for now" error below.
| RankedTensorType bType = cast<TypedValue<RankedTensorType>>(bIn).getType(); | ||
| ArrayRef<int64_t> aShape = aType.getShape(); | ||
| ArrayRef<int64_t> bShape = bType.getShape(); | ||
| int64_t dimension = aShape.size(); |
There was a problem hiding this comment.
I prefer rank for this variable
There was a problem hiding this comment.
Let's add more defensive programming here - add a condition that ensures that the shape of A and B has the same rank
| } | ||
|
|
||
| void MIGraphXToLinalgPass::runOnOperation() { | ||
| // MIGraphX to Linalg conversion is performed in two passes: |
There was a problem hiding this comment.
Adding this to the PR description would be useful
There was a problem hiding this comment.
I would also add this to the description of what this pass does at the top of the file.
|
|
||
| void mlir::linalg::populateMIGraphXToLinalgDialectConversion( | ||
| ConversionTarget &target) { | ||
| target.addLegalDialect<linalg::LinalgDialect, arith::ArithDialect, |
There was a problem hiding this comment.
Here we will be converting migraphx.dot to linalg equivalent, so I think we should add migraphx::DotOp to as IllegalOp
There was a problem hiding this comment.
I have added migraphx::MIGraphXDialect as a illegal dialect except migraphx::AsLogicalShapeOp and migraphx::AsUnderlyingShapeOp
8759718 to
6d3702b
Compare
dhernandez0
left a comment
There was a problem hiding this comment.
Can you change the title? we only convert migraphx.dot, not all migraphx ops, right?
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // Base kernels (convolution, gemm) |
There was a problem hiding this comment.
this only converts gemms, not convolutions?
There was a problem hiding this comment.
Yes, that is correct for now.
migraphx.dot into linalg.matmul/batch_matmul
| int64_t rank = aShape.size(); | ||
| if (aShape.size() != bShape.size()) { | ||
| return op.emitError("input a and b must have the same rank"); | ||
| } |
There was a problem hiding this comment.
I think some of these checks should be in the migraphx.dot verifier if possible (if they aren't already there).
There was a problem hiding this comment.
It doesn't seems like there is a migrpahx.dot verifier? I can't seem to find DotOp::verify in the codebase?
| rewriter.getZeroAttr(RankedTensorType::get( | ||
| outputShape, aType.getElementType()))); | ||
| auto matMulOp = | ||
| linalg::MatmulOp::create(rewriter, loc, {aIn, bIn}, zero, {}); |
There was a problem hiding this comment.
are the layouts of a and b exactly the same?
There was a problem hiding this comment.
I mean migraphx and linalg expeted layouts (MxK or KxM etc)
There was a problem hiding this comment.
also what about "acc_type", is there something like that in linalg.matmul?
There was a problem hiding this comment.
are the layouts of a and b exactly the same?
The layouts of a and b is not the same. Usually, it is [batch, M, N] and [batch, N, M].
I don't think there is a acc_type like equivalent in linalg.batch_matmul?
| } | ||
|
|
||
| // don't emit linalg.generic for 2D and 3D case to preserver type sugar | ||
| if (rank == 2) { |
There was a problem hiding this comment.
see MigraphxToTosa for handling of special cases such as broadcast etc
| outputShape, aType.getElementType()))); | ||
| auto matMulOp = | ||
| linalg::MatmulOp::create(rewriter, loc, {aIn, bIn}, zero, {}); | ||
| rewriter.replaceOp(op, matMulOp); |
There was a problem hiding this comment.
we should make sure to copy the perf_config attribute. See MigraphxToTosa
There was a problem hiding this comment.
let's add a TODO somewhere for scaled gemms.
There was a problem hiding this comment.
migraphx::DeQuantizeLinearConverter
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 17 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
mlir/include/mlir/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.h
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.h
Outdated
Show resolved
Hide resolved
| Value zero = | ||
| arith::ConstantOp::create(rewriter, loc, | ||
| rewriter.getZeroAttr(RankedTensorType::get( | ||
| outputShape, aType.getElementType()))); |
There was a problem hiding this comment.
zero's type shouldn't be set as aType. it should take whatever ist the type of the output in original migrpahx.dot
mlir/include/mlir/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.h
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
ccb06f3 to
98fcb90
Compare
0850d14 to
f12a85c
Compare
|
The logic for DotConverter now look similar to the one in MIGraphXTosa. The important thing to note is that this PR now somewhat supports "broadcast": now compile on both the tosa path and the linalg path. This seems to be intentional as well.
|
| RankedTensorType newAType = RankedTensorType::get(newDimsA, elementTy); | ||
| RankedTensorType newBType = RankedTensorType::get(newDimsB, elementTy); | ||
| newOutType = RankedTensorType::get(newDimsOut, newOutElementTy); | ||
| inA = (rankA == 2) |
There was a problem hiding this comment.
clang-format seems to be doing something weird here.
| .getResult(0); | ||
|
|
||
| // Convert optional attributes | ||
| if (auto attr = (*op).template getAttrOfType<StringAttr>("perf_config")) |
There was a problem hiding this comment.
Add a test with perf_config and make sure it is passed on when convert to linalg.
| !isa<RankedTensorType>(inB.getType())) { | ||
| return op.emitError("expected both operands to be RankedTensorType"); | ||
| } | ||
| Type elementTy = cast<RankedTensorType>(inA.getType()).getElementType(); |
There was a problem hiding this comment.
nit:
Create variables for RankedTensorType aRankedType = cast<RanekdTensorType>(inA.getType()) and use it everywhere.
| Type outElementTy = origOutputTy.getElementType(); | ||
| Type newOutElementTy = getTypeConverter()->convertType(outElementTy); | ||
|
|
||
| // check batch dimension. Tosa matmul only allow a single dimension for it, |
There was a problem hiding this comment.
nit:
We shouldn't mention tosa here. I get that it is to keep it aligned with tosa implementation but better not to mention it here.
| if (batchSizeA != batchSizeB || batchSizeC != batchSizeB) { | ||
| return op.emitError("cannot handle this broadcast for now"); |
There was a problem hiding this comment.
Just looking for batchSizeA == batchSizeB could lead to mismatch.
e.g.
I see you have a test case with {3, 2, 2, 2} and {2, 3, 2, 2} shapes. where batchSize = 6 but the dimensions do not really match.
I think it wouldn't generate correct results. Can you verify ?
There was a problem hiding this comment.
I think it gives correct results?
func.func @dot_broadcast(%arg0: !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>, %arg1: !migraphx.shaped<6x2x2xf32, 4x2x1>) -> !migraphx.shaped<3x2x2x2xf32, 8x4x2x1> attributes {kernel, arch="gfx950"} {
%0 = migraphx.dot %arg0, %arg1 : <3x2x2x2xf32, 8x4x2x1>, <6x2x2xf32, 4x2x1> -> <3x2x2x2xf32, 8x4x2x1>
func.return %0 : !migraphx.shaped<3x2x2x2xf32, 8x4x2x1>
}
The tosa and linalg lowering path seems to give similar results:
// ./bin/rocmlir-opt main.mlir --migraphx-to-linalg --canonicalize
module {
func.func @dot_broadcast(%arg0: tensor<24xf32>, %arg1: tensor<24xf32>) -> tensor<24xf32> attributes {arch = "gfx950", kernel} {
%cst = arith.constant dense<0.000000e+00> : tensor<6x2x2xf32>
%expanded = tensor.expand_shape %arg1 [[0, 1, 2]] output_shape [6, 2, 2] : tensor<24xf32> into tensor<6x2x2xf32>
%expanded_0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [6, 2, 2] : tensor<24xf32> into tensor<6x2x2xf32>
%0 = linalg.batch_matmul ins(%expanded_0, %expanded : tensor<6x2x2xf32>, tensor<6x2x2xf32>) outs(%cst : tensor<6x2x2xf32>) -> tensor<6x2x2xf32>
%collapsed = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<6x2x2xf32> into tensor<24xf32>
return %collapsed : tensor<24xf32>
}
}
// ./bin/rocmlir-opt main.mlir --migraphx-to-tosa --canonicalize
module {
func.func @dot_broadcast(%arg0: tensor<24xf32>, %arg1: tensor<24xf32>) -> tensor<24xf32> attributes {arch = "gfx950", kernel} {
%0 = tosa.const_shape {values = dense<24> : tensor<1xindex>} : () -> !tosa.shape<1>
%1 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%2 = tosa.const_shape {values = dense<[6, 2, 2]> : tensor<3xindex>} : () -> !tosa.shape<3>
%3 = tosa.reshape %arg1, %2 : (tensor<24xf32>, !tosa.shape<3>) -> tensor<6x2x2xf32>
%4 = tosa.reshape %arg0, %2 : (tensor<24xf32>, !tosa.shape<3>) -> tensor<6x2x2xf32>
%5 = tosa.matmul %4, %3, %1, %1 {acc_type = f32} : (tensor<6x2x2xf32>, tensor<6x2x2xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<6x2x2xf32>
%6 = tosa.reshape %5, %0 : (tensor<6x2x2xf32>, !tosa.shape<1>) -> tensor<24xf32>
return %6 : tensor<24xf32>
}
}
ad1c291 to
ef5d903
Compare
97d65f9 to
405b3e8
Compare
|
|
||
| if (rankA != rankB || rankB != outRank) { |
There was a problem hiding this comment.
We can support rankA != rankB in some cases.
e.g. A = {batch, m, k} and B = {k, n}. A can be folded {batch * m, k}`
umangyadav
left a comment
There was a problem hiding this comment.
Can you create ticket to handle cases for rankA != rankB and work on it later ?
Motivation
Lower MIGraphX GEMM into Linalg Dialect.
Technical Details
Lower from MIGraphX to Linalg initial changes.
MIGraphX to Linalg conversion is performed in two passes:
Pass 1: Convert MIGraphX operations to their Linalg equivalents. The !migraphx.shaped type contains both shape and stride (memory layout) information. During this pass, sourceMaterialization and targetMaterialization insert temporary ops (migraphx.mlir.as_logical_shape and migraphx.mlir.as_underlying_shape) to handle the type conversions.
Pass 2: Convert the boundary/materialization operations to proper memory layout representations using tensor operations, completing the conversion from MIGraphX's shaped types to standard tensor types.
Test Plan
Lit test for diagnostics and IR output.
Test Result
One lit test for now, and passed locally.
Submission Checklist