Skip to content

Conversation

@tyb0807
Copy link
Contributor

@tyb0807 tyb0807 commented Jan 14, 2026

Index expressions map tensor dimensions to thread/workgroup coordinates for memory access lowering. The dimension order in these expressions must match the tensor type's shape order (e.g., [@M, @K] -> {M: ..., K: ...}). However, DictionaryAttr alphabetically sorts its entries, so {M, K} becomes {K, M}. This causes incorrect lowering when the tensor type is no longer available (e.g., after conversion to memref).

This PR introduces WaveIndexExprsAttr, an ordered attribute that preserves insertion order. It consists of:

  • WaveIndexEntryAttr: A single (dimension, mapping) pair where dimension is a WaveSymbolAttr and mapping is a WaveIndexMappingAttr
  • WaveIndexExprsAttr: An ordered array of WaveIndexEntryAttr entries

Syntax: #wave.index_exprs<[@M : <mapping>, @K : <mapping>]>

The join() operation in lattice analysis uses a LHS-first policy (LHS entries first, then RHS-only entries) as a conventional choice to ensure deterministic output.

Changes:

  • Add WaveIndexEntryAttr and WaveIndexExprsAttr attribute definitions
  • Update IndexExprsLatticeStorage to use WaveIndexExprsAttr and MapVector for order preservation
  • Add C API and Python bindings
  • Update Python emitter/converter

@tyb0807 tyb0807 requested a review from ftynse January 14, 2026 14:24
@tyb0807 tyb0807 force-pushed the idx_attr branch 3 times, most recently from 10a3b95 to 9406810 Compare January 14, 2026 14:35
Add new attributes to represent ordered index expressions that preserve
dimension order, unlike DictionaryAttr which sorts entries alphabetically.

This is the first step toward fixing the index expression ordering issue
where dimension order is lost during lattice operations because:
1. DenseMap loses insertion order during join operations
2. DictionaryAttr::get() always sorts entries alphabetically

New attributes:
- WaveIndexEntryAttr: A single (dimension, mapping) pair
- WaveIndexExprsAttr: An ordered array of entries

The order corresponds to the tensor type's shape dimension order, which
is critical for lowering when the tensor type has been converted to
memref and is no longer available.

Syntax: index_exprs<[@m : <mapping>, @k : <mapping>, @n : <mapping>]>

Signed-off-by: tyb0807 <sontuan.vu@amd.com>
@tyb0807 tyb0807 force-pushed the idx_attr branch 2 times, most recently from d51c2b5 to 746aef2 Compare January 14, 2026 15:02

def WaveIndexExprsArrayAttr : TypedArrayAttrBase<WaveIndexExprsAttr,
"array of WaveIndexExprsAttr"> {
let constBuilderCall = "$_builder.getArrayAttr($0)";
Copy link
Contributor

@tgymnich tgymnich Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we leave the constBuilderCall out and just use the default builder implementation?

@tyb0807 tyb0807 force-pushed the idx_attr branch 7 times, most recently from 68c8cfe to 9a4f30f Compare January 15, 2026 10:25
…xpressions

This commit uses WaveIndexExprsAttr and WaveIndexEntryAttr to replace
DictionaryAttr for storing index expressions. The key motivation is that
DictionaryAttr alphabetically sorts its entries, but dimension order must
be preserved to match the tensor type's shape order for correct lowering.

Changes:
- Custom parsing/printing that preserves order
- Update IndexExprsLatticeStorage to use WaveIndexExprsAttr
- Update join() logic to use MapVector for order preservation
- Add C API and Python bindings
- Update Python emitter/converter

The join() logic uses LHS-first ordering policy: LHS entries come first
(in LHS order), then new RHS entries (in RHS order). This ensures
deterministic output when joining lattices.

Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Copy link
Contributor

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First batch

// Index expression attributes (ordered)
//-----------------------------------------------------------------------------

def WaveIndexEntryAttr : AttrDef<WaveDialect, "WaveIndexEntry"> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point of having entry as an attribute? Do we ever need a single entry? Otherwise, we are just increasing the cost of manipulating these objects without benefit: attributes are created under lock and accessing each attribute adds a pointer indirection to get to the context-owned memory. Unless we use or intend to use index entries separately, we just store an array of pairs, pretty much like DictAttr does.

/// Look up the index mapping for a given dimension symbol.
/// Returns std::nullopt if the dimension is not found.
/// Complexity: O(n) where n is the number of dimensions (typically 2-4).
std::optional<::wave::WaveIndexMappingAttr>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You never need optional around a type that is nullable. Just use null as failure. Otherwise you have two "failure" states with no clear difference.

Comment on lines +577 to +579
/// Look up by dimension name string.
std::optional<::wave::WaveIndexMappingAttr>
lookup(::llvm::StringRef dimensionName) const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to avoid string-based APIs as much as possible. If we don't need it, let's remove this.

lookup(::llvm::StringRef dimensionName) const;

/// Get the ordered list of dimension symbols.
::llvm::SmallVector<::wave::WaveSymbolAttr> getDimensions() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think you can make the return type auto if you put implementation here, and just return llvm::map_range(...) (the actual type is atrocious), which will avoid the need to construct and potentially copy a vector.

/// Ordering semantics:
/// - LHS entries come first (in LHS order), then RHS-only entries (in RHS
/// order).
/// - Entries with the same dimension have their mappings merged.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merged how exactly? This should refer to the actual implementation function that has this documentation.

/// order. This is the common case for elementwise ops like wave.add.
///
/// 2. **MMA ops**: LHS has {M, K}, RHS has {N, K}, accumulator/result has
/// {M, N}. The `ignoredRhsSymbols` parameter filters dimensions that
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to say anything about the order, which is the kinda the point of this documentation...

Comment on lines +431 to +433
/// 3. **Iterate ops**: Block arguments are joined with iter_args, and
/// terminator operands with results. Both should have matching tensor
/// types and thus matching dimension order.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This describes a quite fundamental verification constraint that any control-flow operation follows, why do we need it here exactly?

);

string commonArgumentsSyntax = "( `index` custom<WaveIndexDict>($index)^ )?";
// Index is now printed as part of attr-dict using standard WaveIndexExprsAttr format.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then just remove commonArgumentsSyntax entirely.

Also, as I mentioned in some previous review, comments should explain the code as it is right now, not compare to how it used to be but is not anymore (the "now" part). This is a thoroughly unnecessary complexity forcing whoever reads the comment to remember or discover how things used to work, which makes the value of comments negative.

if (parser.parseCustomAttributeWithFallback<WaveSymbolAttr>(dimension))
return {};

// Parse colon
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could guess this from the line immediately below. Please keep obvious comments to the minimum.

SmallVector<DictionaryAttr> dicts;
dicts.reserve(arr.size());
return op->emitError(
"'index' attribute must be an array of WaveIndexExprsAttr");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be one with a constraint in ODS

Copy link
Contributor

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite follow the idea of LHS/RHS join of symbols. Now that index expressions are supposed to reflect the order of dimensions in some tensor, shouldn't we somehow enforce that order? It may be fine in the join itself, but we will need to reorder the entries before setting the final value after the analysis.

Relatedly, I suspect we want a verifier that checks whether the order of dimensions in the index mapping attribute matches that of some tensor. For elementwise arithmetic, this is straightforward, but one needs to think how to implement this for other operations, e.g., MMAs where none of the tensors have all dimensions? Reductions? Reshape?

Separately, how do we handle the situation when two mappings are identical except for the ordering?

Comment on lines +857 to +861
llvm::SmallVector<wave::WaveIndexEntryAttr> entries;
entries.reserve(result.size());
for (auto &[dimension, mapping] : result) {
entries.push_back(wave::WaveIndexEntryAttr::get(ctx, dimension, mapping));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there a takeVector on MapVector? The name of the data structure kinda implies there's a vector inside.

Comment on lines +736 to +737
// Populate the entries with all index expressions.
void populate(llvm::SmallVectorImpl<wave::WaveIndexEntryAttr> &entries) const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This renaming is a kind of a mindless churn that makes everyone's live harder for no good reason (attributes is a perfectly valid description of WaveIndexEntryAttr): rebases are atrocious, this PR appears on git blames when it is not very relevant, integration suffers. It may be easy to implement with a tool, but causes negative externalities ad should be avoided.

Now, ironically, I'm pushing to remove the EntryAttr, at which point this change will make more sense, so let's keep it, but it was important to point this out.

Comment on lines +535 to +536
Syntax: @M : [symbols] -> (start, step, stride)
Example: @M : [#wave.index_symbol<WG0>, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M, 1, BLOCK_M)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THis PR literally has tests that show different syntax. Don't put examples in here, they are bitrotten before you even sent your PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants