diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index d7c55c6e7df..71c67c46681 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -501,13 +501,7 @@ def select_torchair_padded_batch_size(self, batch_size: int): def update_torchair_graph_batch_sizes(self): # return graph_batch_sizes according to the max number of tokens # first pad according to the number of requests - if self.is_kv_consumer and self.speculative_config and self.speculative_config.method == 'deepseek_mtp': - # pd disaggregation scenario may incorrectly calculate the batch in mtp scenario, so we force set it to max_num_reqs - self.torchair_graph_batch_sizes = [self.max_num_reqs] - logger.warning( - f"is kv_consumer, torch_graph_batch_sizes sets to [max_num_seqs] {[self.max_num_reqs]}" - ) - elif len(self.torchair_graph_batch_sizes) == 0: + if len(self.torchair_graph_batch_sizes) == 0: self.torchair_graph_batch_sizes = [1, self.max_num_reqs] else: self.torchair_graph_batch_sizes = sorted( @@ -537,10 +531,11 @@ def update_torchair_graph_batch_sizes(self): def _align_graph_size_divisible_by_tp_size(self): tp_size = self.parallel_config.tensor_parallel_size + lcm_size = math.lcm(tp_size, self.decode_token_per_req) new_graph_batch_sizes = [] for graph_batch_size in self.torchair_graph_batch_sizes: - cur_graph_batch_size = self.calculate_new_torchair_graph_batch_size( - graph_batch_size, tp_size) + cur_graph_batch_size = (graph_batch_size + lcm_size - + 1) // lcm_size * lcm_size if cur_graph_batch_size not in new_graph_batch_sizes and \ cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens: new_graph_batch_sizes.append(cur_graph_batch_size)