Skip to content

Conversation

@YanivDorGalron
Copy link

@YanivDorGalron YanivDorGalron commented Dec 5, 2025

What does this PR do?

Many discussions focused on speculative decoding with 'batch_size > 1' (#26875, #29769, #32165, #32189), but none were fully implemented. This PR does that.

Summary of Changes: Enable Batched Speculative Decoding

Until now, invoking generate() with an assistant_model raised a ValueError if batch_size > 1. This update modifies the codebase to support batched speculative decoding.

Key Features & Limitations

  • Batch Support: Enabled batch_size > 1 for standard speculative decoding.
  • Tokenizer Constraint: Batched execution is currently restricted to scenarios where the draft (assistant) and target (main) models share the same tokenizer.
    • Universal Assisted Generation (UAG)—where models use different tokenizers—remains restricted to batch_size = 1 and will still raise a ValueError if used with a batch.

Technical Implementation Details

Supporting batches requires dealing with "ragged tensors", where different sequences in a batch accept a different number of speculative tokens, leading to misalignment in tensor shapes and KV caches. The implementation addresses this via the following mechanisms:

1. Assistant Cache Alignment

  • When reusing the assistant model for the next draft generation, the previous KV cache must match the current state of the sequences.
  • align_cache function was implemented in candidate_generator.py. It identify which parts of the old cache match the new inputs and removing the irrelevant keys and values from the cache. This ensures the assistant's cache correctly reflects the accepted tokens.

2. Repadding Inputs and Main Model Cache

  • After verification, if Batch Item A accepts 2 tokens and Batch Item B accepts 4, the resulting input tensor becomes ragged.
  • The rejected tokens and the padding tokens are removed from each item and the entire batch is repadded. The same is happening fo the main model KV cache

3. Vectorized Verification & Sampling

  • Greedy Search: The acceptance logic (checking if candidate tokens match target output) now uses vectorized cumsum operations to identify the index of the first mismatch for every item in the batch simultaneously.
  • Sampling: _speculative_sampling has been updated to calculate acceptance probabilities ($p$ vs $q$) and handle token rejection/resampling for the entire batch in parallel.

4. Dynamic Heuristics (update_candidate_strategy)

  • The heuristic that adjusts the number of assistant tokens (num_assistant_tokens) now aggregates statistics across the batch (e.g., using the average number of matches) to determine if the draft model step size should increase or decrease.
  • For the dynamic confidence threshold (ROC curve calculation), matches and probabilities are collected per batch item, flattened, and then used to update the threshold.

Note on max_new_tokens

When using speculative decoding with batch_size > 1, the behavior regarding max_new_tokens requires nuance. Because the generation loop serves the entire batch, the process continues until one items satisfy the stopping criteria. However, due to the repadding and varying acceptance rates, other items in the batch may effectively stop generating before reaching the global max_new_tokens limit.

Batch Size = 3: Assistant vs Standard Decoding

WITHOUT assistant_model: 9.2215 seconds
WITH assistant_model: 7.4744 seconds

Sequence # Tokens w Assistant # Tokens w/o Assistant
1 50 50
2 41 50
3 40 50

I look forward to your response and comments to ensure this can be included in the next version.

- Add batch speculative decoding functionality for batch_size > 1
- Update candidate generator to handle batch processing
- Enhance generation utils with batch speculative decoding support
- Add cache utilities for batch speculative decoding
- Update tests for batch speculative decoding
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant