Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@
from gt4py.next.type_system import type_info, type_specifications as ts


def _is_literal_expr(node: itir.Node) -> bool:
"""Return if node is a `Literal` or a tuple thereof."""
if isinstance(node, itir.Literal):
return True
if cpm.is_call_to(node, "make_tuple") and all(_is_literal_expr(arg) for arg in node.args):
return True
if cpm.is_call_to(node, "tuple_get") and _is_literal_expr(node.args[1]):
return True
return False


def _is_trivial_tuple_expr(node: itir.Expr):
"""Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof."""
if cpm.is_call_to(node, "make_tuple") and all(
Expand Down Expand Up @@ -78,6 +89,9 @@ def visit_FunCall(self, node: itir.FunCall) -> itir.Node:


def _is_collectable_expr(node: itir.Node) -> bool:
if _is_literal_expr(node):
# do not collect literal expressions
return False
if isinstance(node, itir.FunCall):
# do not collect (and thus deduplicate in CSE) shift(offsets…) calls. Node must still be
# visited, to ensure symbol dependencies are recognized correctly.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import pytest
import gt4py.next as gtx
import numpy as np

from next_tests.integration_tests.cases import KDim, cartesian_case
from next_tests.integration_tests import cases
from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
exec_alloc_descriptor,
)


@pytest.mark.uses_scan
def test_scan_init_duplicated(cartesian_case):
"""
Tests that a non-trivial duplicated expression in the `init` argument of a scan operator works.

GTFN currently doesn't like if the expression gets cse-extracted.
"""

@gtx.scan_operator(axis=KDim, forward=True, init=((1.0,), (1.0,)))
def testee_scan(
state: tuple[tuple[float], tuple[float]], inp: float
) -> tuple[tuple[float], tuple[float]]:
return (state[0][0] + inp,), (state[1][0] + inp,)

@gtx.field_operator
def testee(
inp: gtx.Field[[KDim], float],
) -> tuple[tuple[gtx.Field[[KDim], float]], tuple[gtx.Field[[KDim], float]]]:
return testee_scan(inp)

inp = cases.allocate(cartesian_case, testee, "inp")()
out = cases.allocate(cartesian_case, testee, cases.RETURN).zeros()()

cases.verify(
cartesian_case,
testee,
inp,
out=out,
ref=(
(np.cumsum(inp.asnumpy(), axis=0) + 1.0,),
(np.cumsum(inp.asnumpy(), axis=0) + 1.0,),
),
)