Skip to content

Conversation

@SunnyLee151064
Copy link
Contributor

@SunnyLee151064 SunnyLee151064 commented Dec 6, 2025

What this PR does / why we need it?

Update the causal_conv1d_fn for better perf. For details, get seqlens from input x.shape rather than query_start_loc to remove tolist operation for less duration time.

Does this PR introduce any user-facing change?

No

How was this patch tested?

Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: SunnyLee219 <3294305115@qq.com>
@github-actions
Copy link

github-actions bot commented Dec 6, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to improve the performance of causal_conv1d_fn by removing a loop and a .tolist() operation. While the performance improvement is valid for a batch size of one, the changes introduce a critical bug that breaks support for batch processing. The updated implementation only handles the first sequence in a batch and ignores the rest, which will lead to incorrect outputs for batches larger than one. I've provided a suggestion to fix the batching logic while retaining a bugfix for tensor shapes that was also part of this change.

Comment on lines 110 to 124
out_ref_b.append(
causal_conv1d_fn_native(
x,
weight,
bias,
activation=activation,
return_final_states=True,
final_states_out=conv_states[cache_indices[0]].unsqueeze(0),
initial_states=(
conv_states[cache_indices[0]].unsqueeze(0)
if has_initial_state[0]
else None
),
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The refactoring of this function to improve performance has unfortunately broken the batch processing capability. The original implementation iterated over sequences in a batch, which is necessary for handling variable-length sequences packed together. The new implementation removes this loop and only processes the first element of the batch (as evidenced by the use of cache_indices[0] and has_initial_state[0]). This will result in incorrect behavior for any input with a batch size greater than one.

While this change fixes a shape mismatch for initial_states by adding .unsqueeze(0), removing the loop is incorrect. The correct approach is to restore the loop for batch processing and apply the unsqueeze(0) fix within the loop. Here is a suggested implementation:

    seqlens = query_start_loc[1:] - query_start_loc[:-1]
    seqlens = seqlens.tolist()
    splits = torch.split(x, seqlens, dim=-1)

    for i in range(len(seqlens)):
        x_s = splits[i]
        if cache_indices[i] == PAD_SLOT_ID:
            continue
        out_ref_b.append(
            causal_conv1d_fn_native(
                x_s,
                weight,
                bias,
                activation=activation,
                return_final_states=True,
                final_states_out=conv_states[cache_indices[i]].unsqueeze(0),
                initial_states=(
                    conv_states[cache_indices[i]].unsqueeze(0)
                    if has_initial_state[i]
                    else None
                ),
            )
        )

Signed-off-by: SunnyLee219 <3294305115@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant