Paged Attention: migraphx & highlevel pipeline changes#2220
Open
justinrosner wants to merge 5 commits intodevelopfrom
Open
Paged Attention: migraphx & highlevel pipeline changes#2220justinrosner wants to merge 5 commits intodevelopfrom
migraphx & highlevel pipeline changes#2220justinrosner wants to merge 5 commits intodevelopfrom
Conversation
migraphx & highlevel pipeline changesmigraphx & highlevel pipeline changes
Contributor
There was a problem hiding this comment.
Pull request overview
This pull request introduces paged attention support for the migraphx and highlevel pipelines. The implementation adds a new deref operation to handle pointer dereferences for paged memory access, and extends the attention operations to support optional key/value addresses for paged attention.
Changes:
- Introduces
migraphx.derefandrock.derefoperations for lazy pointer dereferencing in paged attention - Extends
rock.attentionandrock.gridwise_attention_accelwith optionalkeyAddressesandvalueAddressesoperands - Adds conversion patterns from MIGraphX → TOSA → Rock for the deref operation and paged attention support
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td | Defines migraphx.deref op with UI64 input constraint |
| mlir/include/mlir/Dialect/Rock/IR/RockOps.td | Defines rock.deref op and extends attention ops with address operands |
| mlir/include/mlir/Dialect/Rock/IR/RockTosaCustomOps.h | Adds ROCK_CUSTOMOP_DEREF constant |
| mlir/include/mlir/Conversion/TosaToRock/TosaToRock.h | Declares populateTosaToRockDerefPatterns function |
| mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp | Implements verification for migraphx.deref (shape/stride matching) |
| mlir/lib/Dialect/Rock/IR/RockDialect.cpp | Implements verification for rock.deref (rank-3 constraint) and attention ops (paged attention validation) |
| mlir/lib/Conversion/MIGraphXToTosa/MIGraphXToTosa.cpp | Converts migraphx.deref to tosa.custom deref op |
| mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | Converts tosa.custom deref to rock.deref with pattern matching; detects paged K/V in attention |
| mlir/lib/Conversion/TosaToRock/TosaToRockPass.cpp | Adds deref conversion stage before attention patterns |
| mlir/lib/Dialect/Rock/Transforms/BufferizableOpInterfaceImpl.cpp | Implements bufferization for rock.deref |
| mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp | Threads keyAddresses/valueAddresses through AttentionOp lowering |
| mlir/lib/Dialect/Rock/Transforms/SortDimensionsMemoryLayout.cpp | Threads keyAddresses/valueAddresses through AttentionOp rewrite |
| mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp | Threads keyAddresses/valueAddresses through flash decoding pattern |
| mlir/tools/rocmlir-gen/rocmlir-gen.cpp | Updates AttentionOp::create call with nullptr for new operands |
| mlir/test/Dialect/Rock/ops_error.mlir | Updates operandSegmentSizes for two test cases |
| mlir/test/Dialect/MIGraphX/ops.mlir | Adds test for migraphx.deref |
| mlir/test/Dialect/MIGraphX/invalid.mlir | Adds negative tests for migraphx.deref |
| mlir/test/Conversion/MIGraphXToTosa/mixr-to-tosa-ops.mlir | Tests migraphx.deref to tosa.custom conversion |
| mlir/test/Conversion/TosaToRock/tosa-to-rock-paged-attention.mlir | Tests deref and paged attention conversions |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
mlir/test/Conversion/TosaToRock/tosa-to-rock-paged-attention.mlir
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Rock/Transforms/BufferizableOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
This PR introduces
migraphxandhighlevelpipeline support for paged attention. This implements: https://amd-hub.atlassian.net/browse/AIROCMLIR-46Technical Details
This PR implements the following changes:
derefop for paged attentionmigraphx.derefthat gets lowered totosa.customderef op (MIGraphXToTosa)tosa.customderef will get lowered torock.deref(TosaToRock)rock.attentionandrock.gridwise_attention_accelwithkeyAddressesandvalueAddresses.DerefOpInterfaceforrock.derefto support bufferization (conversion to memrefs).Note,
rock.derefacts as a dererred/lazy load descriptor rather than an immediate load operation. At a high-level it doesn't actually load anything, it declares:The actual memory loads happen much later during tiled/blockwise lowering. Going with this approach allows for us to use the existing pipeline of applying
rock.transformsto the K/V input to attention ops in places like SortDimensionsMemoryLayout, etc.Test Plan
Test Result
Submission Checklist