From 800012c576f4b8e196c55d8324af8879b6b2d020 Mon Sep 17 00:00:00 2001 From: Yaniv Galron Date: Sat, 6 Dec 2025 15:28:06 +0200 Subject: [PATCH 1/4] Implement batch speculative decoding support - Add batch speculative decoding functionality for batch_size > 1 - Update candidate generator to handle batch processing - Enhance generation utils with batch speculative decoding support - Add cache utilities for batch speculative decoding - Update tests for batch speculative decoding --- src/transformers/cache_utils.py | 186 ++++++++++++++++++ .../generation/candidate_generator.py | 184 ++++++++++++----- src/transformers/generation/utils.py | 168 ++++++++++++---- tests/generation/test_candidate_generator.py | 2 + tests/generation/test_utils.py | 113 ++++++++--- tests/utils/test_cache_utils.py | 64 ++++++ 6 files changed, 601 insertions(+), 116 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 28f40952f2cd..09bb7799275b 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -80,6 +80,75 @@ def reorder_cache(self, beam_idx: torch.LongTensor) -> None: self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device)) self.values = self.values.index_select(0, beam_idx.to(self.values.device)) + def align( + self, + new_seq_length: int, + copy_instructions: list[tuple[int, slice, slice]], + ) -> None: + """ + Align this layer's cache based on copy instructions. + + Args: + new_seq_length (`int`): The new sequence length for the aligned cache. + copy_instructions (`list[tuple[int, slice, slice]]`): List of (batch_idx, src_slice, dst_slice) tuples + specifying what to copy from the old cache to the new cache. + """ + if not self.is_initialized: + return + + B, H, _, D = self.keys.shape + new_keys = self.keys.new_zeros((B, H, new_seq_length, D)) + new_values = self.values.new_zeros((B, H, new_seq_length, D)) + + # Execute the pre-calculated copy instructions + for i, src_slice, dst_slice in copy_instructions: + new_keys[i, :, dst_slice] = self.keys[i, :, src_slice] + new_values[i, :, dst_slice] = self.values[i, :, src_slice] + + self.keys = new_keys + self.values = new_values + + def compress_and_repad_cache(self, padding_mask): + # padding_mask: True = Pad, False = Keep + B, H, S, D = self.keys.shape + + # 1. Compute lengths and dimensions + # Invert mask: True = Keep + keep_mask = ~padding_mask # [B, S] + lengths = keep_mask.sum(dim=1) # [B] + max_len = lengths.max().item() + + # 2. Allocate Output (Pre-filled with padding/zeros) + # We allocate directly in the final shape [B, H, max_len, D] + out_keys = self.keys.new_zeros((B, H, max_len, D)) + out_values = self.values.new_zeros((B, H, max_len, D)) + + # 3. Create the "Destination" mask for Left-Padding + # We want valid data to sit at the END of the sequence (Left Padding) + # Row i should have (max_len - length_i) pads, then valid data. + + # shape: [max_len] + range_tensor = torch.arange(max_len, device=self.keys.device) + # shape: [B, max_len] broadcast comparison + # Example: max_len=5, len=3. We want indices 2,3,4 to be True. + # range (0,1,2,3,4) >= (5-3=2) -> F,F,T,T,T + dest_mask = range_tensor >= (max_len - lengths.unsqueeze(1)) + + # 4. Perform the Copy (The Fast Part) + # We transpose (B, H, S, D) -> (B, S, H, D) so the mask (B, S) aligns + # This extracts ONLY the valid tokens into a flat buffer [Total_Valid, H, D] + valid_keys = self.keys.transpose(1, 2)[keep_mask] + valid_values = self.values.transpose(1, 2)[keep_mask] + + # Assign into output using the destination mask + # We transpose output to (B, max_len, H, D) to align with dest_mask (B, max_len) + out_keys.transpose(1, 2)[dest_mask] = valid_keys + out_values.transpose(1, 2)[dest_mask] = valid_values + + # 5. Assign back + self.keys = out_keys + self.values = out_values + class DynamicLayer(CacheLayerMixin): """ @@ -891,6 +960,94 @@ def __len__(self): # forward through all the layers return len(self.layers) + def align( + self, + new_ids: torch.LongTensor, + ids_in_cache: torch.LongTensor, + pad_token_id: int, + return_new_ids_in_cache: bool = False, + ): + """ + Align the cache when input sequences change (e.g., when batching different sequences together). + + Args: + new_ids (`torch.LongTensor`): The new input IDs after batching changes. + ids_in_cache (`torch.LongTensor`): The input IDs that were used to build the current cache. + pad_token_id (`int`): The padding token ID. + return_new_ids_in_cache (`bool`, *optional*, defaults to `False`): Whether to return the aligned input IDs. + + Returns: + `None` if `return_new_ids_in_cache=False`, otherwise the aligned input IDs tensor. + """ + # 1. Setup metadata (Shape: [Batch, Heads, Sequence_Length, Dimension]) + # We access the first layer just to get shapes and device + if len(self.layers) == 0 or not self.layers[0].is_initialized: + raise ValueError("Cache is not initialized") + + ref_layer = self.layers[0] + B, H, S_old, D = ref_layer.keys.shape + S_new = new_ids.shape[1] - 1 # Preserving your original sizing logic + + # 2. Pre-calculate "What to copy" for the whole batch ONCE. + + # Find start indices (Vectorized) + # Note: sum() assumes left-padding only. + old_start_indices = (ids_in_cache == pad_token_id).sum(dim=1) + new_start_indices = (new_ids == pad_token_id).sum(dim=1) + + # We will store the copy instructions here to apply to all layers later + # Format: List of tuples (batch_idx, source_slice, dest_slice) + copy_instructions = [] + + # We still loop over batch (B), but only once, not B * Layers + for i in range(B): + # Identify the content without padding + # We use standard python slicing here as it's just index math, very fast + o_start = old_start_indices[i].item() + n_start = new_start_indices[i].item() + + # Get the actual token sequences (views, not copies) + # We perform the comparison on the ID tensors (int64), which is cheap + trimmed_old = ids_in_cache[i, o_start:] + trimmed_new = new_ids[i, n_start:] + + min_len = min(len(trimmed_old), len(trimmed_new)) + + # Compare only up to min_len + # Using .ne() (not equal) and finding the first true is faster than checks + if min_len == 0: + copy_len = 0 + else: + # Find mismatch: (a != b) + mismatch = trimmed_old[:min_len].ne(trimmed_new[:min_len]) + if not mismatch.any(): + copy_len = min_len + else: + # argmax on boolean gives index of first True + copy_len = mismatch.int().argmax().item() + + if copy_len > 0: + # Define the slice objects now so we don't recreate them 32 times + src_slice = slice(o_start, o_start + copy_len) + # You align to the right (-length:) + dst_slice = slice(-copy_len, None) + copy_instructions.append((i, src_slice, dst_slice)) + + # 3. Apply changes to all layers using per-layer align method + for layer in self.layers: + layer.align(S_new, copy_instructions) + + if return_new_ids_in_cache: + new_input_ids_in_cache = ids_in_cache.new_zeros((B, S_new)) + # Execute the copy instructions for input IDs + for i, src_slice, dst_slice in copy_instructions: + new_input_ids_in_cache[i, dst_slice] = ids_in_cache[i, src_slice] + return new_input_ids_in_cache + + def compress_and_repad_cache(self, padding_mask): + for layer in self.layers: + layer.compress_and_repad_cache(padding_mask) + class DynamicCache(Cache): """ @@ -1277,6 +1434,35 @@ def batch_select_indices(self, indices: torch.Tensor): self.self_attention_cache.batch_select_indices(indices) self.cross_attention_cache.batch_select_indices(indices) + def align( + self, + new_ids: torch.LongTensor, + ids_in_cache: torch.LongTensor, + pad_token_id: int, + return_new_ids_in_cache: bool = False, + ): + """ + Align the cache when input sequences change (e.g., when batching different sequences together). + This aligns both self-attention and cross-attention caches. + + Args: + new_ids (`torch.LongTensor`): The new input IDs after batching changes. + ids_in_cache (`torch.LongTensor`): The input IDs that were used to build the current cache. + pad_token_id (`int`): The padding token ID. + return_new_ids_in_cache (`bool`, *optional*, defaults to `False`): Whether to return the aligned input IDs. + + Returns: + `None` if `return_new_ids_in_cache=False`, otherwise the aligned input IDs tensor. + """ + if return_new_ids_in_cache: + aligned_ids = self.self_attention_cache.align(new_ids, ids_in_cache, pad_token_id, return_new_ids_in_cache) + return aligned_ids + else: + self.self_attention_cache.align(new_ids, ids_in_cache, pad_token_id, return_new_ids_in_cache) + + def compress_and_repad_cache(self, padding_mask): + self.self_attention_cache.compress_and_repad_cache(padding_mask) + def get_max_cache_shape(self) -> int: """Returns the maximum sequence length (i.e. max capacity) of the cache object""" return self.self_attention_cache.get_max_cache_shape() diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index c695858169e9..10794aaaa26c 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -41,13 +41,16 @@ class CandidateGenerator: """Abstract base class for all candidate generators that can be applied during assisted generation.""" - def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, torch.FloatTensor | None]: + def get_candidates( + self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None + ) -> tuple[torch.LongTensor, torch.FloatTensor | None]: """ Fetches the candidates to be tried for the current input. Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + assistant_ids_in_cache (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Return: `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be @@ -58,7 +61,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`." ) - def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): + def update_candidate_strategy( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int, assistant_used: bool = True + ): """ Updates the candidate generation strategy based on the outcomes. @@ -70,6 +75,8 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F beam search or log softmax for each vocabulary token when using beam search num_matches (`int`): The number of matches between the candidate sequences and the model predictions. + assistant_used (`bool`): + Whether the assistant was used to generate the candidates. Assistant was not used if max_new_tokens is 0. """ raise NotImplementedError( f"{self.__class__} is an abstract class. Only classes inheriting this class can call " @@ -192,10 +199,14 @@ def __init__( and self.assistant_model.generation_config.assistant_confidence_threshold and type(self) is AssistedCandidateGenerator ): + # only needed for ROC curve calculation self.probs = [] self.matches = [] + self.clean_probs = [] - def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, torch.FloatTensor | None]: + def get_candidates( + self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None + ) -> tuple[torch.LongTensor, torch.FloatTensor | None]: """ Fetches the candidates to be tried for the current input. @@ -214,13 +225,15 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, if max_new_tokens == 0: return input_ids, None # Update past key values and masks - self._update_past_and_masks(input_ids) + self._update_past_and_masks(input_ids, assistant_ids_in_cache=assistant_ids_in_cache) # Generate candidates generation_args = self._prepare_generation_args(input_ids, min_new_tokens, max_new_tokens) candidate_ids, candidate_logits = self._generate_candidates(generation_args) return candidate_ids, candidate_logits - def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): + def update_candidate_strategy( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int, assistant_used: bool = True + ): """ Updates the candidate generation strategy based on the outcomes. @@ -230,9 +243,19 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam search or log softmax for each vocabulary token when using beam search - num_matches (`int`): - The number of matches between the candidate sequences and the model predictions. + num_matches (`torch.LongTensor` of shape `(batch_size,)` or `int`): + The number of matches between the candidate sequences and the model predictions for each batch item. + If `int`, assumes `batch_size=1` for backward compatibility. + assistant_used (`bool`): + Whether the assistant was used to generate the candidates. Assistant was not used if max_new_tokens is 0. """ + # Handle backward compatibility: convert int to tensor + if isinstance(num_matches, int): + assert input_ids.shape[0] == 1, "num_matches should be a tensor of shape (batch_size,) when batch_size > 1" + num_matches = torch.tensor([num_matches], device=input_ids.device) + + batch_size = input_ids.shape[0] + # Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the # cost of forecasting incorrect assistant tokens. @@ -240,33 +263,51 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F "heuristic", "heuristic_transient", }: - # len(scores[0])-1 is the number of candidates according to the target tokenizer. - if num_matches == len(scores[0]) - 1: + # For batch processing, we can use different strategies: + # Option 1: Use average matches across batch + avg_matches = num_matches.float().mean().item() + # Option 2: Use max matches (more aggressive) + # avg_matches = num_matches.float().max().item() + # Option 3: Use min matches (more conservative) + # avg_matches = num_matches.float().min().item() + max_candidate_length = scores.shape[1] - 1 + if avg_matches == max_candidate_length: self.num_assistant_tokens += 2.0 else: self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0) - # The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes. The costs are estimated based on the ROC curve, which considers the probability of the draft token and its match with the target. A cost of 25% is assigned to false positives and 75% to false negatives. + # The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes. + # The costs are estimated based on the ROC curve, which considers the probability of the draft token and its match with the target. + # A cost of 25% is assigned to false positives and 75% to false negatives. # This adaptation is not compatible with UAG, as it relies on the number of matched tokens based on the draft vocabulary, which is unavailable in UAG. if ( is_sklearn_available() and self.assistant_model.generation_config.assistant_confidence_threshold and type(self) is AssistedCandidateGenerator + and assistant_used ): - # update self.matches - self.matches.extend([1] * num_matches) - if len(self.probs) > len(self.matches): - self.matches.append(0) - - # update self.probs - excess_length = len(self.probs) - len(self.matches) - if excess_length > 0: - del self.probs[-excess_length:] - + # Update matches: add one match per token for each batch item + # Flatten the matches across all batches + for batch_idx in range(batch_size): + matches_count = num_matches[batch_idx].item() + item_matches = [1] * matches_count + if len(self.probs[len(self.matches)]) > matches_count: + # if the number of probabilities is greater than the number of matches, add a 0 to the end of the matches. + # this means we reject a token. + item_matches.append(0) + # taking only the relevant probabilities. for all the accepted tokens and the first rejected token. + self.clean_probs.extend([self.probs[len(self.matches)][: len(item_matches)]]) + self.matches.extend([item_matches]) + + assert len(self.matches) == len(self.clean_probs), "matches and probs must have the same length" + clean_matches = np.concatenate(self.matches) + clean_probs = np.concatenate(self.clean_probs) + + # calculate ROC curve and update threshold if we have enough samples if ( - len(self.probs) > 5 and {0, 1}.issubset(self.matches) + len(clean_probs) > 5 and {0, 1}.issubset(clean_matches) ): # require at least 5 samples to calculate the ROC curve and at least one positive and one negative sample - fpr, tpr, thresholds = roc_curve(self.matches, self.probs) + fpr, tpr, thresholds = roc_curve(clean_matches, clean_probs) fnr = 1 - tpr # Calculate the cost for each threshold @@ -286,15 +327,28 @@ def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> tuple[int, int]: return min_new_tokens, max_new_tokens def _update_past_and_masks( - self, input_ids: torch.LongTensor, remove_from_pkv: int = 0, num_added_tokens: int = 1 + self, + input_ids: torch.LongTensor, + remove_from_pkv: int = 0, + num_added_tokens: int = 1, + assistant_ids_in_cache: torch.LongTensor = None, ) -> bool: """Update past key values and attention masks for subsequent generation rounds.""" has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None if has_past_key_values: - new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv - self.assistant_kwargs["past_key_values"].crop(new_cache_size - num_added_tokens) + if input_ids.shape[0] > 1: + self.assistant_kwargs["past_key_values"].align( + input_ids, assistant_ids_in_cache, self.generation_config._pad_token_tensor.item() + ) + else: + new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv + self.assistant_kwargs["past_key_values"].crop(new_cache_size - num_added_tokens) self.assistant_kwargs = _prepare_attention_mask( - self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder + self.assistant_kwargs, + input_ids.shape[-1], + self.assistant_model.config.is_encoder_decoder, + input_ids, + self.generation_config._pad_token_tensor.item(), ) self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1]) @@ -318,17 +372,24 @@ def _generate_candidates(self, generation_args: dict) -> tuple[torch.LongTensor, """Generate candidate sequences using the assistant model.""" assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs) self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values + candidate_logits = torch.stack( + assistant_output.scores, dim=1 + ) # shape: (batch_size, candidate_length, vocab_size) if ( is_sklearn_available() and self.assistant_model.generation_config.assistant_confidence_threshold and type(self) is AssistedCandidateGenerator ): - scores_tensor = torch.cat(assistant_output.scores, dim=0) - scores_softmax = torch.softmax(scores_tensor, dim=-1) - ids = assistant_output.sequences[-1, -len(assistant_output.scores) :] - p = scores_softmax[range(len(ids)), ids] + scores_softmax = torch.softmax( + candidate_logits, dim=-1 + ) # shape: (batch_size, candidate_length, vocab_size) + ids = assistant_output.sequences[ + :, -len(assistant_output.scores) : + ] # shape: (batch_size, candidate_length) + p = torch.gather(scores_softmax, dim=-1, index=ids.unsqueeze(-1)).squeeze( + -1 + ) # shape: (batch_size, candidate_length) self.probs.extend(p.tolist()) - candidate_logits = torch.stack(assistant_output.scores, dim=1) candidate_ids = assistant_output.sequences return candidate_ids, candidate_logits @@ -494,14 +555,17 @@ def convert_source_tokens_to_target_tokens( dest_ids = destination_tokenizer(text, add_special_tokens=True, return_tensors="pt")["input_ids"] return dest_ids.to(input_ids.device) - def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, torch.FloatTensor | None]: + def get_candidates( + self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None + ) -> tuple[torch.LongTensor, torch.FloatTensor | None]: """ Fetches the candidates to be tried for the current input. Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) - + assistant_ids_in_cache (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the assistant vocabulary that are in the cache. Return: `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length, @@ -519,7 +583,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - assistant_input_ids.shape[-1]), 0) - self._update_past_and_masks(assistant_input_ids, remove_from_pkv) + self._update_past_and_masks( + assistant_input_ids, remove_from_pkv, assistant_ids_in_cache=assistant_ids_in_cache + ) generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) self.assistant_kwargs.pop("attention_mask", None) @@ -797,7 +863,7 @@ def get_target_ids( Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids. """ - num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1] + num_new_tokens = assistant_candidate_ids.shape[1] - assistant_input_ids.shape[1] if num_new_tokens == 0: return target_input_ids else: @@ -919,7 +985,9 @@ def __init__( self._target_seq_len_with_candidates: int = 0 self._prev_assistant_ids: torch.LongTensor | None = None - def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, torch.FloatTensor | None]: + def get_candidates( + self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None + ) -> tuple[torch.LongTensor, torch.FloatTensor | None]: """ Simplified version of get_candidates that uses the translator cache for token conversion. """ @@ -930,7 +998,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, if max_new_tokens == 0: return input_ids, None - self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens) + self._update_past_and_masks( + assistant_input_ids, num_added_tokens=num_added_tokens, assistant_ids_in_cache=assistant_ids_in_cache + ) generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) # Ensure scores are returned @@ -951,7 +1021,12 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, return target_candidate_ids, target_candidate_logits - def _update_past_and_masks(self, assistant_input_ids: torch.LongTensor, num_added_tokens: int = 1) -> bool: + def _update_past_and_masks( + self, + assistant_input_ids: torch.LongTensor, + num_added_tokens: int = 1, + assistant_ids_in_cache: torch.LongTensor = None, + ) -> bool: if self._prev_assistant_ids is None: # Prepare attention mask for the first generation. # For subsequent generations, the attention mask is updated in super()_update_past_and_masks. @@ -1045,13 +1120,17 @@ def __init__( if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") - def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, torch.FloatTensor | None]: + def get_candidates( + self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None + ) -> tuple[torch.LongTensor, torch.FloatTensor | None]: """ Fetches the candidates to be tried for the current input. Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + assistant_ids_in_cache (`torch.LongTensor`, *optional*): + Assistant model input IDs that are already in the cache. Not used by prompt lookup decoding. Return: `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. @@ -1139,7 +1218,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, # assisted_generation expects logits as well, but we don't have those here, so returning None return candidate_input_ids, None - def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): + def update_candidate_strategy( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int, assistant_used: bool = True + ): """ Updates the candidate generation strategy based on the outcomes. @@ -1151,6 +1232,8 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F beam search or log softmax for each vocabulary token when using beam search num_matches (`int`): The number of matches between the candidate sequences and the model predictions. + assistant_used (`bool`): + Whether the assistant was used to generate the candidates. Assistant was not used if max_new_tokens is 0. """ # Currently does nothing return @@ -1202,17 +1285,27 @@ def __init__( self.assistant_early_exit = self.generation_config.assistant_early_exit self.generation_config.assistant_early_exit = None - def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, torch.FloatTensor | None]: + def get_candidates( + self, input_ids: torch.LongTensor, assistant_ids_in_cache: torch.LongTensor = None + ) -> tuple[torch.LongTensor, torch.FloatTensor | None]: # Temporarily sets the number of hidden layers to the early exit value base_model = getattr(self.assistant_model, self.assistant_model.base_model_prefix) original_num_hidden_layers = base_model.config.num_hidden_layers base_model.config.num_hidden_layers = self.assistant_early_exit - candidate_ids, candidate_logits = super().get_candidates(input_ids) + candidate_ids, candidate_logits = super().get_candidates( + input_ids, assistant_ids_in_cache=assistant_ids_in_cache + ) base_model.config.num_hidden_layers = original_num_hidden_layers return candidate_ids, candidate_logits -def _prepare_attention_mask(model_kwargs: dict[str, Any], new_length: int, is_encoder_decoder: bool) -> dict[str, Any]: +def _prepare_attention_mask( + model_kwargs: dict[str, Any], + new_length: int, + is_encoder_decoder: bool, + input_ids: torch.LongTensor | None = None, + pad_token_id: int | None = None, +) -> dict[str, Any]: """Expands or crops the model's mask for decoding purposes, to the defined length""" mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask" @@ -1221,8 +1314,9 @@ def _prepare_attention_mask(model_kwargs: dict[str, Any], new_length: int, is_en mask = model_kwargs[mask_key] mask_length_diff = new_length - mask.shape[1] - - if mask_length_diff < 0: + if input_ids is not None and pad_token_id is not None: + model_kwargs[mask_key] = (input_ids != pad_token_id).to(mask.dtype) + elif mask_length_diff < 0: # not sure when we get into this case model_kwargs[mask_key] = mask[:, :mask_length_diff] elif mask_length_diff > 0: model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f97828a0862b..6408fbdf8af4 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -26,6 +26,7 @@ import torch.distributed as dist from packaging import version from torch import nn +from torch.nn.utils.rnn import pad_sequence from ..cache_utils import ( Cache, @@ -3649,20 +3650,34 @@ def _assisted_decoding( # keep track of which sequences are already finished batch_size, cur_len = input_ids.shape[:2] if batch_size > 1: - raise ValueError("assisted generate is only supported for batch_size = 1") + if assistant_tokenizer is not None: + raise ValueError( + "assisted generate is only supported for batch_size > 1 if assistant_tokenizer is None" + ) + if generation_config.prompt_lookup_num_tokens is not None: + raise ValueError( + "assisted generate is only supported for batch_size > 1 if prompt_lookup_num_tokens is None" + ) + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) this_peer_finished = False is_first_iteration = True # to preserve the same API in the output as other generation methods + assistant_ids_in_cache = None + pad_token_id = generation_config._pad_token_tensor while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): cur_len = input_ids.shape[1] # 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device - candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) + candidate_input_ids, candidate_logits = candidate_generator.get_candidates( + input_ids, assistant_ids_in_cache + ) + assistant_ids_in_cache = candidate_input_ids[:, :-1] candidate_input_ids = candidate_input_ids.to(self.device) if candidate_logits is not None: candidate_logits = candidate_logits.to(self.device) + assistant_used = candidate_logits is not None candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] is_done_candidate = stopping_criteria(candidate_input_ids, None) @@ -3674,7 +3689,11 @@ def _assisted_decoding( # 2.1. Prepare the model inputs candidate_kwargs = copy.copy(model_kwargs) candidate_kwargs = _prepare_attention_mask( - candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder + candidate_kwargs, + candidate_input_ids.shape[1], + self.config.is_encoder_decoder, + candidate_input_ids, + pad_token_id, ) candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) if "cache_position" in candidate_kwargs: @@ -3708,12 +3727,13 @@ def _assisted_decoding( # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) # 👉 Apply algorithm 1 from the speculative decoding paper (https://huggingface.co/papers/2211.17192). if do_sample and candidate_logits is not None: - valid_tokens, n_matches = _speculative_sampling( + valid_tokens_padded, n_matches = _speculative_sampling( candidate_input_ids, candidate_logits, candidate_length, new_logits, is_done_candidate, + pad_token_id, ) # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the @@ -3722,17 +3742,37 @@ def _assisted_decoding( else: if do_sample: probs = new_logits.softmax(dim=-1) - selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] + selected_tokens = ( + torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1) + .squeeze(-1) + .view(probs.size(0), probs.size(1)) + ) else: selected_tokens = new_logits.argmax(dim=-1) candidate_new_tokens = candidate_input_ids[:, cur_len:] - n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum() + not_accepted_tokens = ~( + candidate_new_tokens == selected_tokens[:, :-1] + ) # shape: (batch_size, candidate_lenght) + n_matches = (not_accepted_tokens.cumsum(dim=-1) < 1).sum(dim=-1) # shape: (batch_size,) # Ensure we don't generate beyond max_len or an EOS token - if is_done_candidate and n_matches == candidate_length: - n_matches -= 1 - valid_tokens = selected_tokens[:, : n_matches + 1] + # Create fully padded tensor + valid_tokens_padded = torch.full( + (batch_size, candidate_length + 1), # YANIV: max_matches is at most candidate_length + pad_token_id, + dtype=torch.long, + device=selected_tokens.device, + ) + + # Build mask for which positions in each row should be filled + range_row = torch.arange(candidate_length + 1, device=selected_tokens.device).unsqueeze(0) + mask = range_row <= n_matches.unsqueeze(1) + + # Fill efficient batched slice + valid_tokens_padded[mask] = selected_tokens[:, : candidate_length + 1][ + mask + ] # a tensor of the selected_tokens with trailing pad_token_id # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated # by the model after the last candidate match is also valid, as it is generated from a correct sequence. @@ -3740,23 +3780,25 @@ def _assisted_decoding( # is no match. # 4.1. Get the valid continuation, after the matching tokens - input_ids = torch.cat((input_ids, valid_tokens), dim=-1) + input_ids, outputs.past_key_values = repadd_batch_and_fix_cache( + input_ids, outputs.past_key_values, valid_tokens_padded, pad_token_id + ) if streamer is not None: - streamer.put(valid_tokens.cpu()) + streamer.put( + valid_tokens_padded.cpu() + ) # we might want to remove the padding here. and allow only for batch size 1 new_cur_len = input_ids.shape[1] - # 4.2. Discard past key values relative to unused assistant tokens - outputs.past_key_values.crop(new_cur_len - 1) - # 5. Update the candidate generation strategy if needed - candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) + candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches, assistant_used) # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs["cache_position"] = torch.tensor([new_cur_len - 1], device=input_ids.device, dtype=torch.long) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, - num_new_tokens=n_matches + 1, + num_new_tokens=0, ) if synced_gpus and this_peer_finished: continue @@ -3898,6 +3940,7 @@ def _speculative_sampling( candidate_length, new_logits, is_done_candidate, + pad_token_id, ): """ Applies sampling as in the speculative decoding paper (https://huggingface.co/papers/2211.17192, algorithm 1). Returns @@ -3909,43 +3952,53 @@ def _speculative_sampling( # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens # selected by the assistant, respectively. q = candidate_logits.softmax(dim=-1) - q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1) + q_i = q.gather(dim=-1, index=new_candidate_input_ids.unsqueeze(-1)).squeeze(-1) p = new_logits.softmax(dim=-1) - p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1) - probability_ratio = p_i / q_i + p_i = p.gather(dim=-1, index=new_candidate_input_ids.unsqueeze(-1)).squeeze(-1) + probability_ratio = p_i / q_i # [batch_size, candidate_length] # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio # (= keep with p = probability_ratio). Keep all the tokens until the first rejection r_i = torch.rand_like(probability_ratio) is_accepted = r_i <= probability_ratio - n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1 + n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum(dim=-1) # this is `n` in algorithm 1. shape: (batch_size,) # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior) - if is_done_candidate and n_matches == candidate_length: - # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model - # due to acceptance on EOS we fix `n_matches` - n_matches -= 1 - valid_tokens = new_candidate_input_ids[:, : n_matches + 1] - else: - # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. - gamma = candidate_logits.shape[1] - p_n_plus_1 = p[:, n_matches, :] - if n_matches < gamma: - q_n_plus_1 = q[:, n_matches, :] - p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0) - p_prime.div_(p_prime.sum()) - else: - p_prime = p_n_plus_1 - t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :] - - # The selected tokens include the matches (if any) plus the next sampled tokens - if n_matches > 0: - valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1) + valid_tokens_list = [None] * len(n_matches) + for i in range(len(n_matches)): + if is_done_candidate[i] and n_matches[i] == candidate_length: + # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model + # due to acceptance on EOS we fix `n_matches` + n_matches[i] -= 1 + valid_tokens = new_candidate_input_ids[[i], : n_matches[i] + 1] else: - valid_tokens = t + # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling. + gamma = candidate_logits.shape[1] + p_n_plus_1 = p[i, n_matches[i], :] + if n_matches[i] < gamma: + q_n_plus_1 = q[i, n_matches[i], :] + p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0) + p_prime.div_(p_prime.sum()) + else: + p_prime = p_n_plus_1 + t = torch.multinomial(p_prime, num_samples=1)[None, :] - return valid_tokens, n_matches + # The selected tokens include the matches (if any) plus the next sampled tokens + if n_matches[i] > 0: + valid_tokens = torch.cat((new_candidate_input_ids[[i], : n_matches[i]], t), dim=-1) + else: + valid_tokens = t + valid_tokens_list[i] = valid_tokens + pad_valid_tokens = torch.full( + (len(valid_tokens_list), candidate_length + 1), + pad_token_id, + dtype=torch.long, + device=valid_tokens_list[0].device, + ) + for i in range(len(valid_tokens_list)): + pad_valid_tokens[i, -valid_tokens_list[i].shape[1] :] = valid_tokens_list[i] + return pad_valid_tokens, n_matches def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_attention=False): @@ -3972,3 +4025,34 @@ def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_at new_tuple += (layer[..., i : i + 1, :last_dim_size],) outputs += (new_tuple,) return outputs + + +def repadd_batch_and_fix_cache(input_ids, past_key_values, accepted_tokens_padded, pad_token_id): + """ + params + input_ids: the input ids of the model (without the candidate tokens). shape: [B, S] + past_key_values: the past key values of the model. shape: [B, S, num_heads, head_dim] for each layer in the cache. + accepted_tokens_padded: the accepted tokens padded with the bonus token. shape: [B, candidate_length+1]. + The bonus token is not always the last token (!). The rejected tokens are replaced with the pad_token_id. + pad_token_id: the pad token id. + returns: + repadded_tensor: the repadded tensor. shape: [B, ..] + cache: the cache after modifying the keys and values. + + """ + next_input_ids = torch.cat([input_ids, accepted_tokens_padded], dim=1) # notive that accepted + # this will let us know which locations in the kv cache we should remove. + # we remove the last token because it is the bonus token and it does not appear in the cache. + cache_input_ids = next_input_ids.clone() + # last non zero token in each row is set to 0 because it is the bonus token and it does not appear in the cache. + cache_input_ids[torch.arange(cache_input_ids.shape[0]), (cache_input_ids != pad_token_id).cumsum(1).argmax(1)] = ( + pad_token_id + ) + padding_mask = cache_input_ids[:, :-1] == pad_token_id + past_key_values.compress_and_repad_cache(padding_mask) + # 1. Filter out current padding and repad to minimum length. + next_input_ids_clean = [row[row != pad_token_id] for row in next_input_ids] + next_input_ids_padded = pad_sequence( + next_input_ids_clean, batch_first=True, padding_value=pad_token_id, padding_side="left" + ) + return next_input_ids_padded, past_key_values diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index 3a50a963a9a2..e85c786b0450 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -249,6 +249,8 @@ def setUp(self): self.assistant_tokenizer.pad_token_id = self.assistant_tokenizer.eos_token_id if self.assistant_tokenizer.bos_token_id is None: self.assistant_tokenizer.bos_token_id = self.assistant_tokenizer.eos_token_id + + self.generation_config._pad_token_tensor = torch.tensor([self.target_tokenizer.pad_token_id]).to(torch_device) self.input_ids = torch.tensor([[1, 2, 3]]).to(torch_device) self.model_kwargs = { diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 36e1fc248a47..9c80fdd0e2a0 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2521,16 +2521,17 @@ def test_speculative_sampling(self): ] ] ) - last_assistant_token_is_eos = False + last_assistant_token_is_eos = [False] # length: batch_size validated_tokens, n_matches = _speculative_sampling( candidate_input_ids, candidate_logits, candidate_length, new_logits, last_assistant_token_is_eos, + pad_token_id=-1, ) - self.assertTrue(n_matches.item() == 2) - self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8]) + self.assertTrue(n_matches.tolist() == [2]) + self.assertTrue(validated_tokens.tolist()[0] == [-1, 1, 4, 8]) def test_speculative_sampling_target_distribution(self): """ @@ -2564,7 +2565,7 @@ def test_speculative_sampling_target_distribution(self): ] ] ) - last_assistant_token_is_eos = False + last_assistant_token_is_eos = [False] # length: batch_size last_validated_token = [] for _ in range(10_000): validated_tokens, n_matches = _speculative_sampling( @@ -2573,12 +2574,14 @@ def test_speculative_sampling_target_distribution(self): candidate_length, new_logits, last_assistant_token_is_eos, + pad_token_id=-1, ) - self.assertTrue(n_matches.item() == 2) - self.assertTrue(validated_tokens.tolist()[0][0] == 1) - self.assertTrue(validated_tokens.tolist()[0][1] == 4) - self.assertTrue(validated_tokens.tolist()[0][2] in [1, 3, 7, 8]) - last_validated_token.append(validated_tokens.tolist()[0][2]) + self.assertTrue(n_matches.tolist() == [2]) + self.assertTrue(validated_tokens.tolist()[0][0] == -1) # padding token + self.assertTrue(validated_tokens.tolist()[0][1] == 1) + self.assertTrue(validated_tokens.tolist()[0][2] == 4) + self.assertTrue(validated_tokens.tolist()[0][3] in [1, 3, 7, 8]) + last_validated_token.append(validated_tokens.tolist()[0][3]) # check that the most likely tokens are selected more often than the less likely ones last_token_counts = collections.Counter(last_validated_token) self.assertTrue(last_token_counts[1] > last_token_counts[3] > last_token_counts[7] > 0) @@ -3416,7 +3419,9 @@ def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self): "do_sample": False, "assistant_model": assistant_model, } - model.generate(**inputs, **generation_kwargs) + output = model.generate(**inputs, **generation_kwargs) + print(tokenizer.batch_decode(output, skip_special_tokens=True)) + print(f"num_assistant_tokens: {assistant_model.generation_config.num_assistant_tokens}") # update_candidate_strategy is called only once and therefore, assistant_model.generation_config.num_assistant_tokens should be either 4 or 7 self.assertTrue(assistant_model.generation_config.num_assistant_tokens in (4, 7)) @@ -3669,6 +3674,37 @@ def test_speculative_decoding_equals_regular_decoding(self): self.assertEqual(expected_out.shape, predicted_out.shape) self.assertTrue((expected_out == predicted_out).all().item()) + def test_batched_speculative_decoding_equals_regular_decoding(self): + draft_name = "HuggingFaceTB/SmolLM-135M" + target_name = "HuggingFaceTB/SmolLM-1.7B" + + batch_size = 4 + draft_model = AutoModelForCausalLM.from_pretrained(draft_name) + target_model = AutoModelForCausalLM.from_pretrained(target_name) + + tokenizer = AutoTokenizer.from_pretrained(target_name, padding_side="left") + + prompt_size = torch.randint(low=20, high=100, size=(1,)) + max_new_tokens = torch.randint(low=10, high=50, size=(1,)) + input_ids = (torch.rand(batch_size, prompt_size[0]) * 100).to(int) + 50 + + max_new_tokens_item = max_new_tokens[0].item() + expected_out = target_model.generate(input_ids, do_sample=False, max_new_tokens=max_new_tokens_item) + predicted_out = target_model.generate( + input_ids, + do_sample=False, + max_new_tokens=max_new_tokens_item, + assistant_model=draft_model, + tokenizer=tokenizer, + ) + + self.assertEqual(expected_out.shape, predicted_out.shape) + expected_out_decoded = tokenizer.batch_decode(expected_out, skip_special_tokens=True) + predicted_out_decode = tokenizer.batch_decode(predicted_out, skip_special_tokens=True) + for predicted_item, expected_item in zip(predicted_out_decode, expected_out_decoded): + min_length = min(len(predicted_item), len(expected_item)) + self.assertTrue(predicted_item[:min_length] == expected_item[:min_length]) + @pytest.mark.generate @require_torch_multi_accelerator def test_generate_with_static_cache_multi_accelerator(self): @@ -4863,12 +4899,12 @@ def setUp(self): generation_config=self.assistant_model.generation_config, model_kwargs=self.model_kwargs, ) - self.candidate_generator.probs = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1] - self.original_probs = self.candidate_generator.probs self.original_threshold = self.assistant_model.generation_config.assistant_confidence_threshold def assert_no_sklearn(self): with patch("transformers.generation.candidate_generator.is_sklearn_available", lambda: False): + self.candidate_generator.probs = [[0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]] + self.original_probs = self.candidate_generator.probs self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) self.assertEqual(self.candidate_generator.matches, self.original_matches) self.assertEqual(self.candidate_generator.probs, self.original_probs) @@ -4881,76 +4917,95 @@ def test_update_candidate_strategy_no_matches_short(self, sklearn_available): self.original_matches = [] self.candidate_generator.matches = self.original_matches self.num_matches = 0 + self.candidate_generator.probs = [[0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]] + self.original_probs = self.candidate_generator.probs if sklearn_available: self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) - self.assertEqual(self.candidate_generator.matches, [0]) - self.assertEqual(self.candidate_generator.probs, [0.9]) + self.assertEqual(self.candidate_generator.matches, [[0]]) + self.assertEqual(self.candidate_generator.clean_probs, [[0.9]]) self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4) else: self.assert_no_sklearn() @parameterized.expand([(is_sklearn_available(),), (False,)]) def test_update_candidate_strategy_with_mix_matches_3(self, sklearn_available): - self.original_matches = [1, 0, 1, 0, 1] + self.original_matches = [[1, 0], [1, 0], [1]] self.candidate_generator.matches = self.original_matches self.num_matches = 3 + self.candidate_generator.probs = [[0.9, 0.8], [0.7, 0.6], [0.5], [0.4, 0.3, 0.2, 0.1]] + self.candidate_generator.clean_probs = [[0.9, 0.8], [0.7, 0.6], [0.5]] + self.original_probs = self.candidate_generator.probs if sklearn_available: self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) - self.assertEqual(self.candidate_generator.matches, [1, 0, 1, 0, 1, 1, 1, 1, 0]) - self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]) + self.assertEqual(self.candidate_generator.matches, [[1, 0], [1, 0], [1], [1, 1, 1, 0]]) + self.assertEqual( + self.candidate_generator.clean_probs, [[0.9, 0.8], [0.7, 0.6], [0.5], [0.4, 0.3, 0.2, 0.1]] + ) self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.2) else: self.assert_no_sklearn() @parameterized.expand([(is_sklearn_available(),), (False,)]) def test_update_candidate_strategy_with_matches_4(self, sklearn_available): - self.original_matches = [1, 1, 1, 1, 1] + self.original_matches = [[1, 1, 1, 1, 1]] self.candidate_generator.matches = self.original_matches self.num_matches = 4 + self.candidate_generator.probs = [[0.9, 0.8, 0.7, 0.6, 0.5], [0.4, 0.3, 0.2, 0.1]] + self.candidate_generator.clean_probs = [[0.9, 0.8, 0.7, 0.6, 0.5]] + self.original_probs = self.candidate_generator.probs if sklearn_available: self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) - self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 1, 1]) - self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]) + self.assertEqual(self.candidate_generator.matches, [[1, 1, 1, 1, 1], [1, 1, 1, 1]]) + self.assertEqual(self.candidate_generator.clean_probs, [[0.9, 0.8, 0.7, 0.6, 0.5], [0.4, 0.3, 0.2, 0.1]]) self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4) else: self.assert_no_sklearn() @parameterized.expand([(is_sklearn_available(),), (False,)]) def test_update_candidate_strategy_with_matches_3(self, sklearn_available): - self.original_matches = [1, 1, 1, 1, 1] + self.original_matches = [[1, 1, 1, 1, 1]] self.candidate_generator.matches = self.original_matches self.num_matches = 3 + self.candidate_generator.probs = [[0.9, 0.8, 0.7, 0.6, 0.5], [0.4, 0.3, 0.2, 0.1]] + self.candidate_generator.clean_probs = [[0.9, 0.8, 0.7, 0.6, 0.5]] + self.original_probs = self.candidate_generator.probs if sklearn_available: self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) - self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 1, 0]) - self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]) + self.assertEqual(self.candidate_generator.matches, [[1, 1, 1, 1, 1], [1, 1, 1, 0]]) + self.assertEqual(self.candidate_generator.clean_probs, [[0.9, 0.8, 0.7, 0.6, 0.5], [0.4, 0.3, 0.2, 0.1]]) self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.2) else: self.assert_no_sklearn() @parameterized.expand([(is_sklearn_available(),), (False,)]) def test_update_candidate_strategy_with_matches_2(self, sklearn_available): - self.original_matches = [1, 1, 1, 1, 1] + self.original_matches = [[1, 1, 1, 1, 1]] self.candidate_generator.matches = self.original_matches self.num_matches = 2 + self.candidate_generator.probs = [[0.9, 0.8, 0.7, 0.6, 0.5], [0.4, 0.3, 0.2, 0.1]] + self.candidate_generator.clean_probs = [[0.9, 0.8, 0.7, 0.6, 0.5]] + self.original_probs = self.candidate_generator.probs if sklearn_available: self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) - self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 1, 0]) - self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2]) + self.assertEqual(self.candidate_generator.matches, [[1, 1, 1, 1, 1], [1, 1, 0]]) + self.assertEqual(self.candidate_generator.clean_probs, [[0.9, 0.8, 0.7, 0.6, 0.5], [0.4, 0.3, 0.2]]) self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.3) else: self.assert_no_sklearn() @parameterized.expand([(is_sklearn_available(),), (False,)]) def test_update_candidate_strategy_with_matches_1(self, sklearn_available): - self.original_matches = [1, 1, 1, 1, 1] + self.original_matches = [[1, 1, 1, 1, 1]] self.candidate_generator.matches = self.original_matches self.num_matches = 1 + self.candidate_generator.probs = [[0.9, 0.8, 0.7, 0.6, 0.5], [0.4, 0.3, 0.2, 0.1]] + self.candidate_generator.clean_probs = [[0.9, 0.8, 0.7, 0.6, 0.5]] + self.original_probs = self.candidate_generator.probs if sklearn_available: self.candidate_generator.update_candidate_strategy(self.input_ids, None, self.num_matches) - self.assertEqual(self.candidate_generator.matches, [1, 1, 1, 1, 1, 1, 0]) - self.assertEqual(self.candidate_generator.probs, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3]) + self.assertEqual(self.candidate_generator.matches, [[1, 1, 1, 1, 1], [1, 0]]) + self.assertEqual(self.candidate_generator.clean_probs, [[0.9, 0.8, 0.7, 0.6, 0.5], [0.4, 0.3]]) self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4) else: self.assert_no_sklearn() diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 64f2ca09a9b1..e8c771789141 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -56,6 +56,7 @@ convert_and_export_with_cache, pipeline, ) + from transformers.cache_utils import DynamicLayer from transformers.integrations.executorch import export_with_dynamic_cache @@ -114,6 +115,69 @@ def _random_kvs(config): self.assertTrue(cached_values.shape == (1, 1, 10, 128)) +@require_torch +class TestAlignCache(unittest.TestCase): + """Test suite for the align_cache function.""" + + def setUp(self): + """Set up test fixtures.""" + self.device = torch_device + self.dtype = torch.float32 + self.pad_token_id = 0 + self.num_layers = 3 + self.num_heads = 4 + self.head_dim = 8 + + def _create_cache(self, seq_len: int, batch_size: int) -> Cache: + """Helper to create a cache with specified sequence length.""" + shape = (batch_size, self.num_heads, seq_len, self.head_dim) + layers = [] + for _ in range(self.num_layers): + layer = DynamicLayer() + # Initialize with dummy key/value states to set up the cache structure + key_states = torch.randn(*shape, dtype=self.dtype, device=self.device) + value_states = torch.randn(*shape, dtype=self.dtype, device=self.device) + layer.lazy_initialization(key_states) + layer.keys = key_states + layer.values = value_states + layers.append(layer) + return Cache(layers=layers) + + def test_align_cache(self): + """Test alignment.""" + input_ids_in_cache = torch.tensor( + [ + [56, 19712, 8182, 314, 354, 1440, 29, 7032, 2727, 338], + [0, 0, 0, 0, 23297, 314, 253, 6256, 1789, 338], + [0, 0, 0, 0, 0, 29968, 1380, 314, 4013, 975], + ], + dtype=torch.long, + ) + + current_input_ids = torch.tensor( + [ + [56, 19712, 8182, 314, 354, 1440, 29, 7032, 2727, 338, 2433], + [0, 0, 0, 23297, 314, 253, 6256, 1789, 338, 314, 4889], + [0, 0, 0, 0, 0, 0, 29968, 1380, 314, 4013, 28], + ], + dtype=torch.long, + ) + + ids_in_cache_after_fixing = torch.tensor( + [ + [56, 19712, 8182, 314, 354, 1440, 29, 7032, 2727, 338], + [0, 0, 0, 0, 23297, 314, 253, 6256, 1789, 338], + [0, 0, 0, 0, 0, 0, 29968, 1380, 314, 4013], + ], + dtype=torch.long, + ) + cache = self._create_cache(seq_len=input_ids_in_cache.shape[1], batch_size=input_ids_in_cache.shape[0]) + new_input_ids_in_cache = cache.align( + current_input_ids, input_ids_in_cache, self.pad_token_id, return_new_ids_in_cache=True + ) + self.assertTrue(torch.allclose(new_input_ids_in_cache, ids_in_cache_after_fixing)) + + def _skip_on_failed_cache_prerequisites(test, cache_implementation): """Function to skip tests on failed cache prerequisites, given a cache implementation""" # Installed dependencies From ab53bd4571c21ab8164a5b34fcfab633b23487c0 Mon Sep 17 00:00:00 2001 From: Yaniv Galron Date: Sat, 6 Dec 2025 19:17:39 +0200 Subject: [PATCH 2/4] fixing input alignemnet when decoder first token is a padding token --- src/transformers/generation/utils.py | 19 ++++++++++++++++--- tests/generation/test_candidate_generator.py | 2 +- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6408fbdf8af4..6a137007ae28 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3666,6 +3666,7 @@ def _assisted_decoding( is_first_iteration = True # to preserve the same API in the output as other generation methods assistant_ids_in_cache = None pad_token_id = generation_config._pad_token_tensor + decoder_start_token_tensor = generation_config._decoder_start_token_tensor while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): cur_len = input_ids.shape[1] @@ -3781,7 +3782,12 @@ def _assisted_decoding( # 4.1. Get the valid continuation, after the matching tokens input_ids, outputs.past_key_values = repadd_batch_and_fix_cache( - input_ids, outputs.past_key_values, valid_tokens_padded, pad_token_id + input_ids, + outputs.past_key_values, + valid_tokens_padded, + pad_token_id, + self.config.is_encoder_decoder, + decoder_start_token_tensor, ) if streamer is not None: streamer.put( @@ -3945,7 +3951,6 @@ def _speculative_sampling( """ Applies sampling as in the speculative decoding paper (https://huggingface.co/papers/2211.17192, algorithm 1). Returns the selected tokens, as well as the number of candidate matches. - NOTE: Unless otherwise stated, the variable names match those in the paper. """ new_candidate_input_ids = candidate_input_ids[:, -candidate_length:] @@ -4027,7 +4032,9 @@ def _split_model_outputs(outputs, new_outputs, cur_len, added_len, is_decoder_at return outputs -def repadd_batch_and_fix_cache(input_ids, past_key_values, accepted_tokens_padded, pad_token_id): +def repadd_batch_and_fix_cache( + input_ids, past_key_values, accepted_tokens_padded, pad_token_id, is_encoder_decoder, decoder_start_token_tensor +): """ params input_ids: the input ids of the model (without the candidate tokens). shape: [B, S] @@ -4035,11 +4042,15 @@ def repadd_batch_and_fix_cache(input_ids, past_key_values, accepted_tokens_padde accepted_tokens_padded: the accepted tokens padded with the bonus token. shape: [B, candidate_length+1]. The bonus token is not always the last token (!). The rejected tokens are replaced with the pad_token_id. pad_token_id: the pad token id. + is_encoder_decoder: whether the model is an encoder-decoder model. + decoder_start_token_tensor: the decoder start token tensor. shape: [B, 1]. returns: repadded_tensor: the repadded tensor. shape: [B, ..] cache: the cache after modifying the keys and values. """ + if is_encoder_decoder and pad_token_id == decoder_start_token_tensor: + input_ids[:, 0] = -1 # placeholder to keep safe next_input_ids = torch.cat([input_ids, accepted_tokens_padded], dim=1) # notive that accepted # this will let us know which locations in the kv cache we should remove. # we remove the last token because it is the bonus token and it does not appear in the cache. @@ -4055,4 +4066,6 @@ def repadd_batch_and_fix_cache(input_ids, past_key_values, accepted_tokens_padde next_input_ids_padded = pad_sequence( next_input_ids_clean, batch_first=True, padding_value=pad_token_id, padding_side="left" ) + if is_encoder_decoder and pad_token_id == decoder_start_token_tensor: + next_input_ids_padded[:, 0] = decoder_start_token_tensor return next_input_ids_padded, past_key_values diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index e85c786b0450..85303b6a1ed1 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -249,7 +249,7 @@ def setUp(self): self.assistant_tokenizer.pad_token_id = self.assistant_tokenizer.eos_token_id if self.assistant_tokenizer.bos_token_id is None: self.assistant_tokenizer.bos_token_id = self.assistant_tokenizer.eos_token_id - + self.generation_config._pad_token_tensor = torch.tensor([self.target_tokenizer.pad_token_id]).to(torch_device) self.input_ids = torch.tensor([[1, 2, 3]]).to(torch_device) From 36e802231e353fb067ae892fc108bd85c44adff0 Mon Sep 17 00:00:00 2001 From: Yaniv Galron Date: Sun, 7 Dec 2025 17:28:50 +0200 Subject: [PATCH 3/4] fixing problem in candidate generator --- src/transformers/generation/candidate_generator.py | 7 +++++-- src/transformers/generation/utils.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 10794aaaa26c..ed38f445cda8 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -1314,9 +1314,12 @@ def _prepare_attention_mask( mask = model_kwargs[mask_key] mask_length_diff = new_length - mask.shape[1] - if input_ids is not None and pad_token_id is not None: + if mask_length_diff == 0: + pass # no need to do anything + elif input_ids is not None and pad_token_id is not None: + # this could be a problem if the original mask is ignoring tokens that are not pad tokens. model_kwargs[mask_key] = (input_ids != pad_token_id).to(mask.dtype) - elif mask_length_diff < 0: # not sure when we get into this case + elif mask_length_diff < 0: model_kwargs[mask_key] = mask[:, :mask_length_diff] elif mask_length_diff > 0: model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6a137007ae28..03d8481bd68f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3693,8 +3693,8 @@ def _assisted_decoding( candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder, - candidate_input_ids, - pad_token_id, + candidate_input_ids if batch_size > 1 else None, + pad_token_id if batch_size > 1 else None, ) candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) if "cache_position" in candidate_kwargs: From 5d1870b29600d8f3d8315c9b21e1f2fd58396551 Mon Sep 17 00:00:00 2001 From: Yaniv Galron Date: Sun, 7 Dec 2025 17:32:29 +0200 Subject: [PATCH 4/4] fixup --- src/transformers/generation/candidate_generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index ed38f445cda8..b56fcaca4963 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -1315,9 +1315,9 @@ def _prepare_attention_mask( mask = model_kwargs[mask_key] mask_length_diff = new_length - mask.shape[1] if mask_length_diff == 0: - pass # no need to do anything + pass # no need to do anything elif input_ids is not None and pad_token_id is not None: - # this could be a problem if the original mask is ignoring tokens that are not pad tokens. + # this could be a problem if the original mask is ignoring tokens that are not pad tokens. model_kwargs[mask_key] = (input_ids != pad_token_id).to(mask.dtype) elif mask_length_diff < 0: model_kwargs[mask_key] = mask[:, :mask_length_diff]