From 853c43ae722d4c5ca4788f83c2df5d243e3b591d Mon Sep 17 00:00:00 2001 From: Chase Naples Date: Sun, 21 Dec 2025 18:49:04 -0500 Subject: [PATCH] Add op test for torch.unique_consecutive Added OpInfo-based test for torch.unique_consecutive operator: - Created sample_inputs_unique_consecutive function in extra_opinfo.py that reuses common_methods_invocations.sample_inputs_unique - Added OpInfo entry for ops.aten.unique_consecutive with integral_types - Added TorchLibOpInfo entry in ops_test_data.py Fixes #2695 --- tests/function_libs/torch_lib/extra_opinfo.py | 18 ++++++++++++++++++ tests/function_libs/torch_lib/ops_test_data.py | 1 + 2 files changed, 19 insertions(+) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 2ce015b363..e1ebf799e6 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -2270,6 +2270,16 @@ def sample_inputs_unique_dim(op_info, device, dtype, requires_grad, **kwargs): yield sample +def sample_inputs_unique_consecutive(op_info, device, dtype, requires_grad, **kwargs): + for sample in common_methods_invocations.sample_inputs_unique( + op_info, device, dtype, requires_grad, **kwargs + ): + # unique_consecutive only supports dim=None or (dim=0 with rank=1) + # So filter out samples with dim != None + if sample.kwargs.get("dim") is None: + yield sample + + def sample_inputs_upsample_trilinear3d_vec(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs @@ -2878,6 +2888,14 @@ def __init__(self): supports_out=False, supports_autograd=False, ), + opinfo_core.OpInfo( + "ops.aten.unique_consecutive", + aten_name="unique_consecutive", + dtypes=common_dtype.integral_types(), + sample_inputs_func=sample_inputs_unique_consecutive, + supports_out=False, + supports_autograd=False, + ), opinfo_core.OpInfo( "ops.aten.upsample_bicubic2d.default", aten_name="upsample_bicubic2d", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index f6ce0f5176..46b4958f83 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1865,6 +1865,7 @@ def _where_input_wrangler( "Our implementation is based on that for CUDA" ), ), + TorchLibOpInfo("ops.aten.unique_consecutive", core_ops.aten_unique_consecutive), TorchLibOpInfo("ops.prims.broadcast_in_dim.default", prims_ops.prims_broadcast_in_dim), TorchLibOpInfo( "ops.prims.var.default", prims_ops.prims_var, tolerance={torch.float16: (1e-3, 5e-2)}