From 15f3d6598d32b9b5c1931ed1fe425de1f837ed7d Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Fri, 5 Dec 2025 22:26:26 +0800 Subject: [PATCH] [Feature] Support multi graphs for torchair Signed-off-by: Jade Zheng --- vllm_ascend/torchair/torchair_model_runner.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index d7c55c6e7df..71c67c46681 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -501,13 +501,7 @@ def select_torchair_padded_batch_size(self, batch_size: int): def update_torchair_graph_batch_sizes(self): # return graph_batch_sizes according to the max number of tokens # first pad according to the number of requests - if self.is_kv_consumer and self.speculative_config and self.speculative_config.method == 'deepseek_mtp': - # pd disaggregation scenario may incorrectly calculate the batch in mtp scenario, so we force set it to max_num_reqs - self.torchair_graph_batch_sizes = [self.max_num_reqs] - logger.warning( - f"is kv_consumer, torch_graph_batch_sizes sets to [max_num_seqs] {[self.max_num_reqs]}" - ) - elif len(self.torchair_graph_batch_sizes) == 0: + if len(self.torchair_graph_batch_sizes) == 0: self.torchair_graph_batch_sizes = [1, self.max_num_reqs] else: self.torchair_graph_batch_sizes = sorted( @@ -537,10 +531,11 @@ def update_torchair_graph_batch_sizes(self): def _align_graph_size_divisible_by_tp_size(self): tp_size = self.parallel_config.tensor_parallel_size + lcm_size = math.lcm(tp_size, self.decode_token_per_req) new_graph_batch_sizes = [] for graph_batch_size in self.torchair_graph_batch_sizes: - cur_graph_batch_size = self.calculate_new_torchair_graph_batch_size( - graph_batch_size, tp_size) + cur_graph_batch_size = (graph_batch_size + lcm_size - + 1) // lcm_size * lcm_size if cur_graph_batch_size not in new_graph_batch_sizes and \ cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens: new_graph_batch_sizes.append(cur_graph_batch_size)