Skip to content

Commit c1bfdfc

Browse files
authored
Utility and example for custom op expansion (#2701)
A utility and an example showing how onnxscript functions can be used to define function expansions and be used with the inliner to replace calls to the custom function with an expanded subgraph. This is useful to perform certain classes of graph surgery easily.
1 parent 97513c7 commit c1bfdfc

File tree

3 files changed

+97
-0
lines changed

3 files changed

+97
-0
lines changed

.lintrunner.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ include_patterns = [
3939
exclude_patterns = [
4040
'tests/**', # Skip linting test files for speed
4141
# FIXME: Fix typing annotations in these files
42+
'examples/custom_op_expansion.py',
4243
'onnxscript/converter_test.py',
4344
'onnxscript/converter.py',
4445
'onnxscript/evaluator_test.py',

examples/custom_op_expansion.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
# ruff: noqa
4+
5+
"""A utility and an example showing how onnxscript functions can be used to define function expansions
6+
and be used with the inliner to replace calls to the custom function with an expanded subgraph.
7+
This is useful to perform certain classes of graph surgery easily.
8+
"""
9+
10+
import onnx
11+
12+
import onnxscript
13+
import onnxscript.utils.replace as replace
14+
15+
script = onnxscript.script
16+
FLOAT = onnxscript.FLOAT
17+
op = onnxscript.values.opset22
18+
local = onnxscript.values.Opset("local", 1)
19+
20+
21+
# Example Model: Actual models can come from ModelBuilder or Exporter or any other source.
22+
# Models can contain calls to custom operations (from a custom domain like 'local' here or
23+
# even "com.microsoft" etc.)
24+
@script()
25+
def model_script(X: FLOAT["N"], Y: FLOAT["N"]) -> FLOAT["N"]:
26+
DoubleX = op.Add(X, X)
27+
YSquare = op.Mul(Y, Y)
28+
# Example call to a custom operation
29+
Temp1 = local.CustomOp1(DoubleX, YSquare)
30+
# Another call to a custom operation with an attribute
31+
Temp2 = local.CustomOp2(Temp1, alp=0.9)
32+
return Temp2
33+
34+
35+
# Define expansions for custom operations as onnxscript functions
36+
@script(opset=local)
37+
def CustomOp1(X: FLOAT["N"], Y: FLOAT["N"]) -> FLOAT["N"]:
38+
Temp1 = op.Sub(X, Y)
39+
return op.Div(Temp1, X)
40+
41+
42+
@script(opset=local)
43+
def CustomOp2(X: FLOAT["N"], alp: float) -> FLOAT["N"]:
44+
Temp2 = op.Elu(X, alpha=alp)
45+
return op.Mul(Temp2, Temp2)
46+
47+
48+
# Now, we can replace the custom operations in the model with their expansions:
49+
50+
functions = [CustomOp1.to_function_proto(), CustomOp2.to_function_proto()]
51+
52+
model = model_script.to_model_proto()
53+
54+
print("Original Model with custom operations:")
55+
print(onnx.printer.to_text(model))
56+
57+
58+
updated_model = replace.replace_functions(model, functions)
59+
60+
print("\nUpdated Model after replacing custom operations with their expansions:")
61+
print(onnx.printer.to_text(updated_model))

onnxscript/utils/replace.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""A utility function to replace custom operations in a model with their expansions"""
4+
5+
from typing import Sequence
6+
7+
import onnx
8+
import onnx_ir as ir
9+
import onnx_ir.passes.common as common_passes
10+
11+
12+
def replace_functions(
13+
model: onnx.ModelProto, functions: Sequence[onnx.FunctionProto]
14+
) -> onnx.ModelProto:
15+
"""A utility function to replace custom operations in a model with their expansions:
16+
Args:
17+
model: An ONNX ModelProto possibly containing calls to custom operations.
18+
functions: A sequence of FunctionProto defining the expansions for the custom operations.
19+
20+
Returns:
21+
An updated ModelProto with custom operations replaced by their expansions.
22+
"""
23+
irmodel = ir.from_proto(model)
24+
irfunctions = [ir.from_proto(func) for func in functions]
25+
model_functions = irmodel.functions
26+
if len(model_functions) != 0:
27+
# Since we use inlining, check that there are no model-local functions.
28+
raise ValueError("Input model cannot have model-local functions.")
29+
for func in irfunctions:
30+
model_functions[func.identifier()] = func
31+
32+
# TODO (rama): Ideally, we should provide users more control over renaming strategy for inlined values.
33+
common_passes.InlinePass()(irmodel)
34+
common_passes.RemoveUnusedOpsetsPass()(irmodel)
35+
return ir.to_proto(irmodel)

0 commit comments

Comments
 (0)