From 1bdc016c8f3e82dea100859cbc645a5f035ed386 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 4 Dec 2025 21:19:36 +0000 Subject: [PATCH 1/2] wip Signed-off-by: Lucas Wilkinson --- .../models/language/generation/test_hybrid.py | 7 +- tests/v1/cudagraph/test_cudagraph_dispatch.py | 16 +-- vllm/config/compilation.py | 33 ------ vllm/config/vllm.py | 7 -- vllm/v1/cudagraph_dispatcher.py | 38 ++++++- vllm/v1/spec_decode/eagle.py | 105 +++++++----------- vllm/v1/worker/gpu_model_runner.py | 20 ++-- 7 files changed, 97 insertions(+), 129 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 37830093cd3c..dd5b7fdc3e1c 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -9,6 +9,7 @@ from tests.utils import multi_gpu_test from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from ...utils import check_logprobs_close, check_outputs_equal @@ -172,7 +173,11 @@ def test_mamba_cache_cg_padding( tensor dimensions aren't compatible. """ vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config() - while len(example_prompts) == vllm_config.pad_for_cudagraph(len(example_prompts)): + cudagraph_dispatcher = CudagraphDispatcher(vllm_config) + while ( + len(example_prompts) + == cudagraph_dispatcher.dispatch(len(example_prompts))[1].num_tokens + ): example_prompts.append(example_prompts[0]) try: diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index b86534d3d438..e9f281a3dd37 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -58,9 +58,6 @@ def _create_vllm_config( ) compilation_config.post_init_cudagraph_sizes() - mock_config.pad_for_cudagraph = ( - lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size] - ) return mock_config @@ -161,10 +158,13 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config): assert rt_mode == CUDAGraphMode.NONE assert key == BatchDescriptor(num_tokens=15) - # 4. Cascade attention should have a fall back mode + # 4. piecewise_or_eager_only should have a fall back mode desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False) rt_mode, key = dispatcher.dispatch( - num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True + num_tokens=8, + uniform_decode=False, + has_lora=False, + piecewise_or_eager_only=True, ) if "PIECEWISE" in cudagraph_mode_str: # string contains check assert rt_mode == CUDAGraphMode.PIECEWISE @@ -357,7 +357,7 @@ def test_capture_replay_bypass_logic(self): ): full_wrapper(input_1) - rt_mode, key = self.dispatcher.dispatch(desc_1) + rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_1.num_tokens) # 1. Capture first shape action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key) assert action == "capture_global" @@ -366,7 +366,7 @@ def test_capture_replay_bypass_logic(self): action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key) assert action == "replay" - rt_mode, key = self.dispatcher.dispatch(desc_2) + rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_2.num_tokens) # 3. Capture second shape action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key) assert action == "capture_global" @@ -378,7 +378,7 @@ def test_capture_replay_bypass_logic(self): assert action == "replay" # 5. Bypass if no key match - rt_mode, key = self.dispatcher.dispatch(desc_3_unseen) + rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_3_unseen.num_tokens) assert rt_mode == CUDAGraphMode.NONE action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key) assert action == "bypass" diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 0f876c38169a..d57ac29a469b 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -528,15 +528,6 @@ class CompilationConfig: local_cache_dir: str = field(default=None, init=False) # type: ignore """local cache dir for each rank""" - bs_to_padded_graph_size: list[int] = field( - default=None, # type: ignore - init=False, - ) - """optimization: - Intuitively, bs_to_padded_graph_size should be dict[int, int]. - since we know all keys are in a range [0, max_cudagraph_capture_size], - we can optimize it to list[int] for better lookup performance.""" - # keep track of enabled and disabled custom ops enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False) """custom ops that are enabled""" @@ -586,7 +577,6 @@ def compute_hash(self) -> str: "debug_dump_path", "cache_dir", "local_cache_dir", - "bs_to_padded_graph_size", "traced_files", "compilation_time", "static_forward_context", @@ -606,7 +596,6 @@ def __repr__(self) -> str: "enabled_custom_ops": True, "disabled_custom_ops": True, "compilation_time": True, - "bs_to_padded_graph_size": True, "traced_files": True, "inductor_compile_config": { "post_grad_custom_post_pass": True, @@ -827,7 +816,6 @@ def post_init_cudagraph_sizes(self) -> None: """To complete the initialization after cudagraph related configs are set. This includes: - initialize compile_sizes - - pre-compute the mapping bs_to_padded_graph_size """ computed_compile_sizes = [] @@ -851,9 +839,6 @@ def post_init_cudagraph_sizes(self) -> None: if self.cudagraph_capture_sizes: assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size - # May get recomputed in the model runner if adjustment is needed for spec-decode - self.compute_bs_to_padded_graph_size() - def set_splitting_ops_for_v1(self): # To compatible with OOT hardware plugin platform (for example vllm-ascend) # which currently only supports sequence parallelism in eager mode. @@ -1050,21 +1035,3 @@ def adjust_cudagraph_sizes_for_spec_decode( self.max_cudagraph_capture_size = rounded_sizes[-1] self.cudagraph_capture_sizes = rounded_sizes - - # Recompute after adjusting the cudagraph sizes - self.compute_bs_to_padded_graph_size() - - def compute_bs_to_padded_graph_size(self): - # pre-compute the mapping from batch size to padded graph size - self.bs_to_padded_graph_size = [ - 0 for i in range(self.max_cudagraph_capture_size + 1) - ] - for end, start in zip( - self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1], - [0] + self.cudagraph_capture_sizes, - ): - for bs in range(start, end): - if bs == start: - self.bs_to_padded_graph_size[bs] = start - else: - self.bs_to_padded_graph_size[bs] = end diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 5b3a9c437662..2a13fa9eb43a 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -310,13 +310,6 @@ def compute_hash(self) -> str: ] return hash_str - def pad_for_cudagraph(self, batch_size: int) -> int: - # if batch_size > self.compilation_config.max_cudagraph_capture_size, - # it should raise an IndexError. - # the caller should make sure the batch_size is within the range, - # i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size - return self.compilation_config.bs_to_padded_graph_size[batch_size] - def enable_trace_function_call_for_thread(self) -> None: """ Set up function tracing for the current thread, diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index ef0f8d9e6745..bc3411e62467 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -58,12 +58,27 @@ def __init__(self, vllm_config: VllmConfig): self.keys_initialized = False + def _compute_bs_to_padded_graph_size(self) -> None: + """Pre-compute the mapping from batch size to padded graph size.""" + max_size = self.compilation_config.max_cudagraph_capture_size + capture_sizes = self.compilation_config.cudagraph_capture_sizes + self._bs_to_padded_graph_size: list[int] = [0] * (max_size + 1) + for end, start in zip( + capture_sizes + [max_size + 1], + [0] + capture_sizes, + ): + for bs in range(start, end): + if bs == start: + self._bs_to_padded_graph_size[bs] = start + else: + self._bs_to_padded_graph_size[bs] = end + def _create_padded_batch_descriptor( self, num_tokens: int, uniform_decode: bool, has_lora: bool ) -> BatchDescriptor: max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs uniform_decode_query_len = self.uniform_decode_query_len - num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens) + num_tokens_padded = self._bs_to_padded_graph_size[num_tokens] if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL): num_reqs = num_tokens_padded // uniform_decode_query_len @@ -140,18 +155,29 @@ def initialize_cudagraph_keys( self.keys_initialized = True + self._compute_bs_to_padded_graph_size() + def dispatch( self, num_tokens: int, - uniform_decode: bool, - has_lora: bool, - use_cascade_attn: bool = False, + uniform_decode: bool = False, + has_lora: bool = False, + piecewise_or_eager_only: bool = False, ) -> tuple[CUDAGraphMode, BatchDescriptor]: """ - Given conditions(e.g.,batch descriptor and if using cascade attention), + Given conditions(e.g.,batch descriptor and if using piecewise only), dispatch to a cudagraph runtime mode and the valid batch descriptor. A new batch descriptor is returned as we might dispatch a uniform batch to a graph that supports a more general batch (uniform to non-uniform). + + Args: + num_tokens: Number of tokens in the batch. + uniform_decode: Whether the batch is uniform decode (i.e. uniform and query + length is uniform_decode_query_len). + has_lora: Whether LoRA is active. + piecewise_or_eager_only: If True, skip FULL cudagraph checks and + return PIECEWISE or NONE only. (can be used for features cascade + attention that are not supported by full cudagraphs) """ if ( not self.keys_initialized @@ -165,7 +191,7 @@ def dispatch( ) relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs() - if not use_cascade_attn: + if not piecewise_or_eager_only: # check if key exists for full cudagraph if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]: return CUDAGraphMode.FULL, batch_desc diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1c7845a14b74..0ddb0937b951 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -10,7 +10,6 @@ from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import ( - CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, @@ -37,6 +36,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, ) +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import _SAMPLING_EPS @@ -96,24 +96,13 @@ def __init__( self._get_eagle3_use_aux_hidden_state_from_config() ) - self.use_cuda_graph = False - self.compilation_config = self.vllm_config.compilation_config - if self.compilation_config.mode == CompilationMode.VLLM_COMPILE: - cudagraph_mode = self.compilation_config.cudagraph_mode - if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode( - CUDAGraphMode.PIECEWISE - ): - logger.warning( - "Currently the eagle proposer only supports cudagraph_mode " - "PIECEWISE, if you want the drafter to use cuda graphs, " - "please set compilation_config.cudagraph_mode to PIECEWISE " - "or FULL_AND_PIECEWISE" - ) - self.use_cuda_graph = ( - cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE) - and not self.speculative_config.enforce_eager - ) + + # Cudagraph dispatcher for PIECEWISE-only dispatching in eagle. + # Keys are initialized later via initialize_cudagraph_keys() called from + # gpu_model_runner._check_and_update_cudagraph_mode after + # adjust_cudagraph_sizes_for_spec_decode is called. + self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) # persistent buffers for cuda graph self.input_ids = torch.zeros( @@ -216,6 +205,21 @@ def _set_positions(self, num_tokens: int, positions: torch.Tensor): else: self.positions[:num_tokens] = positions + def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: + """Initialize cudagraph dispatcher keys for eagle. + + Eagle only supports PIECEWISE cudagraphs (via mixed_mode). + This should be called after adjust_cudagraph_sizes_for_spec_decode. + """ + eagle_cudagraph_mode = ( + cudagraph_mode.mixed_mode() + if not self.speculative_config.enforce_eager + else CUDAGraphMode.NONE + ) + self.cudagraph_dispatcher.initialize_cudagraph_keys( + eagle_cudagraph_mode, uniform_decode_query_len=1 + ) + def propose( self, # [num_tokens] @@ -285,16 +289,10 @@ def propose( num_tokens_padded=num_tokens, ) - cudagraph_runtime_mode = CUDAGraphMode.NONE - if ( - self.use_cuda_graph - and num_tokens_dp_padded - <= self.compilation_config.max_cudagraph_capture_size - ): - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens_dp_padded) - cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE - else: - num_input_tokens = num_tokens_dp_padded + cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + num_tokens_dp_padded + ) + num_input_tokens = batch_desc.num_tokens if num_tokens_across_dp is not None: num_tokens_across_dp[self.dp_rank] = num_input_tokens @@ -389,16 +387,10 @@ def propose( num_tokens_padded=batch_size, ) - if ( - self.use_cuda_graph - and batch_size_dp_padded - <= self.compilation_config.max_cudagraph_capture_size - ): - input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size_dp_padded) - cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE - else: - input_batch_size = batch_size_dp_padded - cudagraph_runtime_mode = CUDAGraphMode.NONE + cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + batch_size_dp_padded + ) + input_batch_size = batch_desc.num_tokens if batch_size_across_dp is not None: batch_size_across_dp[self.dp_rank] = input_batch_size @@ -793,15 +785,10 @@ def propose_tree( self.positions[:num_tokens] = tree_positions.view(-1) self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1) - if ( - self.use_cuda_graph - and num_tokens <= self.compilation_config.max_cudagraph_capture_size - ): - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) - cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE - else: - num_input_tokens = num_tokens - cudagraph_runtime_mode = CUDAGraphMode.NONE + cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + num_tokens + ) + num_input_tokens = batch_desc.num_tokens # Run the model. with set_forward_context( per_layer_attn_metadata, @@ -1139,9 +1126,6 @@ def dummy_run( use_cudagraphs=True, is_graph_capturing=False, ) -> None: - # Determine if CUDA graphs should be used for this run. - cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph - # FIXME: when using tree-based specdec, adjust number of forward-passes # according to the depth of the tree. for fwd_idx in range( @@ -1152,16 +1136,10 @@ def dummy_run( num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens, ) - if ( - cudagraphs_enabled - and num_tokens_dp_padded - <= self.compilation_config.max_cudagraph_capture_size - ): - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_tokens_dp_padded - ) - else: - num_input_tokens = num_tokens_dp_padded + cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + num_tokens_dp_padded + ) + num_input_tokens = batch_desc.num_tokens if num_tokens_across_dp is not None: num_tokens_across_dp[self.dp_rank] = num_input_tokens @@ -1170,9 +1148,7 @@ def dummy_run( self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE - if cudagraphs_enabled - else CUDAGraphMode.NONE, + cudagraph_runtime_mode=cudagraph_runtime_mode, ): if self.supports_mm_inputs: input_ids = None @@ -1262,7 +1238,8 @@ def _pad_batch_across_dp( num_tokens_unpadded=num_tokens_unpadded, parallel_config=self.vllm_config.parallel_config, allow_microbatching=False, - allow_dp_padding=self.use_cuda_graph, + allow_dp_padding=self.cudagraph_dispatcher.cudagraph_mode + != CUDAGraphMode.NONE, num_tokens_padded=num_tokens_padded, uniform_decode=None, num_scheduled_tokens_per_request=None, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ee28f477a26a..6854a4a8159e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2032,15 +2032,11 @@ def _prepare_kv_sharing_fast_prefill( self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( logits_indices[-1].item() ) - if ( - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_logits <= self.cudagraph_batch_sizes[-1] - ): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) - else: - num_logits_padded = num_logits + # Dispatch for the decoder portion of the model. + _, batch_desc = self.cudagraph_dispatcher.dispatch( + num_logits, piecewise_or_eager_only=True + ) + num_logits_padded = batch_desc.num_tokens logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[ :num_logits_padded ] @@ -2780,7 +2776,7 @@ def _determine_batch_execution_and_padding( lambda num_tokens: self.cudagraph_dispatcher.dispatch( num_tokens=num_tokens, has_lora=has_lora, - use_cascade_attn=use_cascade_attn, + piecewise_or_eager_only=use_cascade_attn, uniform_decode=uniform_decode, ) if not force_eager @@ -4798,6 +4794,10 @@ def _check_and_update_cudagraph_mode( cudagraph_mode, self.uniform_decode_query_len ) + # Initialize eagle's cudagraph dispatcher if using eagle spec decode. + if self.speculative_config and self.speculative_config.use_eagle(): + self.drafter.initialize_cudagraph_keys(cudagraph_mode) + def calculate_reorder_batch_threshold(self) -> None: """ Choose the minimum reorder batch threshold from all attention groups. From 575ad80bb548dfe229f95c84137545a0631bd05b Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 5 Dec 2025 16:53:40 +0000 Subject: [PATCH 2/2] fix precommit Signed-off-by: Lucas Wilkinson --- vllm/v1/worker/gpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6854a4a8159e..41fef55eb5ab 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4796,6 +4796,7 @@ def _check_and_update_cudagraph_mode( # Initialize eagle's cudagraph dispatcher if using eagle spec decode. if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) self.drafter.initialize_cudagraph_keys(cudagraph_mode) def calculate_reorder_batch_threshold(self) -> None: