Skip to content
Open
104 changes: 95 additions & 9 deletions tests/v1/core/test_single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,13 @@ def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]):

def test_get_num_blocks_to_allocate():
block_size = 2
sliding_window_length = 2 * block_size
sliding_window_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=4, # Placeholder value, not related to test result
sliding_window=sliding_window_length,
)

block_pool = BlockPool(
Expand All @@ -331,22 +332,83 @@ def test_get_num_blocks_to_allocate():
KVCacheBlock(i + 1) for i in range(5)
]

num_new_blocks_to_allocate, num_evictable_blocks_to_allocate = (
manager.get_num_blocks_to_allocate(
"1",
20 * block_size,
cached_blocks_1,
total_computed_tokens=len(cached_blocks_1) * block_size,
)
)
assert (
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20
num_new_blocks_to_allocate == 10
and num_evictable_blocks_to_allocate == sliding_window_length // block_size
)

num_new_blocks_to_allocate, num_evictable_blocks_to_allocate = (
manager.get_num_blocks_to_allocate(
"2",
20 * block_size,
cached_blocks_2,
total_computed_tokens=len(cached_blocks_2) * block_size,
)
)
assert (
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15
num_new_blocks_to_allocate == 10
and num_evictable_blocks_to_allocate == sliding_window_length // block_size
)


def test_evictable_cached_blocks_not_double_allocated():
block_size = 2
sliding_window_length = 2 * block_size
sliding_window_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=sliding_window_length,
)

block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = get_sliding_window_manager(sliding_window_spec, block_pool)

request_id = "req"
evictable_block = block_pool.blocks[1] # ref_cnt == 0, eviction candidate

num_new_blocks_to_allocate, num_evictable_blocks_to_allocate = (
manager.get_num_blocks_to_allocate(
request_id=request_id,
num_tokens=2 * block_size,
new_computed_blocks=[evictable_block],
total_computed_tokens=block_size,
)
)
# Free capacity check should count evictable cached blocks, but allocation
# should only allocate the truly new block.
assert num_new_blocks_to_allocate == 1 and num_evictable_blocks_to_allocate == 1

manager.save_new_computed_blocks(
request_id, [evictable_block], total_computed_tokens=block_size
)
new_blocks = manager.allocate_new_blocks(
request_id, num_new_blocks_to_allocate, num_tokens=4
)
assert len(new_blocks) == 1
assert len(manager.req_to_blocks[request_id]) == 2


def test_chunked_local_attention_get_num_blocks_to_allocate():
block_size = 2
attention_chunk_size = 2 * block_size
attention_spec = ChunkedLocalAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
attention_chunk_size=4, # Placeholder value, not related to test result
attention_chunk_size=attention_chunk_size,
)

block_pool = BlockPool(
Expand All @@ -357,10 +419,34 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
KVCacheBlock(i + 1) for i in range(5)
]

assert (
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20
cached_blocks_3 = [KVCacheBlock(i + 1) for i in range(5)]

num_new_blocks_to_allocate, num_evictable_blocks_to_allocate = (
manager.get_num_blocks_to_allocate(
"1",
20 * block_size,
cached_blocks_1,
total_computed_tokens=len(cached_blocks_1) * block_size,
)
)
assert (
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15
assert num_new_blocks_to_allocate == 10 and num_evictable_blocks_to_allocate == 0

num_new_blocks_to_allocate, num_evictable_blocks_to_allocate = (
manager.get_num_blocks_to_allocate(
"2",
20 * block_size,
cached_blocks_2,
total_computed_tokens=len(cached_blocks_2) * block_size,
)
)
assert num_new_blocks_to_allocate == 10 and num_evictable_blocks_to_allocate == 0

num_new_blocks_to_allocate, num_evictable_blocks_to_allocate = (
manager.get_num_blocks_to_allocate(
"3",
20 * block_size,
cached_blocks_3,
total_computed_tokens=len(cached_blocks_3) * block_size,
)
)
assert num_new_blocks_to_allocate == 15 and num_evictable_blocks_to_allocate == 1
17 changes: 7 additions & 10 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class OptimizationLevel(IntEnum):
"""O0 : No optimization. no compilation, no cudagraphs, no other
optimization, just starting up immediately"""
O1 = 1
"""O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
"""O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
cudagraphs"""
O2 = 2
"""O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs."""
Expand Down Expand Up @@ -885,19 +885,16 @@ def has_blocked_weights():
# Hybrid KV cache manager is not supported on non-GPU platforms.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.kv_transfer_config is not None:
# NOTE(Kuntai): turn HMA off for connector for now.
# NOTE(Yifan): warning when both kv connector and hybrid kv cache
# manager are enabled but don't disable hybrid kv cache manager here.
# TODO(Kuntai): have a more elegent solution to check and
# turn off HMA for connector that does not support HMA.
logger.warning(
"Turning off hybrid kv cache manager because "
"`--kv-transfer-config` is set. This will reduce the "
"performance of vLLM on LLMs with sliding window attention "
"or Mamba attention. If you are a developer of kv connector"
", please consider supporting hybrid kv cache manager for "
"your connector by making sure your connector is a subclass"
" of `SupportsHMA` defined in kv_connector/v1/base.py."
"Warning: both kv connector and hybrid kv cache manager are "
"enabled. However, not all kv connectors support HMA. Please "
"check if the kv connector you are using supports HMA, or "
"disable HMA by setting `--disable-hybrid-kv-cache-manager`."
)
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
Expand Down
9 changes: 6 additions & 3 deletions vllm/distributed/kv_transfer/kv_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,12 @@ def create_connector(
# check if the connector supports HMA
hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager
if hma_enabled and not supports_hma(connector_cls):
raise ValueError(
f"Connector {connector_cls.__name__} does not support HMA but "
f"HMA is enabled. Please set `--disable-hybrid-kv-cache-manager`."
## REMOVE BEFORE MERGE (YIFAN): Revert this warning back to raising
# an ValueError.
logger.warning(
"Connector %s does not support HMA but HMA is enabled. Please set "
"--disable-hybrid-kv-cache-manager to disable HMA.",
connector_cls.__name__,
)
Comment on lines +59 to 65
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change from raising a ValueError to a logger.warning is marked with a "REMOVE BEFORE MERGE" comment. Using a connector that does not support Hybrid Memory Allocation (HMA) when HMA is enabled can lead to incorrect behavior or hard-to-debug runtime errors. It is much safer to fail fast with an exception. This change should be reverted to raise ValueError before merging to prevent potential issues in production.

            raise ValueError(
                f"Connector {connector_cls.__name__} does not support HMA but "
                f"HMA is enabled. Please set `--disable-hybrid-kv-cache-manager`.
            )


logger.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
SupportsHMA,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
Expand All @@ -26,18 +27,21 @@
logger = init_logger(__name__)


class LMCacheConnectorV1(KVConnectorBase_V1):
class LMCacheConnectorV1(KVConnectorBase_V1, SupportsHMA):
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
## REMOVE BEFORE MERGE (YIFAN): this is temporary workaround to work with
# LMCache. Remove this once having LMCache-side support for new interfaces.
vllm_config.kv_cache_config = kv_cache_config # type: ignore[attr-defined]
super().__init__(
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
)
assert vllm_config.kv_transfer_config is not None
use_native = vllm_config.kv_transfer_config.get_from_extra_config(
assert vllm_config.kv_transfer_config is not None # type: ignore[attr-defined]
use_native = vllm_config.kv_transfer_config.get_from_extra_config( # type: ignore[attr-defined]
"use_native", False
)
if use_native:
Expand Down Expand Up @@ -213,4 +217,18 @@ def request_finished(
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
# NOTE: LMCache overloads request_finished so `block_ids` here can be
# either list[int] or tuple[list[int], ...].
return self._lmcache_engine.request_finished(request, block_ids)

## REMOVE BEFORE MERGE (YIFAN): this is temporary workaround to work with
# LMCache. Remove this once having LMCache-side support for new interfaces.
def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
# NOTE: LMCache overloads request_finished so `block_ids` here can be
# either list[int] or tuple[list[int], ...]. This could be changed in
# the future to separate these two methods.
return self._lmcache_engine.request_finished(request, block_ids)
Comment on lines +224 to 234
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The request_finished_all_groups method is marked as a temporary workaround with a "REMOVE BEFORE MERGE" comment. It appears to be a shim for a new interface required by the hybrid allocator. This temporary implementation should be replaced with a proper solution, and the dependency on this fix in LMCache should be resolved before this pull request is merged.

13 changes: 13 additions & 0 deletions vllm/v1/core/block_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,18 @@ def cache_full_blocks(
new_hashes: list[ExternalBlockHash] | None = (
[] if self.enable_kv_cache_events else None
)

# Some blocks may be null blocks when enabling sparse attention or sliding
# window attention. For now, we only have sliding window attention, and
# null blocks must be at the beginning.
first_non_null_blk_idx = 0
for i, blk in enumerate(new_full_blocks):
if not blk.is_null:
first_non_null_blk_idx = i
break

for i, blk in enumerate(new_full_blocks[first_non_null_blk_idx:]):
assert not blk.is_null
assert blk.block_hash is None
block_hash = new_block_hashes[i]

Expand All @@ -280,6 +291,8 @@ def cache_full_blocks(
BlockStored(
block_hashes=new_hashes,
parent_block_hash=parent_block_hash,
## TODO(Yifan): here token_ids may be over-estimated for
## sliding window layers
Comment on lines +294 to +295
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TODO comment indicates that token_ids might be over-estimated for sliding window layers. This could lead to incorrect data in BlockStored events, which could be problematic for external systems consuming these events. If external systems rely on exact token IDs for correctness, this over-estimation could be a significant issue. This should be addressed to ensure data integrity for event consumers.

token_ids=request.all_token_ids[
num_cached_blocks * block_size : num_full_blocks * block_size
],
Expand Down
59 changes: 46 additions & 13 deletions vllm/v1/core/kv_cache_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def get_num_blocks_to_allocate(
num_tokens: int,
new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
num_encoder_tokens: int,
) -> int:
total_computed_tokens: int,
) -> tuple[list[int], list[int]]:
"""
Get the number of blocks needed to be allocated for the request.

Expand All @@ -85,26 +86,45 @@ def get_num_blocks_to_allocate(
prefix caching.
num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention.
total_computed_tokens: Include both local and external tokens.

Returns:
The number of blocks.
The number of new blocks to allocate for each kv cache group.
The number of evictable blocks to allocate for each kv cache group.
"""
num_blocks_to_allocate = 0
num_new_blocks_to_allocate = []
num_evictable_blocks_to_allocate = []
for i, manager in enumerate(self.single_type_managers):
if isinstance(manager, CrossAttentionManager):
# For cross-attention, we issue a single static allocation
# of blocks based on the number of encoder input tokens.
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_encoder_tokens, []
(
num_new_blocks_to_allocate_single_group,
num_evictable_blocks_to_allocate_single_group,
) = manager.get_num_blocks_to_allocate(
request_id, num_encoder_tokens, [], 0
)
else:
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks[i]
(
num_new_blocks_to_allocate_single_group,
num_evictable_blocks_to_allocate_single_group,
) = manager.get_num_blocks_to_allocate(
request_id,
num_tokens,
new_computed_blocks[i],
total_computed_tokens,
)
return num_blocks_to_allocate
num_new_blocks_to_allocate.append(num_new_blocks_to_allocate_single_group)
num_evictable_blocks_to_allocate.append(
num_evictable_blocks_to_allocate_single_group
)
return num_new_blocks_to_allocate, num_evictable_blocks_to_allocate

def save_new_computed_blocks(
self, request_id: str, new_computed_blocks: tuple[Sequence[KVCacheBlock], ...]
self,
request_id: str,
new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
total_computed_tokens: int,
) -> None:
"""
Add the new computed blocks to the request.
Expand All @@ -113,19 +133,31 @@ def save_new_computed_blocks(
request_id: The request ID.
new_computed_blocks: The new computed blocks just hitting the
prefix cache.
total_computed_tokens: The total number of computed tokens, including
both local and external tokens.
"""
for i, manager in enumerate(self.single_type_managers):
manager.save_new_computed_blocks(request_id, new_computed_blocks[i])
manager.save_new_computed_blocks(
request_id, new_computed_blocks[i], total_computed_tokens
)

def allocate_new_blocks(
self, request_id: str, num_tokens: int, num_encoder_tokens: int = 0
self,
request_id: str,
num_blocks_to_allocate_per_group: list[int],
num_tokens: int,
num_encoder_tokens: int = 0,
) -> tuple[list[KVCacheBlock], ...]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
token slots. If `num_blocks_to_allocate_per_group[i]` is smaller than
the number of blocks needed (in the case of sliding window attention),
the leading blocks will be padded with null blocks.

Args:
request_id: The request ID.
num_blocks_to_allocate_per_group: The number of blocks to allocate
for each kv cache group.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
num_encoder_tokens: The number of encoder tokens for allocating
Expand All @@ -137,11 +169,12 @@ def allocate_new_blocks(
return tuple(
manager.allocate_new_blocks(
request_id,
num_blocks_to_allocate_per_group[i],
num_encoder_tokens
if isinstance(manager, CrossAttentionManager)
else num_tokens,
)
for manager in self.single_type_managers
for i, manager in enumerate(self.single_type_managers)
)

def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
Expand Down
Loading