From 98456c9ba29e24ba45543ae8915d042f4182edfb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 29 Apr 2025 10:45:11 -0700 Subject: [PATCH 1/6] [torchlib] Make index_put dynamic --- .../function_libs/torch_lib/ops/core.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ea43c2c4db..56498c2dec 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4323,11 +4323,12 @@ def aten_index_put( `_. """ - def _make_reshape_list_broadcastable(reshape_list, values_shape): + def _make_reshape_list_broadcastable(reshape_list: list[INT64], values_shape, values_rank: int): # Remove ones until the rank of reshape_list matches values_shape. - while len(reshape_list) > len(values_shape) and 1 in reshape_list: + while len(reshape_list) > values_rank and 1 in reshape_list: reshape_list.remove(1) + TODO(justinchuby): Here # Now ensure each dimension is broadcastable: # This is mandatory when mixing basic and advanced indexing # Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3) @@ -4348,33 +4349,35 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape): indices = list(indices) + [None] * (self_rank - len(indices)) # Get values shape - values_shape = tuple(values.shape) + values_shape = op.Shape(values) + values_rank = len(values.shape) # Statically known index_vectors = [] for i in range(self_rank): if indices[i] is None: # For a full slice along dim i, create a range index [0, self.shape[i]). - idx = op.Range(0, self.shape[i], 1) - reshape_update = self.shape[i] + idx = op.Range(0, op.Shape(self), 1) + reshape_update = op.Shape(self, start=i, end=i+1) else: idx = indices[i] - reshape_update = math.prod(idx.shape) + reshape_update = op.ReduceProd(op.Shape(idx), keepdims=True) # when Index is more than 1D, flatten it and also the values shape # Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3) # Indices -> (2*4,) and values shape (2*4, 32) - if len(idx.shape) > 1: - values_shape = (reshape_update, *values_shape[len(idx.shape) :]) + idx_rank = len(idx.shape) + if idx_rank > 1: + values_shape = op.Concat(reshape_update, op.Shape(values, start=idx_rank, end=values_rank)) # Flatten index (always working with 1D index in each dim) idx = op.Reshape(idx, [-1]) # Create a reshape pattern: one value per index dimension, # with the current dimension set to the update size. - reshape_list = [1] * len(indices) + reshape_list = [[1]] * len(indices) reshape_list[i] = reshape_update # Adjust the reshape list to match the values shape. - reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape) + reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape, values_rank) # Reshape and expand the index. idx = op.Reshape(idx, reshape_list) From 20a0df52a33bfcefb8b2ffb3cb0630162bb24978 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 29 Apr 2025 13:08:33 -0700 Subject: [PATCH 2/6] wip --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 56498c2dec..a83b07fe2a 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4334,7 +4334,7 @@ def _make_reshape_list_broadcastable(reshape_list: list[INT64], values_shape, va # Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3) # the reshape list should be : [[2, 1], [1, 3], [2, 1]] for i, r in enumerate(reshape_list): - if r not in (1, values_shape[i]): + if isinstance(r, int) and r not in (1, values_shape[i]): value_index = values_shape.index(r) # Swap elements # For the example above the current reshape list is [1, 2] for last dim, From 6ca6d9e79ddfdee9167c2c8f4aa69215c5e65d88 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 26 Jun 2025 09:38:23 -0700 Subject: [PATCH 3/6] wip Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 99 +++++-------------- 1 file changed, 26 insertions(+), 73 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 50ff82724f..dab65f4f5e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4324,91 +4324,44 @@ def aten_index_copy( @torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True) def aten_index_put( self: TReal, - indices: Sequence[INT64], + indices: Sequence[Optional[INT64]], values: TReal, accumulate: bool = False, ) -> TReal: - """index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor - - See implementation of `torch.onnx.symbolic_opset11.index_put - `_. - """ - - def _make_reshape_list_broadcastable(reshape_list: list[INT64], values_shape, values_rank: int): - # Remove ones until the rank of reshape_list matches values_shape. - while len(reshape_list) > values_rank and 1 in reshape_list: - reshape_list.remove(1) - - TODO(justinchuby): Here - # Now ensure each dimension is broadcastable: - # This is mandatory when mixing basic and advanced indexing - # Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3) - # the reshape list should be : [[2, 1], [1, 3], [2, 1]] - for i, r in enumerate(reshape_list): - if isinstance(r, int) and r not in (1, values_shape[i]): - value_index = values_shape.index(r) - # Swap elements - # For the example above the current reshape list is [1, 2] for last dim, - # to make it broadcastable, we swap the elements - reshape_list[value_index], reshape_list[i] = r, 1 - - return reshape_list + """index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor""" # Ensure the number of indices matches the tensor rank. self_rank = len(self.shape) - if len(indices) < self_rank: - indices = list(indices) + [None] * (self_rank - len(indices)) - - # Get values shape - values_shape = op.Shape(values) - values_rank = len(values.shape) # Statically known - - index_vectors = [] - for i in range(self_rank): - if indices[i] is None: - # For a full slice along dim i, create a range index [0, self.shape[i]). - idx = op.Range(0, op.Shape(self), 1) - reshape_update = op.Shape(self, start=i, end=i+1) - else: - idx = indices[i] - reshape_update = op.ReduceProd(op.Shape(idx), keepdims=True) - # when Index is more than 1D, flatten it and also the values shape - # Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3) - # Indices -> (2*4,) and values shape (2*4, 32) - idx_rank = len(idx.shape) - if idx_rank > 1: - values_shape = op.Concat(reshape_update, op.Shape(values, start=idx_rank, end=values_rank)) - - # Flatten index (always working with 1D index in each dim) - idx = op.Reshape(idx, [-1]) - - # Create a reshape pattern: one value per index dimension, - # with the current dimension set to the update size. - reshape_list = [[1]] * len(indices) - reshape_list[i] = reshape_update - - # Adjust the reshape list to match the values shape. - reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape, values_rank) - - # Reshape and expand the index. - idx = op.Reshape(idx, reshape_list, allowzero=True) - idx = op.Expand(idx, values_shape) + index_ranks = [len(index.shape) for index in indices if index is not None] + advanced_indexing_rank = max(index_ranks) - # Flatten the index to 1D and unsqueeze to form a column vector. - idx = op.Reshape(idx, [-1]) - idx = op.Unsqueeze(idx, axes=[1]) - index_vectors.append(idx) + # reordered_positions is the permutation of the index positions where + # positions with None are move to the end of the list + # For example, if indices = [None, 1, None, 2], then reordered_positions = [1, 3, 0, 2] + reordered_positions = sorted(range(len(indices)), key=lambda i: (indices[i] is None, i)) - # Concatenate the index vectors along axis=1 to form the final indices. - new_index = op.Concat(*index_vectors, axis=1) + # Fill the list with the remaining indices up to the rank of the tensor self. + # For example, if indices = [None, 1, None, 2], and the rank of self is 6, + # then reordered_positions = [1, 3, 0, 2, 4, 5] + reordered_positions = [ + *reordered_positions, + *range(len(reordered_positions), self_rank), + ] + # Transpose self according to the reordered positions + self = op.Transpose(self, perm=reordered_positions) - # Flatten values to match the indices - flat_values = op.Reshape(values, [-1]) + # Broadcast the indices to the same shape then concatenate + not_none_indices = [idx for idx in indices if idx is not None] + broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices) + final_index = op.Concat( + *(op.Unsqueeze(op.Expand(idx, broadcast_shape), [-1]) for idx in not_none_indices), + axis=-1, + ) if accumulate: - result = op.ScatterND(self, new_index, flat_values, reduction="add") + self = op.ScatterND(self, final_index, values, reduction="add") else: - result = op.ScatterND(self, new_index, flat_values) + self = op.ScatterND(self, final_index, values) return result From f1704b9a207aee9137d089f00a0654e03ccf63c0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 8 Jul 2025 15:46:03 -0700 Subject: [PATCH 4/6] try this Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index dab65f4f5e..0e30673b10 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4340,6 +4340,8 @@ def aten_index_put( # For example, if indices = [None, 1, None, 2], then reordered_positions = [1, 3, 0, 2] reordered_positions = sorted(range(len(indices)), key=lambda i: (indices[i] is None, i)) + values = op.Transpose(values, perm=reordered_positions) + # Fill the list with the remaining indices up to the rank of the tensor self. # For example, if indices = [None, 1, None, 2], and the rank of self is 6, # then reordered_positions = [1, 3, 0, 2, 4, 5] @@ -4363,8 +4365,52 @@ def aten_index_put( else: self = op.ScatterND(self, final_index, values) - return result + if _has_none_in_middle(indices): + # If there is None in the middle, Advanced Indexing cannot decide where to put + # the new dimensions. So it places them in the front, like GatherND does. + return op.Identity(self) + # When the indices are consecutive, Advanced Indexing will place the new dimensions + # (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes. + # + # Input index axes (three parts): + # [ + # x_None_front_1, ... x_None_front_m, + # x1, ..., xk, + # x_None_back_1, ..., x_None_back_m + # ] + # GatherND result axes: + # [ + # *broadcasted_shape(x1, x2, ..., xk), + # x_None_front_1, ... x_None_front_m, + # x_None_back_1, ..., x_None_back_m + # ] + # (Transpose here) + # Advanced indexing result axes: + # [ + # x_None_front_1, ... x_None_front_m, + # *brocasted_shape(x1, x2, ..., xk), + # x_None_back_1, ..., x_None_back_m + # ] + # + # Need to transpose the result of GatherND to match this axes ordering. + first_not_none_position = reordered_positions[0] # x_None_front_m + 1 + starting_position_of_none_in_back = ( + advanced_indexing_rank + first_not_none_position + ) # x_None_back_1 + result_rank = self_rank - len(not_none_indices) + advanced_indexing_rank + perm = [ + *range( + advanced_indexing_rank, starting_position_of_none_in_back + ), # None_front_1...x_None_back_1 + *range(advanced_indexing_rank), # 0...len(broadcasted_shape) + *range( + starting_position_of_none_in_back, + result_rank, + ), # None_back_1...None_back_m + ] + + return op.Transpose(self, perm=perm) @torch_op("aten::index_put", trace_only=True) def aten_index_put_bool( From d9e55dd1c540ab318a8230510254e561120368ea Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 9 Jul 2025 17:08:40 -0700 Subject: [PATCH 5/6] Update Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 116 +++++------------- 1 file changed, 32 insertions(+), 84 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 0e30673b10..3b1dbed040 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4218,23 +4218,9 @@ def _aten_index_onnx( # ] # # Need to transpose the result of GatherND to match this axes ordering. - first_not_none_position = reordered_positions[0] # x_None_front_m + 1 - starting_position_of_none_in_back = ( - advanced_indexing_rank + first_not_none_position - ) # x_None_back_1 - result_rank = self_rank - len(not_none_indices) + advanced_indexing_rank - perm = [ - *range( - advanced_indexing_rank, starting_position_of_none_in_back - ), # None_front_1...x_None_back_1 - *range(advanced_indexing_rank), # 0...len(broadcasted_shape) - *range( - starting_position_of_none_in_back, - result_rank, - ), # None_back_1...None_back_m - ] + inverse_positions = np.argsort(reordered_positions).tolist() - return op.Transpose(self, perm=perm) + return op.Transpose(self, perm=inverse_positions) @torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True) @@ -4332,85 +4318,47 @@ def aten_index_put( # Ensure the number of indices matches the tensor rank. self_rank = len(self.shape) - index_ranks = [len(index.shape) for index in indices if index is not None] - advanced_indexing_rank = max(index_ranks) - # reordered_positions is the permutation of the index positions where - # positions with None are move to the end of the list - # For example, if indices = [None, 1, None, 2], then reordered_positions = [1, 3, 0, 2] + # 1. Reorder input tensor so that None-indexed axes are last + # This logic is identical to the aten.index implementation. reordered_positions = sorted(range(len(indices)), key=lambda i: (indices[i] is None, i)) + remaining_dims = [i for i in range(self_rank) if i not in reordered_positions] + reordered_positions.extend(remaining_dims) - values = op.Transpose(values, perm=reordered_positions) + # Transpose the input data to group the indexed dimensions first + transposed_self = op.Transpose(self, perm=reordered_positions) - # Fill the list with the remaining indices up to the rank of the tensor self. - # For example, if indices = [None, 1, None, 2], and the rank of self is 6, - # then reordered_positions = [1, 3, 0, 2, 4, 5] - reordered_positions = [ - *reordered_positions, - *range(len(reordered_positions), self_rank), - ] - # Transpose self according to the reordered positions - self = op.Transpose(self, perm=reordered_positions) - - # Broadcast the indices to the same shape then concatenate + # 2. Prepare indices for ScatterND + # This logic is also identical. not_none_indices = [idx for idx in indices if idx is not None] broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices) - final_index = op.Concat( - *(op.Unsqueeze(op.Expand(idx, broadcast_shape), [-1]) for idx in not_none_indices), - axis=-1, - ) + final_index_parts = [] + for idx in not_none_indices: + # Unsqueeze is needed to make indices broadcastable to the common shape + expanded_idx = op.Expand(idx, broadcast_shape) + final_index_parts.append(op.Unsqueeze(expanded_idx, [-1])) + + final_index = op.Concat(*final_index_parts, axis=-1) + + # 3. Prepare the 'updates' tensor (values) + # The 'values' tensor must be broadcast to match the shape of the + # broadcasted indices. + expanded_values = op.Expand(values, broadcast_shape) + + # 4. Perform the scatter operation if accumulate: - self = op.ScatterND(self, final_index, values, reduction="add") + scattered_data = op.ScatterND(transposed_self, final_index, expanded_values, reduction="add") else: - self = op.ScatterND(self, final_index, values) - - if _has_none_in_middle(indices): - # If there is None in the middle, Advanced Indexing cannot decide where to put - # the new dimensions. So it places them in the front, like GatherND does. - return op.Identity(self) + scattered_data = op.ScatterND(transposed_self, final_index, expanded_values) - # When the indices are consecutive, Advanced Indexing will place the new dimensions - # (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes. - # - # Input index axes (three parts): - # [ - # x_None_front_1, ... x_None_front_m, - # x1, ..., xk, - # x_None_back_1, ..., x_None_back_m - # ] - # GatherND result axes: - # [ - # *broadcasted_shape(x1, x2, ..., xk), - # x_None_front_1, ... x_None_front_m, - # x_None_back_1, ..., x_None_back_m - # ] - # (Transpose here) - # Advanced indexing result axes: - # [ - # x_None_front_1, ... x_None_front_m, - # *brocasted_shape(x1, x2, ..., xk), - # x_None_back_1, ..., x_None_back_m - # ] - # - # Need to transpose the result of GatherND to match this axes ordering. - first_not_none_position = reordered_positions[0] # x_None_front_m + 1 - starting_position_of_none_in_back = ( - advanced_indexing_rank + first_not_none_position - ) # x_None_back_1 - result_rank = self_rank - len(not_none_indices) + advanced_indexing_rank - perm = [ - *range( - advanced_indexing_rank, starting_position_of_none_in_back - ), # None_front_1...x_None_back_1 - *range(advanced_indexing_rank), # 0...len(broadcasted_shape) - *range( - starting_position_of_none_in_back, - result_rank, - ), # None_back_1...None_back_m - ] + # 5. Restore original dimension order + # The output of ScatterND has the same shape as the transposed input. + # We must apply an "inverse" transpose to get the final result. + inverse_positions = np.argsort(reordered_positions).tolist() + final_output = op.Transpose(scattered_data, perm=inverse_positions) - return op.Transpose(self, perm=perm) + return final_output @torch_op("aten::index_put", trace_only=True) def aten_index_put_bool( From 8da2e77d0e9dfce96ea1c891a720a2468dc35d9e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 10 Jul 2025 11:03:00 -0700 Subject: [PATCH 6/6] wip Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 3b1dbed040..7e86869d93 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4345,6 +4345,8 @@ def aten_index_put( # The 'values' tensor must be broadcast to match the shape of the # broadcasted indices. expanded_values = op.Expand(values, broadcast_shape) + # TODO: Handle None + expanded_values = op.Transpose(expanded_values, perm=reordered_positions) # 4. Perform the scatter operation if accumulate: