Skip to content

Commit 18425fe

Browse files
committed
Fixed the comments. Added more tests
1 parent a67d4a6 commit 18425fe

File tree

5 files changed

+268
-21
lines changed

5 files changed

+268
-21
lines changed

examples/dynamo/low_cpu_memory_compilation.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,25 +86,44 @@ def forward(self, x):
8686

8787
"""
8888
You should be able to see two back-to-back TensorRT engines in the graph
89+
8990
Graph Structure:
9091
9192
Inputs: List[Tensor: (1, 1024, 224, 224)@float32]
9293
...
93-
TRT Engine #1 - Submodule name: _run_on_acc_0
94+
TRT Engine #1 - Submodule name: _run_on_acc_0_resource_split_0
9495
Engine Inputs: List[Tensor: (1, 1024, 224, 224)@float32]
9596
Number of Operators in Engine: 9
9697
Engine Outputs: List[Tensor: (1, 1024, 112, 112)@float32]
9798
...
98-
TRT Engine #2 - Submodule name: _run_on_acc_1
99+
TRT Engine #2 - Submodule name: _run_on_acc_0_resource_split_1
99100
Engine Inputs: List[Tensor: (1, 1024, 112, 112)@float32]
100101
Number of Operators in Engine: 3
101102
Engine Outputs: List[Tensor: (1, 10)@float32]
102103
...
103104
Outputs: List[Tensor: (1, 10)@float32]
104105
106+
------------------------- Aggregate Stats -------------------------
107+
108+
Average Number of Operators per TRT Engine: 6.0
109+
Most Operators in a TRT Engine: 9
105110
111+
********** Recommendations **********
112+
113+
- For minimal graph segmentation, select min_block_size=9 which would generate 1 TRT engine(s)
114+
- For moderate graph segmentation, select min_block_size=6 which would generate 1 TRT engine(s)
115+
- The current level of graph segmentation is equivalent to selecting min_block_size=3 which generates 2 TRT engine(s)
106116
GraphModule(
107-
(_run_on_acc_0): TorchTensorRTModule()
108-
(_run_on_acc_1): TorchTensorRTModule()
117+
(_run_on_acc_0_resource_split_0): TorchTensorRTModule()
118+
(_run_on_acc_0_resource_split_1): TorchTensorRTModule()
119+
)
120+
121+
122+
123+
def forward(self, x):
124+
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
125+
_run_on_acc_0_resource_split_0 = self._run_on_acc_0_resource_split_0(x); x = None
126+
_run_on_acc_0_resource_split_1 = self._run_on_acc_0_resource_split_1(_run_on_acc_0_resource_split_0); _run_on_acc_0_resource_split_0 = None
127+
return pytree.tree_unflatten((_run_on_acc_0_resource_split_1,), self._out_spec)
109128
)
110129
"""

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -622,10 +622,6 @@ def compile(
622622
"'arg_inputs' and 'inputs' should not be used at the same time."
623623
)
624624

625-
assert (
626-
cpu_memory_budget >= 2 * 1024 * 1024 * 1024
627-
), "CPU memory budget must be greater than 10GB"
628-
629625
arg_inputs = inputs or arg_inputs
630626

631627
if kwarg_inputs is None:

py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
from functools import lru_cache
23
from typing import Any, Callable, Dict, List, Set, Tuple
34

@@ -135,7 +136,7 @@ def get_node_in_fusion_pattern(
135136
Key: node that appears in the fusion pattern
136137
Value: the list of nodes that should be fused together
137138
"""
138-
fusion_nodes = {}
139+
fusion_nodes = defaultdict(set)
139140
for compiled_pattern_graph in get_compiled_atomic_subgraphs():
140141
subgraph_matcher = SubgraphMatcher(compiled_pattern_graph.graph)
141142
match_result = subgraph_matcher.match(graph)
@@ -149,7 +150,7 @@ def get_node_in_fusion_pattern(
149150
and node not in match.placeholder_nodes
150151
}
151152
for node in fusion_group:
152-
fusion_nodes[node] = fusion_group
153+
fusion_nodes[node].update(fusion_group)
153154

154155
return fusion_nodes
155156

py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
"""
4747

4848
import logging
49-
from typing import Dict, List, Tuple
49+
from typing import Dict, List, Set, Tuple
5050

5151
import psutil
5252
import torch
@@ -92,7 +92,7 @@ def __init__(
9292

9393
self._node_submodule_map: Dict[str, str] = {}
9494
self._return_tuple = False
95-
self.fusion_patterns: Dict[torch.fx.Node, List[torch.fx.Node]] = {}
95+
self.fusion_patterns: Dict[torch.fx.Node, Set[torch.fx.Node]] = {}
9696

9797
def partition_graph(self) -> torch.fx.GraphModule:
9898
"""Build the final partitioned `GraphModule` honoring memory constraints.
@@ -214,7 +214,6 @@ def break_subgraphs(
214214
# We throw an error if the remaining memory is almost empty compared to the model size.
215215
# i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation.
216216
sizes = self.size_of_subgraphs(subgraphs)
217-
# subgraph_size_budget = 500*1024*1024
218217
if sum(sizes) > subgraph_size_budget * 40:
219218
raise ValueError(
220219
"CPU memory budget or available memory is too small to compile the model. "
@@ -470,12 +469,6 @@ def validate_and_correct_subgraphs(
470469
visited_nodes[subgraph.nodes[-1]] = i + 1
471470
continue
472471

473-
elif not subgraph.is_acc:
474-
# non-accelerated subgraphs should be put in the next subgraph
475-
for node in subgraph.nodes:
476-
visited_nodes[subgraph.nodes[-1]] = i + 1
477-
continue
478-
479472
else:
480473
to_remove_nodes = []
481474
for j, node in enumerate(subgraph.nodes):

0 commit comments

Comments
 (0)