-
Notifications
You must be signed in to change notification settings - Fork 629
[Perf] Update causal conv1d fn for better perf #4759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Perf] Update causal conv1d fn for better perf #4759
Conversation
Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: SunnyLee219 <3294305115@qq.com>
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to improve the performance of causal_conv1d_fn by removing a loop and a .tolist() operation. While the performance improvement is valid for a batch size of one, the changes introduce a critical bug that breaks support for batch processing. The updated implementation only handles the first sequence in a batch and ignores the rest, which will lead to incorrect outputs for batches larger than one. I've provided a suggestion to fix the batching logic while retaining a bugfix for tensor shapes that was also part of this change.
| out_ref_b.append( | ||
| causal_conv1d_fn_native( | ||
| x, | ||
| weight, | ||
| bias, | ||
| activation=activation, | ||
| return_final_states=True, | ||
| final_states_out=conv_states[cache_indices[0]].unsqueeze(0), | ||
| initial_states=( | ||
| conv_states[cache_indices[0]].unsqueeze(0) | ||
| if has_initial_state[0] | ||
| else None | ||
| ), | ||
| ) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The refactoring of this function to improve performance has unfortunately broken the batch processing capability. The original implementation iterated over sequences in a batch, which is necessary for handling variable-length sequences packed together. The new implementation removes this loop and only processes the first element of the batch (as evidenced by the use of cache_indices[0] and has_initial_state[0]). This will result in incorrect behavior for any input with a batch size greater than one.
While this change fixes a shape mismatch for initial_states by adding .unsqueeze(0), removing the loop is incorrect. The correct approach is to restore the loop for batch processing and apply the unsqueeze(0) fix within the loop. Here is a suggested implementation:
seqlens = query_start_loc[1:] - query_start_loc[:-1]
seqlens = seqlens.tolist()
splits = torch.split(x, seqlens, dim=-1)
for i in range(len(seqlens)):
x_s = splits[i]
if cache_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_fn_native(
x_s,
weight,
bias,
activation=activation,
return_final_states=True,
final_states_out=conv_states[cache_indices[i]].unsqueeze(0),
initial_states=(
conv_states[cache_indices[i]].unsqueeze(0)
if has_initial_state[i]
else None
),
)
)Signed-off-by: SunnyLee219 <3294305115@qq.com>
What this PR does / why we need it?
Update the causal_conv1d_fn for better perf. For details, get seqlens from input x.shape rather than query_start_loc to remove tolist operation for less duration time.
Does this PR introduce any user-facing change?
No
How was this patch tested?