Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tests.utils import multi_gpu_test
from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher

from ...utils import check_logprobs_close, check_outputs_equal

Expand Down Expand Up @@ -172,7 +173,11 @@ def test_mamba_cache_cg_padding(
tensor dimensions aren't compatible.
"""
vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph(len(example_prompts)):
cudagraph_dispatcher = CudagraphDispatcher(vllm_config)
while (
len(example_prompts)
== cudagraph_dispatcher.dispatch(len(example_prompts))[1].num_tokens
):
Comment on lines 175 to +180

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Dispatch loop never terminates in hybrid padding test

In test_mamba_cache_cg_padding a CudagraphDispatcher is created and immediately used in the while condition without ever calling initialize_cudagraph_keys, so dispatch returns a BatchDescriptor with num_tokens unchanged when keys_initialized is False. The condition len(example_prompts) == ...num_tokens therefore stays true on every iteration and the loop appends forever, hanging the test before it exercises any logic. This test now never completes under any configuration.

Useful? React with 👍 / 👎.

example_prompts.append(example_prompts[0])

try:
Expand Down
16 changes: 8 additions & 8 deletions tests/v1/cudagraph/test_cudagraph_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ def _create_vllm_config(
)

compilation_config.post_init_cudagraph_sizes()
mock_config.pad_for_cudagraph = (
lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size]
)

return mock_config

Expand Down Expand Up @@ -161,10 +158,13 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
assert rt_mode == CUDAGraphMode.NONE
assert key == BatchDescriptor(num_tokens=15)

# 4. Cascade attention should have a fall back mode
# 4. piecewise_or_eager_only should have a fall back mode
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
num_tokens=8,
uniform_decode=False,
has_lora=False,
piecewise_or_eager_only=True,
)
if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE
Expand Down Expand Up @@ -357,7 +357,7 @@ def test_capture_replay_bypass_logic(self):
):
full_wrapper(input_1)

rt_mode, key = self.dispatcher.dispatch(desc_1)
rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_1.num_tokens)
# 1. Capture first shape
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
assert action == "capture_global"
Expand All @@ -366,7 +366,7 @@ def test_capture_replay_bypass_logic(self):
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
assert action == "replay"

rt_mode, key = self.dispatcher.dispatch(desc_2)
rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_2.num_tokens)
# 3. Capture second shape
action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key)
assert action == "capture_global"
Expand All @@ -378,7 +378,7 @@ def test_capture_replay_bypass_logic(self):
assert action == "replay"

# 5. Bypass if no key match
rt_mode, key = self.dispatcher.dispatch(desc_3_unseen)
rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_3_unseen.num_tokens)
assert rt_mode == CUDAGraphMode.NONE
action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key)
assert action == "bypass"
Expand Down
33 changes: 0 additions & 33 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,15 +528,6 @@ class CompilationConfig:
local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank"""

bs_to_padded_graph_size: list[int] = field(
default=None, # type: ignore
init=False,
)
"""optimization:
Intuitively, bs_to_padded_graph_size should be dict[int, int].
since we know all keys are in a range [0, max_cudagraph_capture_size],
we can optimize it to list[int] for better lookup performance."""

# keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
"""custom ops that are enabled"""
Expand Down Expand Up @@ -586,7 +577,6 @@ def compute_hash(self) -> str:
"debug_dump_path",
"cache_dir",
"local_cache_dir",
"bs_to_padded_graph_size",
"traced_files",
"compilation_time",
"static_forward_context",
Expand All @@ -606,7 +596,6 @@ def __repr__(self) -> str:
"enabled_custom_ops": True,
"disabled_custom_ops": True,
"compilation_time": True,
"bs_to_padded_graph_size": True,
"traced_files": True,
"inductor_compile_config": {
"post_grad_custom_post_pass": True,
Expand Down Expand Up @@ -827,7 +816,6 @@ def post_init_cudagraph_sizes(self) -> None:
"""To complete the initialization after cudagraph related
configs are set. This includes:
- initialize compile_sizes
- pre-compute the mapping bs_to_padded_graph_size
"""

computed_compile_sizes = []
Expand All @@ -851,9 +839,6 @@ def post_init_cudagraph_sizes(self) -> None:
if self.cudagraph_capture_sizes:
assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size

# May get recomputed in the model runner if adjustment is needed for spec-decode
self.compute_bs_to_padded_graph_size()

def set_splitting_ops_for_v1(self):
# To compatible with OOT hardware plugin platform (for example vllm-ascend)
# which currently only supports sequence parallelism in eager mode.
Expand Down Expand Up @@ -1050,21 +1035,3 @@ def adjust_cudagraph_sizes_for_spec_decode(

self.max_cudagraph_capture_size = rounded_sizes[-1]
self.cudagraph_capture_sizes = rounded_sizes

# Recompute after adjusting the cudagraph sizes
self.compute_bs_to_padded_graph_size()

def compute_bs_to_padded_graph_size(self):
# pre-compute the mapping from batch size to padded graph size
self.bs_to_padded_graph_size = [
0 for i in range(self.max_cudagraph_capture_size + 1)
]
for end, start in zip(
self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
[0] + self.cudagraph_capture_sizes,
):
for bs in range(start, end):
if bs == start:
self.bs_to_padded_graph_size[bs] = start
else:
self.bs_to_padded_graph_size[bs] = end
7 changes: 0 additions & 7 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,13 +310,6 @@ def compute_hash(self) -> str:
]
return hash_str

def pad_for_cudagraph(self, batch_size: int) -> int:
# if batch_size > self.compilation_config.max_cudagraph_capture_size,
# it should raise an IndexError.
# the caller should make sure the batch_size is within the range,
# i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size
return self.compilation_config.bs_to_padded_graph_size[batch_size]

def enable_trace_function_call_for_thread(self) -> None:
"""
Set up function tracing for the current thread,
Expand Down
38 changes: 32 additions & 6 deletions vllm/v1/cudagraph_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,27 @@ def __init__(self, vllm_config: VllmConfig):

self.keys_initialized = False

def _compute_bs_to_padded_graph_size(self) -> None:
"""Pre-compute the mapping from batch size to padded graph size."""
max_size = self.compilation_config.max_cudagraph_capture_size
capture_sizes = self.compilation_config.cudagraph_capture_sizes
self._bs_to_padded_graph_size: list[int] = [0] * (max_size + 1)
for end, start in zip(
capture_sizes + [max_size + 1],
[0] + capture_sizes,
):
for bs in range(start, end):
if bs == start:
self._bs_to_padded_graph_size[bs] = start
else:
self._bs_to_padded_graph_size[bs] = end

def _create_padded_batch_descriptor(
self, num_tokens: int, uniform_decode: bool, has_lora: bool
) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens)
num_tokens_padded = self._bs_to_padded_graph_size[num_tokens]

if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
num_reqs = num_tokens_padded // uniform_decode_query_len
Expand Down Expand Up @@ -140,18 +155,29 @@ def initialize_cudagraph_keys(

self.keys_initialized = True

self._compute_bs_to_padded_graph_size()

def dispatch(
self,
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
use_cascade_attn: bool = False,
uniform_decode: bool = False,
has_lora: bool = False,
piecewise_or_eager_only: bool = False,
) -> tuple[CUDAGraphMode, BatchDescriptor]:
"""
Given conditions(e.g.,batch descriptor and if using cascade attention),
Given conditions(e.g.,batch descriptor and if using piecewise only),
dispatch to a cudagraph runtime mode and the valid batch descriptor.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).

Args:
num_tokens: Number of tokens in the batch.
uniform_decode: Whether the batch is uniform decode (i.e. uniform and query
length is uniform_decode_query_len).
has_lora: Whether LoRA is active.
piecewise_or_eager_only: If True, skip FULL cudagraph checks and
return PIECEWISE or NONE only. (can be used for features cascade
attention that are not supported by full cudagraphs)
"""
if (
not self.keys_initialized
Expand All @@ -165,7 +191,7 @@ def dispatch(
)
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()

if not use_cascade_attn:
if not piecewise_or_eager_only:
# check if key exists for full cudagraph
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_desc
Expand Down
Loading