-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[Core][Hybrid allocator + connector] Support hybrid allocator + kv cache connector #30166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ea30aa9
3307c42
170bdd6
73844d6
d35b48d
6f83848
a82b155
1f476e7
c403e4a
bdc1802
d9b25b2
fa53140
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| KVConnectorBase_V1, | ||
| KVConnectorMetadata, | ||
| KVConnectorRole, | ||
| SupportsHMA, | ||
| ) | ||
| from vllm.logger import init_logger | ||
| from vllm.v1.core.sched.output import SchedulerOutput | ||
|
|
@@ -26,18 +27,21 @@ | |
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class LMCacheConnectorV1(KVConnectorBase_V1): | ||
| class LMCacheConnectorV1(KVConnectorBase_V1, SupportsHMA): | ||
| def __init__( | ||
| self, | ||
| vllm_config: "VllmConfig", | ||
| role: KVConnectorRole, | ||
| kv_cache_config: "KVCacheConfig", | ||
| ): | ||
| ## REMOVE BEFORE MERGE (YIFAN): this is temporary workaround to work with | ||
| # LMCache. Remove this once having LMCache-side support for new interfaces. | ||
| vllm_config.kv_cache_config = kv_cache_config # type: ignore[attr-defined] | ||
| super().__init__( | ||
| vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config | ||
| ) | ||
| assert vllm_config.kv_transfer_config is not None | ||
| use_native = vllm_config.kv_transfer_config.get_from_extra_config( | ||
| assert vllm_config.kv_transfer_config is not None # type: ignore[attr-defined] | ||
| use_native = vllm_config.kv_transfer_config.get_from_extra_config( # type: ignore[attr-defined] | ||
| "use_native", False | ||
| ) | ||
| if use_native: | ||
|
|
@@ -213,4 +217,18 @@ def request_finished( | |
| Optional KVTransferParams to be included in the request outputs | ||
| returned by the engine. | ||
| """ | ||
| # NOTE: LMCache overloads request_finished so `block_ids` here can be | ||
| # either list[int] or tuple[list[int], ...]. | ||
| return self._lmcache_engine.request_finished(request, block_ids) | ||
|
|
||
| ## REMOVE BEFORE MERGE (YIFAN): this is temporary workaround to work with | ||
| # LMCache. Remove this once having LMCache-side support for new interfaces. | ||
| def request_finished_all_groups( | ||
| self, | ||
| request: "Request", | ||
| block_ids: tuple[list[int], ...], | ||
| ) -> tuple[bool, dict[str, Any] | None]: | ||
| # NOTE: LMCache overloads request_finished so `block_ids` here can be | ||
| # either list[int] or tuple[list[int], ...]. This could be changed in | ||
| # the future to separate these two methods. | ||
| return self._lmcache_engine.request_finished(request, block_ids) | ||
|
Comment on lines
+224
to
234
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -253,7 +253,18 @@ def cache_full_blocks( | |
| new_hashes: list[ExternalBlockHash] | None = ( | ||
| [] if self.enable_kv_cache_events else None | ||
| ) | ||
|
|
||
| # Some blocks may be null blocks when enabling sparse attention or sliding | ||
| # window attention. For now, we only have sliding window attention, and | ||
| # null blocks must be at the beginning. | ||
| first_non_null_blk_idx = 0 | ||
| for i, blk in enumerate(new_full_blocks): | ||
| if not blk.is_null: | ||
| first_non_null_blk_idx = i | ||
| break | ||
|
|
||
| for i, blk in enumerate(new_full_blocks[first_non_null_blk_idx:]): | ||
| assert not blk.is_null | ||
| assert blk.block_hash is None | ||
| block_hash = new_block_hashes[i] | ||
|
|
||
|
|
@@ -280,6 +291,8 @@ def cache_full_blocks( | |
| BlockStored( | ||
| block_hashes=new_hashes, | ||
| parent_block_hash=parent_block_hash, | ||
| ## TODO(Yifan): here token_ids may be over-estimated for | ||
| ## sliding window layers | ||
|
Comment on lines
+294
to
+295
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The TODO comment indicates that |
||
| token_ids=request.all_token_ids[ | ||
| num_cached_blocks * block_size : num_full_blocks * block_size | ||
| ], | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change from raising a
ValueErrorto alogger.warningis marked with a "REMOVE BEFORE MERGE" comment. Using a connector that does not support Hybrid Memory Allocation (HMA) when HMA is enabled can lead to incorrect behavior or hard-to-debug runtime errors. It is much safer to fail fast with an exception. This change should be reverted toraise ValueErrorbefore merging to prevent potential issues in production.