From 7e6182b12086a352c714bb033f42ad87997332e2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 16 Apr 2025 16:36:26 -0700 Subject: [PATCH] Test fft normalization --- onnxscript/function_libs/torch_lib/ops/fft.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index ea92dc347d..037f69feb4 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -14,7 +14,7 @@ from typing import Optional, Sequence -from onnxscript import INT64 +from onnxscript import INT64, ir from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import TFloat from onnxscript.onnx_opset import opset18 as op @@ -118,12 +118,18 @@ def aten__fft_c2r( # Torch truncates/pads on the last dimension only. Typically, the only valid values that can be passed # into PyTorch are n or n//2+1, where n is self.shape[dim[-1]], but this is not always the case, so we # place no such restriction on the ONNX side. - transformed = op.DFT( - transformed, - dft_length=last_dim_size, - axis=dimension, - inverse=True, - onesided=False, + scale = (op.CastLike(last_dim_size, self)) / op.CastLike( + op.Shape(transformed, start=dimension, end=dimension + 1), self + ) + transformed = ( + op.DFT( + transformed, + dft_length=last_dim_size, + axis=dimension, + inverse=True, + onesided=False, + ) + * scale ) transformed = _fftn_onnx_normalization( transformed,