-
Notifications
You must be signed in to change notification settings - Fork 25
[water] Replace DictionaryAttr with WaveIndexExprsAttr for ordered index expressions #730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
10a3b95 to
9406810
Compare
Add new attributes to represent ordered index expressions that preserve dimension order, unlike DictionaryAttr which sorts entries alphabetically. This is the first step toward fixing the index expression ordering issue where dimension order is lost during lattice operations because: 1. DenseMap loses insertion order during join operations 2. DictionaryAttr::get() always sorts entries alphabetically New attributes: - WaveIndexEntryAttr: A single (dimension, mapping) pair - WaveIndexExprsAttr: An ordered array of entries The order corresponds to the tensor type's shape dimension order, which is critical for lowering when the tensor type has been converted to memref and is no longer available. Syntax: index_exprs<[@m : <mapping>, @k : <mapping>, @n : <mapping>]> Signed-off-by: tyb0807 <sontuan.vu@amd.com>
d51c2b5 to
746aef2
Compare
|
|
||
| def WaveIndexExprsArrayAttr : TypedArrayAttrBase<WaveIndexExprsAttr, | ||
| "array of WaveIndexExprsAttr"> { | ||
| let constBuilderCall = "$_builder.getArrayAttr($0)"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we leave the constBuilderCall out and just use the default builder implementation?
68c8cfe to
9a4f30f
Compare
…xpressions This commit uses WaveIndexExprsAttr and WaveIndexEntryAttr to replace DictionaryAttr for storing index expressions. The key motivation is that DictionaryAttr alphabetically sorts its entries, but dimension order must be preserved to match the tensor type's shape order for correct lowering. Changes: - Custom parsing/printing that preserves order - Update IndexExprsLatticeStorage to use WaveIndexExprsAttr - Update join() logic to use MapVector for order preservation - Add C API and Python bindings - Update Python emitter/converter The join() logic uses LHS-first ordering policy: LHS entries come first (in LHS order), then new RHS entries (in RHS order). This ensures deterministic output when joining lattices. Signed-off-by: tyb0807 <sontuan.vu@amd.com>
ftynse
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First batch
| // Index expression attributes (ordered) | ||
| //----------------------------------------------------------------------------- | ||
|
|
||
| def WaveIndexEntryAttr : AttrDef<WaveDialect, "WaveIndexEntry"> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the point of having entry as an attribute? Do we ever need a single entry? Otherwise, we are just increasing the cost of manipulating these objects without benefit: attributes are created under lock and accessing each attribute adds a pointer indirection to get to the context-owned memory. Unless we use or intend to use index entries separately, we just store an array of pairs, pretty much like DictAttr does.
| /// Look up the index mapping for a given dimension symbol. | ||
| /// Returns std::nullopt if the dimension is not found. | ||
| /// Complexity: O(n) where n is the number of dimensions (typically 2-4). | ||
| std::optional<::wave::WaveIndexMappingAttr> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You never need optional around a type that is nullable. Just use null as failure. Otherwise you have two "failure" states with no clear difference.
| /// Look up by dimension name string. | ||
| std::optional<::wave::WaveIndexMappingAttr> | ||
| lookup(::llvm::StringRef dimensionName) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to avoid string-based APIs as much as possible. If we don't need it, let's remove this.
| lookup(::llvm::StringRef dimensionName) const; | ||
|
|
||
| /// Get the ordered list of dimension symbols. | ||
| ::llvm::SmallVector<::wave::WaveSymbolAttr> getDimensions() const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I think you can make the return type auto if you put implementation here, and just return llvm::map_range(...) (the actual type is atrocious), which will avoid the need to construct and potentially copy a vector.
| /// Ordering semantics: | ||
| /// - LHS entries come first (in LHS order), then RHS-only entries (in RHS | ||
| /// order). | ||
| /// - Entries with the same dimension have their mappings merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merged how exactly? This should refer to the actual implementation function that has this documentation.
| /// order. This is the common case for elementwise ops like wave.add. | ||
| /// | ||
| /// 2. **MMA ops**: LHS has {M, K}, RHS has {N, K}, accumulator/result has | ||
| /// {M, N}. The `ignoredRhsSymbols` parameter filters dimensions that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem to say anything about the order, which is the kinda the point of this documentation...
| /// 3. **Iterate ops**: Block arguments are joined with iter_args, and | ||
| /// terminator operands with results. Both should have matching tensor | ||
| /// types and thus matching dimension order. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This describes a quite fundamental verification constraint that any control-flow operation follows, why do we need it here exactly?
| ); | ||
|
|
||
| string commonArgumentsSyntax = "( `index` custom<WaveIndexDict>($index)^ )?"; | ||
| // Index is now printed as part of attr-dict using standard WaveIndexExprsAttr format. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then just remove commonArgumentsSyntax entirely.
Also, as I mentioned in some previous review, comments should explain the code as it is right now, not compare to how it used to be but is not anymore (the "now" part). This is a thoroughly unnecessary complexity forcing whoever reads the comment to remember or discover how things used to work, which makes the value of comments negative.
| if (parser.parseCustomAttributeWithFallback<WaveSymbolAttr>(dimension)) | ||
| return {}; | ||
|
|
||
| // Parse colon |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could guess this from the line immediately below. Please keep obvious comments to the minimum.
| SmallVector<DictionaryAttr> dicts; | ||
| dicts.reserve(arr.size()); | ||
| return op->emitError( | ||
| "'index' attribute must be an array of WaveIndexExprsAttr"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be one with a constraint in ODS
ftynse
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite follow the idea of LHS/RHS join of symbols. Now that index expressions are supposed to reflect the order of dimensions in some tensor, shouldn't we somehow enforce that order? It may be fine in the join itself, but we will need to reorder the entries before setting the final value after the analysis.
Relatedly, I suspect we want a verifier that checks whether the order of dimensions in the index mapping attribute matches that of some tensor. For elementwise arithmetic, this is straightforward, but one needs to think how to implement this for other operations, e.g., MMAs where none of the tensors have all dimensions? Reductions? Reshape?
Separately, how do we handle the situation when two mappings are identical except for the ordering?
| llvm::SmallVector<wave::WaveIndexEntryAttr> entries; | ||
| entries.reserve(result.size()); | ||
| for (auto &[dimension, mapping] : result) { | ||
| entries.push_back(wave::WaveIndexEntryAttr::get(ctx, dimension, mapping)); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't there a takeVector on MapVector? The name of the data structure kinda implies there's a vector inside.
| // Populate the entries with all index expressions. | ||
| void populate(llvm::SmallVectorImpl<wave::WaveIndexEntryAttr> &entries) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This renaming is a kind of a mindless churn that makes everyone's live harder for no good reason (attributes is a perfectly valid description of WaveIndexEntryAttr): rebases are atrocious, this PR appears on git blames when it is not very relevant, integration suffers. It may be easy to implement with a tool, but causes negative externalities ad should be avoided.
Now, ironically, I'm pushing to remove the EntryAttr, at which point this change will make more sense, so let's keep it, but it was important to point this out.
| Syntax: @M : [symbols] -> (start, step, stride) | ||
| Example: @M : [#wave.index_symbol<WG0>, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M, 1, BLOCK_M) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
THis PR literally has tests that show different syntax. Don't put examples in here, they are bitrotten before you even sent your PR!
Index expressions map tensor dimensions to thread/workgroup coordinates for memory access lowering. The dimension order in these expressions must match the tensor type's shape order (e.g.,
[@M, @K] -> {M: ..., K: ...}). However, DictionaryAttr alphabetically sorts its entries, so{M, K}becomes{K, M}. This causes incorrect lowering when the tensor type is no longer available (e.g., after conversion to memref).This PR introduces WaveIndexExprsAttr, an ordered attribute that preserves insertion order. It consists of:
Syntax:
#wave.index_exprs<[@M : <mapping>, @K : <mapping>]>The join() operation in lattice analysis uses a LHS-first policy (LHS entries first, then RHS-only entries) as a conventional choice to ensure deterministic output.
Changes: