Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
subsets as dace_subsets,
transformation as dace_transformation,
)
from dace.cli import progress as dace_cliprogress
from dace.sdfg import nodes as dace_nodes, utils as dace_sdutils
from dace.transformation import (
dataflow as dace_dataflow,
Expand Down Expand Up @@ -104,7 +103,6 @@ def gt_simplify(
if "InlineSDFGs" not in skip:
inline_res = gt_inline_nested_sdfg(
sdfg=sdfg,
multistate=True,
permissive=False,
validate=False,
validate_all=validate_all,
Expand Down Expand Up @@ -257,11 +255,9 @@ def gt_simplify(

def gt_inline_nested_sdfg(
sdfg: dace.SDFG,
multistate: bool = True,
permissive: bool = False,
validate: bool = True,
validate_all: bool = False,
progress: Optional[bool] = None,
) -> Optional[dict[str, int]]:
"""Perform inlining of nested SDFG into their parent SDFG.

Expand All @@ -272,24 +268,64 @@ def gt_inline_nested_sdfg(

Args:
sdfg: The SDFG that should be processed, will be modified in place and returned.
multistate: Allow inlining of multistate nested SDFG, defaults to `True`.
permissive: Be less strict on the accepted SDFGs.
validate: Perform validation after the transformation has finished.
validate_all: Performs extensive validation.

Note:
- This function grantees a stable processing order, if the name of the nested
SDFGs and the name of the state they are located in, is stable.
- The `no_inline` attribute of the `NestedSDFG` flag only affects the inlining
of that specific node. The clearing transformations and the recursive
processing, i.e. inlining of NestedSDFGs inside the nested SDFG is still
performed.
"""

# NOTE: DaCe has three(!) inliner. First `InlineMultistateSDFG`, that we employ,
# secondly `InlineSDFG`, which is only capable of inlining an SDFG with a single
# state and `InlineSDFGs` which combines the two. However, `InlineSDFG` has a
# bug and the processing order of `InlineSDFGs` is not stable. Thus GT4Py
# implements its own version.

# Finding all nested SDFGs on this level.
nested_sdfgs_to_process: list[dace_nodes.NestedSDFG] = []
for state in sdfg.states():
nested_sdfgs_to_process.extend(
node for node in state.nodes() if isinstance(node, dace_nodes.NestedSDFG)
)

# If there are no SDFGs to process then we exit.
if len(nested_sdfgs_to_process) == 0:
return None

# Now order them, such that we can process them in a stable way.
nested_sdfgs_to_process.sort(key=lambda nsdfg: (str(nsdfg.label), str(nsdfg.sdfg.parent.label)))

nb_preproccess_total = 0
nb_inlines_total = 0
nsdfgs = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, dace_nodes.NestedSDFG)]
for nsdfg_node in dace_cliprogress.optional_progressbar(
reversed(nsdfgs), title="Inlining SDFGs", n=len(nsdfgs), progress=progress
):
nsdfg: dace.SDFG = nsdfg_node.sdfg
parent_state = nsdfg.parent

# Now we start inlining all the nested SDFGs.
# Before a nested SDFG is inlined the function first tires to inline all
# SDFGs that are nested inside it, i.e. they are processed in a stable
# DFS order.
for nsdfg_node in nested_sdfgs_to_process:
nested_sdfg: dace.SDFG = nsdfg_node.sdfg
parent_state = nested_sdfg.parent
parent_sdfg = parent_state.sdfg
parent_state_id = parent_state.block_id

# Clean the symbols and connectors of the nested SDFG.
# Recursive processing of nested SDFGs.
recursive_result = gt_inline_nested_sdfg(
sdfg=nsdfg_node.sdfg,
permissive=permissive,
validate=False,
validate_all=validate_all,
)
if recursive_result is not None:
nb_preproccess_total += recursive_result.get("PruneSymbols|PruneConnectors", 0)
nb_inlines_total += recursive_result.get("InlineSDFGs", 0)

# Now perform some cleaning on the nested SDFG.
for xform in [dace_dataflow.PruneSymbols, dace_dataflow.PruneConnectors]:
candidate = {xform.nsdfg: nsdfg_node}
cleaner = xform()
Expand All @@ -305,14 +341,24 @@ def gt_inline_nested_sdfg(
cleaner.apply(parent_state, parent_sdfg)
nb_preproccess_total += 1

# Inlining an SDFG is only possible if the nested SDFG node is at global scope.
if parent_state.scope_dict()[nsdfg_node] is not None:
continue

# Check the `no_inline` flag. Note that it has to be checked here and not
# before to ensure that the node is recursively processed and the pruning
# transformations are applied.
if nsdfg_node.no_inline:
continue

# Now perform the actual inlining.
# NOTE: In [PR#2178](https://github.com/GridTools/gt4py/pull/2178) this function was
# modified to be more efficient. It also changed the order in which the inlining
# transformations of DaCe were applied. Instead of trying `InlineMultistateSDFG`
# it changed that such that `InlineSDFG` was used. However, this triggered
# [issue#2108](https://github.com/spcl/dace/issues/2108) which lead to the removals
# of some writes. As a temporary solution we no longer use `InlineSDFG` but only
# the multistate version.
# TODO(phimuell): As soon as the DaCe issue is resolved start using `InlineSDFG` again.
multi_state_candidate = {dace_interstate.InlineMultistateSDFG.nested_sdfg: nsdfg_node}
multi_state_inliner = dace_interstate.InlineMultistateSDFG()
multi_state_inliner.setup_match(
Expand Down