1+ #
2+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+ #
4+ # Licensed under the Apache License, Version 2.0 (the "License");
5+ # you may not use this file except in compliance with the License.
6+ # You may obtain a copy of the License at
7+ #
8+ # http://www.apache.org/licenses/LICENSE-2.0
9+ #
10+ # Unless required by applicable law or agreed to in writing, software
11+ # distributed under the License is distributed on an "AS IS" BASIS,
12+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ # See the License for the specific language governing permissions and
14+ # limitations under the License.
15+ # This file is a part of the vllm-ascend project.
16+ #
17+
18+ from dataclasses import dataclass , fields
19+ from typing import Type , Union
20+
21+ from vllm .config import SchedulerConfig
22+
23+ MAX_INT = 2147483647
24+
25+
26+ @dataclass
27+ class AscendSchedulerConfig (SchedulerConfig ):
28+ enable_chunked_prefill : bool = False
29+ max_long_partial_prefills : int = 1
30+ long_prefill_token_threshold : int = MAX_INT
31+ policy : str = "fcfs"
32+ scheduler_cls : Union [str , Type [object ]] = (
33+ "vllm_ascend.core.scheduler.AscendScheduler" )
34+ enable_pd_transfer : bool = False
35+ decode_max_num_seqs : int = 0
36+
37+ @classmethod
38+ def initialize_from_config (
39+ cls ,
40+ vllm_scheduler_config : SchedulerConfig ,
41+ ascend_scheduler_config ,
42+ ):
43+ scheduler_config = {
44+ field .name : getattr (vllm_scheduler_config , field .name )
45+ for field in fields (vllm_scheduler_config ) if field .init
46+ }
47+ # Override default values into original SchedulerConfig
48+ scheduler_config ["enable_chunked_prefill" ] = False
49+ scheduler_config ["max_long_partial_prefills" ] = None
50+ scheduler_config ["long_prefill_token_threshold" ] = None
51+ scheduler_config ["policy" ] = "fcfs"
52+ scheduler_config ["scheduler_cls" ] = (
53+ "vllm_ascend.core.scheduler.AscendScheduler" )
54+ scheduler_config ["enable_pd_transfer" ] = False
55+ scheduler_config ["decode_max_num_seqs" ] = 0
56+ # Override params in original SchedulerConfig with params in ascend_scheduler_config
57+ for k , _ in scheduler_config .items ():
58+ if hasattr (ascend_scheduler_config , k ):
59+ scheduler_config [k ] = getattr (ascend_scheduler_config , k )
60+ return cls (** scheduler_config )
61+
62+ def __post_init__ (self , * args ) -> None :
63+ self .max_num_encoder_input_tokens = self .max_num_batched_tokens
64+ self .encoder_cache_size = self .max_num_batched_tokens
65+ self .chunked_prefill_enabled = self .enable_chunked_prefill
66+ if (self .max_num_batched_tokens < self .max_model_len
67+ and not self .chunked_prefill_enabled ):
68+ raise ValueError (
69+ "Ascend scheduler is enabled without chunked prefill feature. "
70+ f"Argument max_num_batched_tokens ({ self .max_num_batched_tokens } ) is "
71+ f"smaller than max_model_len ({ self .max_model_len } ). "
72+ "This effectively limits the maximum sequence length to "
73+ "max_num_batched_tokens and makes vLLM reject longer "
74+ "sequences. Please increase max_num_batched_tokens or "
75+ "decrease max_model_len." )
76+ # concurrent partial prefills. Default is 1 meaning not enabled.
77+ if self .max_long_partial_prefills is None :
78+ self .max_long_partial_prefills = 1
79+ self .long_prefill_token_threshold = MAX_INT
80+
81+ if self .long_prefill_token_threshold is None or \
82+ self .long_prefill_token_threshold <= 0 :
83+ if self .max_model_len is None :
84+ self .long_prefill_token_threshold = MAX_INT
85+ else :
86+ self .long_prefill_token_threshold = \
87+ max (1 , int (self .max_model_len * 0.04 ))
88+
89+ if self .max_long_partial_prefills < 0 :
90+ raise ValueError (
91+ f"max_long_partial_prefills must be non-negative, but got "
92+ f"{ self .max_long_partial_prefills } " )
93+ if self .long_prefill_token_threshold < 0 :
94+ raise ValueError (
95+ f"long_prefill_token_threshold must be non-negative, but got "
96+ f"{ self .long_prefill_token_threshold } " )
97+
98+ if self .policy != "fcfs" :
99+ raise NotImplementedError (
100+ f"currently AscendScheduler only supports fcfs policy, got { self .policy } "
101+ )
102+ if getattr (self , "scheduler_delay_factor" , 0 ) > 0 :
103+ raise NotImplementedError (
104+ "currently AscendScheduler doesn't support scheduler_delay_factor."
105+ )
0 commit comments