onnxslim 0.1.80__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/__init__.py +16 -0
- onnxslim/__main__.py +4 -0
- onnxslim/argparser.py +215 -0
- onnxslim/cli/__init__.py +1 -0
- onnxslim/cli/_main.py +180 -0
- onnxslim/core/__init__.py +219 -0
- onnxslim/core/optimization/__init__.py +146 -0
- onnxslim/core/optimization/dead_node_elimination.py +151 -0
- onnxslim/core/optimization/subexpression_elimination.py +76 -0
- onnxslim/core/optimization/weight_tying.py +59 -0
- onnxslim/core/pattern/__init__.py +249 -0
- onnxslim/core/pattern/elimination/__init__.py +5 -0
- onnxslim/core/pattern/elimination/concat.py +61 -0
- onnxslim/core/pattern/elimination/reshape.py +77 -0
- onnxslim/core/pattern/elimination/reshape_as.py +64 -0
- onnxslim/core/pattern/elimination/slice.py +108 -0
- onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
- onnxslim/core/pattern/fusion/__init__.py +8 -0
- onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
- onnxslim/core/pattern/fusion/convadd.py +70 -0
- onnxslim/core/pattern/fusion/convbn.py +86 -0
- onnxslim/core/pattern/fusion/convmul.py +69 -0
- onnxslim/core/pattern/fusion/gelu.py +47 -0
- onnxslim/core/pattern/fusion/gemm.py +330 -0
- onnxslim/core/pattern/fusion/padconv.py +89 -0
- onnxslim/core/pattern/fusion/reduce.py +67 -0
- onnxslim/core/pattern/registry.py +28 -0
- onnxslim/misc/__init__.py +0 -0
- onnxslim/misc/tabulate.py +2681 -0
- onnxslim/third_party/__init__.py +0 -0
- onnxslim/third_party/_sympy/__init__.py +0 -0
- onnxslim/third_party/_sympy/functions.py +205 -0
- onnxslim/third_party/_sympy/numbers.py +397 -0
- onnxslim/third_party/_sympy/printers.py +491 -0
- onnxslim/third_party/_sympy/solve.py +172 -0
- onnxslim/third_party/_sympy/symbol.py +102 -0
- onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
- onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
- onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
- onnxslim/third_party/symbolic_shape_infer.py +3273 -0
- onnxslim/utils.py +794 -0
- onnxslim/version.py +1 -0
- onnxslim-0.1.80.dist-info/METADATA +207 -0
- onnxslim-0.1.80.dist-info/RECORD +65 -0
- onnxslim-0.1.80.dist-info/WHEEL +5 -0
- onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
- onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
- onnxslim-0.1.80.dist-info/top_level.txt +1 -0
- onnxslim-0.1.80.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from collections import Counter
|
|
5
|
+
from typing import List, Optional, Union
|
|
6
|
+
|
|
7
|
+
import onnx
|
|
8
|
+
|
|
9
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
10
|
+
from onnxslim.core.pattern.registry import get_fusion_patterns
|
|
11
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger("onnxslim")
|
|
14
|
+
|
|
15
|
+
from .dead_node_elimination import dead_node_elimination
|
|
16
|
+
from .subexpression_elimination import subexpression_elimination
|
|
17
|
+
from .weight_tying import tie_weights
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class OptimizationSettings:
|
|
21
|
+
constant_folding = True
|
|
22
|
+
graph_fusion = True
|
|
23
|
+
dead_node_elimination = True
|
|
24
|
+
subexpression_elimination = True
|
|
25
|
+
weight_tying = True
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def keys(cls):
|
|
29
|
+
return [
|
|
30
|
+
"constant_folding",
|
|
31
|
+
"graph_fusion",
|
|
32
|
+
"dead_node_elimination",
|
|
33
|
+
"subexpression_elimination",
|
|
34
|
+
"weight_tying",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def reset(cls, skip_optimizations: list[str] | None = None):
|
|
39
|
+
for key in cls.keys():
|
|
40
|
+
if skip_optimizations and key in skip_optimizations:
|
|
41
|
+
setattr(cls, key, False)
|
|
42
|
+
else:
|
|
43
|
+
setattr(cls, key, True)
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def stats(cls):
|
|
47
|
+
return {key: getattr(cls, key) for key in cls.keys()}
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def enabled(cls):
|
|
51
|
+
return any([getattr(cls, key) for key in cls.keys()])
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def optimize_model(model: onnx.ModelProto | gs.Graph, skip_fusion_patterns: str | None = None) -> onnx.ModelProto:
|
|
55
|
+
"""Optimize and transform the given ONNX model using various fusion patterns and graph rewriting techniques."""
|
|
56
|
+
graph = model if isinstance(model, gs.Graph) else gs.import_onnx(model)
|
|
57
|
+
if OptimizationSettings.graph_fusion:
|
|
58
|
+
logger.debug("Start graph_fusion.")
|
|
59
|
+
fusion_patterns = get_fusion_patterns(skip_fusion_patterns)
|
|
60
|
+
graph_fusion(graph, fusion_patterns)
|
|
61
|
+
logger.debug("Finish graph_fusion.")
|
|
62
|
+
if OptimizationSettings.dead_node_elimination:
|
|
63
|
+
logger.debug("Start dead_node_elimination.")
|
|
64
|
+
dead_node_elimination(graph)
|
|
65
|
+
graph.cleanup(remove_unused_graph_inputs=True).toposort()
|
|
66
|
+
logger.debug("Finish dead_node_elimination.")
|
|
67
|
+
if OptimizationSettings.subexpression_elimination:
|
|
68
|
+
logger.debug("Start subexpression_elimination.")
|
|
69
|
+
subexpression_elimination(graph)
|
|
70
|
+
graph.cleanup(remove_unused_graph_inputs=True).toposort()
|
|
71
|
+
logger.debug("Finish subexpression_elimination.")
|
|
72
|
+
if OptimizationSettings.weight_tying:
|
|
73
|
+
logger.debug("Start weight_tying.")
|
|
74
|
+
tie_weights(graph)
|
|
75
|
+
logger.debug("Finish weight_tying.")
|
|
76
|
+
model = gs.export_onnx(graph)
|
|
77
|
+
|
|
78
|
+
return model
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@gs.Graph.register()
|
|
82
|
+
def replace_custom_layer(
|
|
83
|
+
self,
|
|
84
|
+
op: str,
|
|
85
|
+
inputs,
|
|
86
|
+
outputs: list[str],
|
|
87
|
+
name: str,
|
|
88
|
+
attrs: dict | None = None,
|
|
89
|
+
domain: str = "ai.onnx.contrib",
|
|
90
|
+
):
|
|
91
|
+
"""Replace a custom layer in the computational graph with specified parameters and domain."""
|
|
92
|
+
return self.layer(
|
|
93
|
+
op=op,
|
|
94
|
+
inputs=inputs,
|
|
95
|
+
outputs=outputs,
|
|
96
|
+
name=name,
|
|
97
|
+
attrs=attrs,
|
|
98
|
+
domain=domain,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def graph_fusion(graph: Graph, fusion_patterns: dict, is_subgraph=False):
|
|
103
|
+
for subgraph in graph.subgraphs():
|
|
104
|
+
graph_fusion(subgraph, fusion_patterns, is_subgraph=True)
|
|
105
|
+
|
|
106
|
+
fusion_pairs = find_matches(graph, fusion_patterns)
|
|
107
|
+
for match in fusion_pairs.values():
|
|
108
|
+
graph.replace_custom_layer(**match)
|
|
109
|
+
|
|
110
|
+
graph.cleanup(remove_unused_graph_inputs=True if not is_subgraph else False).toposort()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def find_matches(graph: Graph, fusion_patterns: dict):
|
|
114
|
+
"""Find matching patterns in the graph based on provided fusion patterns."""
|
|
115
|
+
match_map = {}
|
|
116
|
+
|
|
117
|
+
counter = Counter()
|
|
118
|
+
for node in reversed(graph.nodes):
|
|
119
|
+
if node.name not in match_map:
|
|
120
|
+
for layer_type, pattern_matcher in fusion_patterns.items():
|
|
121
|
+
match = pattern_matcher.match(node)
|
|
122
|
+
if match:
|
|
123
|
+
match_case = pattern_matcher.rewrite(opset=graph.opset)
|
|
124
|
+
logger.debug(f"matched pattern {layer_type}")
|
|
125
|
+
for _, match in match_case.items():
|
|
126
|
+
if "op" not in match:
|
|
127
|
+
match.update({"op": layer_type})
|
|
128
|
+
if "name" not in match:
|
|
129
|
+
match.update({"name": f"{layer_type.lower()}_{counter[layer_type]}"})
|
|
130
|
+
counter.update([layer_type])
|
|
131
|
+
match_map.update(match_case)
|
|
132
|
+
|
|
133
|
+
return match_map
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def get_previous_node_by_type(node, op_type, trajectory=None):
|
|
137
|
+
"""Recursively find and return the first preceding node of a specified type in the computation graph."""
|
|
138
|
+
if trajectory is None:
|
|
139
|
+
trajectory = []
|
|
140
|
+
node_feeds = node.feeds
|
|
141
|
+
for node_feed in node_feeds:
|
|
142
|
+
trajectory.append(node_feed)
|
|
143
|
+
if node_feed.op == op_type:
|
|
144
|
+
return trajectory
|
|
145
|
+
else:
|
|
146
|
+
return get_previous_node_by_type(node_feed, op_type, trajectory)
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
6
|
+
from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx
|
|
7
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Variable
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger("onnxslim")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def dead_node_elimination(graph, is_subgraph=False):
|
|
13
|
+
"""Perform in-place constant folding optimizations on the given computational graph by eliminating redundant
|
|
14
|
+
nodes.
|
|
15
|
+
"""
|
|
16
|
+
for subgraph in graph.subgraphs():
|
|
17
|
+
dead_node_elimination(subgraph, is_subgraph=True)
|
|
18
|
+
|
|
19
|
+
for node in graph.nodes:
|
|
20
|
+
if node.op in {"Identity", "Dropout"}:
|
|
21
|
+
if not is_subgraph:
|
|
22
|
+
node.erase()
|
|
23
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
24
|
+
elif node.op == "Pad":
|
|
25
|
+
if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant):
|
|
26
|
+
pad_value = node.inputs[1].values.tolist()
|
|
27
|
+
pad_value = pad_value if isinstance(pad_value, list) else [pad_value]
|
|
28
|
+
if all(value == 0 for value in pad_value):
|
|
29
|
+
node.erase()
|
|
30
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
31
|
+
elif node.op == "Cast":
|
|
32
|
+
inp_dtype = next(dtype_to_onnx(input.dtype) for input in node.inputs)
|
|
33
|
+
if inp_dtype == node.attrs["to"]:
|
|
34
|
+
node.erase()
|
|
35
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
36
|
+
elif node.op == "Reshape":
|
|
37
|
+
if (node.inputs[0].shape and len(node.inputs[0].shape) == 1) and (
|
|
38
|
+
node.outputs[0].shape and len(node.outputs[0].shape) == 1
|
|
39
|
+
):
|
|
40
|
+
node.erase()
|
|
41
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
42
|
+
elif node.inputs[0].shape and node.outputs[0].shape and node.inputs[0].shape == node.outputs[0].shape:
|
|
43
|
+
node.erase()
|
|
44
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
45
|
+
else:
|
|
46
|
+
node_output_shape = node.outputs[0].shape
|
|
47
|
+
if node_output_shape and check_shape(node_output_shape) and not isinstance(node.inputs[1], gs.Constant):
|
|
48
|
+
shapes = [shape if isinstance(shape, int) else -1 for shape in node_output_shape]
|
|
49
|
+
reshape_const = gs.Constant(
|
|
50
|
+
f"{node.inputs[1].name}_",
|
|
51
|
+
values=np.array(shapes, dtype=np.int64),
|
|
52
|
+
)
|
|
53
|
+
node.inputs.pop(1)
|
|
54
|
+
node.inputs.insert(1, reshape_const)
|
|
55
|
+
logger.debug(f"replacing {node.op} op: {node.name}")
|
|
56
|
+
elif node.op == "Mul":
|
|
57
|
+
if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or (
|
|
58
|
+
isinstance(node.inputs[0], Constant) and isinstance(node.inputs[1], Variable)
|
|
59
|
+
):
|
|
60
|
+
idx, constant_variable = get_constant_variable(node, return_idx=True)
|
|
61
|
+
if np.all(constant_variable.values == 1):
|
|
62
|
+
var_idx = 0 if idx == 1 else 1
|
|
63
|
+
node.erase(var_idx, 0)
|
|
64
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
65
|
+
elif node.op == "Add":
|
|
66
|
+
if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or (
|
|
67
|
+
isinstance(node.inputs[0], Constant) and isinstance(node.inputs[1], Variable)
|
|
68
|
+
):
|
|
69
|
+
idx, constant_variable = get_constant_variable(node, return_idx=True)
|
|
70
|
+
value = constant_variable.values
|
|
71
|
+
var_idx = 0 if idx == 1 else 1
|
|
72
|
+
if value.ndim == 0 and value == 0:
|
|
73
|
+
node.erase(var_idx, 0)
|
|
74
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
75
|
+
elif np.all(value == 0) and (node.inputs[var_idx].shape == node.outputs[0].shape):
|
|
76
|
+
node.erase(var_idx, 0)
|
|
77
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
78
|
+
elif node.op == "Expand":
|
|
79
|
+
# tests/test_onnx_nets.py::TestTimmClass::test_timm[lambda_resnet26rpt_256]
|
|
80
|
+
if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant):
|
|
81
|
+
constant_variable = node.inputs[1]
|
|
82
|
+
value = constant_variable.values
|
|
83
|
+
if node.inputs[0].shape is not None and node.inputs[0].shape == node.outputs[0].shape:
|
|
84
|
+
node.erase()
|
|
85
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
86
|
+
elif value.ndim == 0 and value == 1:
|
|
87
|
+
node.erase()
|
|
88
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
89
|
+
elif node.op == "Concat":
|
|
90
|
+
if len(node.inputs) == 1:
|
|
91
|
+
node.erase()
|
|
92
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
93
|
+
else:
|
|
94
|
+
for input in node.inputs:
|
|
95
|
+
if isinstance(input, Constant) and input.values.size == 0:
|
|
96
|
+
node.inputs.remove(input)
|
|
97
|
+
elif node.op == "Sub":
|
|
98
|
+
if isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable):
|
|
99
|
+
constant_variable = node.inputs[1]
|
|
100
|
+
value = constant_variable.values
|
|
101
|
+
if value.ndim == 0 and value == 0:
|
|
102
|
+
node.erase()
|
|
103
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
104
|
+
elif np.all(value == 0) and (node.inputs[0].shape == node.outputs[0].shape):
|
|
105
|
+
node.erase()
|
|
106
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
107
|
+
elif node.op == "Div":
|
|
108
|
+
if isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable):
|
|
109
|
+
constant_variable = node.inputs[1]
|
|
110
|
+
value = constant_variable.values
|
|
111
|
+
if value.ndim == 0 and value == 1:
|
|
112
|
+
node.erase()
|
|
113
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
114
|
+
elif np.all(value == 1) and (node.inputs[0].shape == node.outputs[0].shape):
|
|
115
|
+
node.erase()
|
|
116
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
117
|
+
elif node.op == "Split":
|
|
118
|
+
if (
|
|
119
|
+
len(node.outputs) == 1
|
|
120
|
+
and node.outputs[0].shape
|
|
121
|
+
and node.inputs[0].shape
|
|
122
|
+
and node.outputs[0].shape == node.inputs[0].shape
|
|
123
|
+
):
|
|
124
|
+
node.erase()
|
|
125
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
126
|
+
elif node.op == "Resize":
|
|
127
|
+
mode = node.attrs.get("mode")
|
|
128
|
+
if mode is None:
|
|
129
|
+
node.attrs["mode"] = "nearest"
|
|
130
|
+
logger.debug(f"setting mode to nearest for {node.op} op: {node.name} since it is not set")
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def check_shape(shapes):
|
|
134
|
+
"""Verify that 'shapes' contains exactly one string and all other elements are positive integers."""
|
|
135
|
+
string_count = 0
|
|
136
|
+
non_negative_int_count = 0
|
|
137
|
+
|
|
138
|
+
for item in shapes:
|
|
139
|
+
if isinstance(item, str):
|
|
140
|
+
string_count += 1
|
|
141
|
+
elif isinstance(item, int) and item > 0:
|
|
142
|
+
non_negative_int_count += 1
|
|
143
|
+
|
|
144
|
+
return (string_count == 1 and non_negative_int_count == len(shapes) - 1) or non_negative_int_count == len(shapes)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def get_constant_variable(node, return_idx=False):
|
|
148
|
+
"""Return the first constant variable found in a node's inputs, optionally including the index."""
|
|
149
|
+
for idx, input in enumerate(list(node.inputs)):
|
|
150
|
+
if isinstance(input, Constant):
|
|
151
|
+
return (idx, input) if return_idx else input
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Variable
|
|
4
|
+
|
|
5
|
+
logger = logging.getLogger("onnxslim")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def find_and_remove_replaceable_nodes(nodes):
|
|
9
|
+
"""Find and remove duplicate or replaceable nodes in a given list of computational graph nodes."""
|
|
10
|
+
|
|
11
|
+
def get_node_key(node):
|
|
12
|
+
input_names = []
|
|
13
|
+
for input_node in node.inputs:
|
|
14
|
+
if isinstance(input_node, Variable):
|
|
15
|
+
input_names.append(input_node.name)
|
|
16
|
+
return "_".join(input_names) if input_names else None
|
|
17
|
+
|
|
18
|
+
node_dict = {}
|
|
19
|
+
for node in nodes:
|
|
20
|
+
key = get_node_key(node)
|
|
21
|
+
if key:
|
|
22
|
+
if key in node_dict:
|
|
23
|
+
node_dict[key].append(node)
|
|
24
|
+
else:
|
|
25
|
+
node_dict[key] = [node]
|
|
26
|
+
|
|
27
|
+
for key, bucketed_nodes in node_dict.items():
|
|
28
|
+
if len(bucketed_nodes) > 1:
|
|
29
|
+
keep_nodes = [True] * len(bucketed_nodes)
|
|
30
|
+
for i, node in enumerate(bucketed_nodes):
|
|
31
|
+
if keep_nodes[i]:
|
|
32
|
+
for j in range(i + 1, len(bucketed_nodes)):
|
|
33
|
+
if keep_nodes[j] and can_be_replaced(node, bucketed_nodes[j]):
|
|
34
|
+
keep_nodes[j] = False
|
|
35
|
+
existing_node = node
|
|
36
|
+
to_be_removed_node = bucketed_nodes[j]
|
|
37
|
+
to_be_removed_node.replace_all_uses_with(existing_node)
|
|
38
|
+
logger.debug(
|
|
39
|
+
f"Node {to_be_removed_node.name} Op {to_be_removed_node.op} can be replaced by {existing_node.name}"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def sequences_equal(seq1, seq2):
|
|
44
|
+
"""Check if two sequences are equal by comparing their lengths and elements."""
|
|
45
|
+
length_match = len(seq1) == len(seq2)
|
|
46
|
+
if not length_match:
|
|
47
|
+
return False
|
|
48
|
+
|
|
49
|
+
return all(elem1 == elem2 for elem1, elem2 in zip(seq1, seq2))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def can_be_replaced(node, other_node):
|
|
53
|
+
"""Check if two nodes can be replaced based on their operations, attributes, and inputs."""
|
|
54
|
+
attrs_match = node.op == other_node.op and node.attrs == other_node.attrs
|
|
55
|
+
node_input = [input for input in node.inputs if not input.is_empty()]
|
|
56
|
+
other_node_input = [input for input in other_node.inputs if not input.is_empty()]
|
|
57
|
+
inputs_match = sequences_equal(node_input, other_node_input)
|
|
58
|
+
|
|
59
|
+
return attrs_match and inputs_match
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def subexpression_elimination(graph):
|
|
63
|
+
"""Perform subexpression elimination on a computational graph to optimize node operations."""
|
|
64
|
+
nodes_by_op = {}
|
|
65
|
+
|
|
66
|
+
for subgraph in graph.subgraphs():
|
|
67
|
+
subexpression_elimination(subgraph)
|
|
68
|
+
|
|
69
|
+
for node in graph.nodes:
|
|
70
|
+
op = node.op
|
|
71
|
+
if op not in nodes_by_op:
|
|
72
|
+
nodes_by_op[op] = []
|
|
73
|
+
nodes_by_op[op].append(node)
|
|
74
|
+
|
|
75
|
+
for nodes in nodes_by_op.values():
|
|
76
|
+
find_and_remove_replaceable_nodes(nodes)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
logger = logging.getLogger("onnxslim")
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
10
|
+
|
|
11
|
+
THRESHOLD = int(os.getenv("ONNXSLIM_THRESHOLD")) if os.getenv("ONNXSLIM_THRESHOLD") else 1000
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def tie_weights(graph):
|
|
15
|
+
"""Tie weights in a computational graph to reduce the number of parameters."""
|
|
16
|
+
tensor_map = graph.tensors()
|
|
17
|
+
constant_tensors = [tensor for tensor in tensor_map.values() if isinstance(tensor, gs.Constant)]
|
|
18
|
+
|
|
19
|
+
sub_graphs = graph.subgraphs(recursive=True)
|
|
20
|
+
sub_graphs_constant_tensors = [
|
|
21
|
+
[tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, gs.Constant)]
|
|
22
|
+
for sub_graph in sub_graphs
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
constant_tensors.extend([tensor for tensors in sub_graphs_constant_tensors for tensor in tensors])
|
|
26
|
+
|
|
27
|
+
constant_by_shape = defaultdict(list)
|
|
28
|
+
|
|
29
|
+
for constant_tensor in constant_tensors:
|
|
30
|
+
shape = tuple(constant_tensor.shape)
|
|
31
|
+
if np.prod(shape) < THRESHOLD:
|
|
32
|
+
constant_by_shape[shape].append(constant_tensor)
|
|
33
|
+
|
|
34
|
+
for nodes in constant_by_shape.values():
|
|
35
|
+
find_and_remove_replaceable_constants(nodes)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def find_and_remove_replaceable_constants(constant_tensors):
|
|
39
|
+
def replace_constant_references(existing_constant, to_be_removed_constant):
|
|
40
|
+
users = list(to_be_removed_constant.outputs)
|
|
41
|
+
|
|
42
|
+
for user in users:
|
|
43
|
+
for idx, inp in enumerate(user.inputs):
|
|
44
|
+
if (inp == to_be_removed_constant) and (inp.name == to_be_removed_constant.name):
|
|
45
|
+
user.inputs.pop(idx)
|
|
46
|
+
user.inputs.insert(idx, existing_constant)
|
|
47
|
+
|
|
48
|
+
if len(constant_tensors) > 1:
|
|
49
|
+
keep_constants = [True] * len(constant_tensors)
|
|
50
|
+
for i, constant_tensor in enumerate(constant_tensors):
|
|
51
|
+
if keep_constants[i]:
|
|
52
|
+
for j in range(i + 1, len(constant_tensors)):
|
|
53
|
+
if keep_constants[j]:
|
|
54
|
+
if constant_tensor == constant_tensors[j]:
|
|
55
|
+
keep_constants[j] = False
|
|
56
|
+
replace_constant_references(constant_tensor, constant_tensors[j])
|
|
57
|
+
logger.debug(
|
|
58
|
+
f"Constant {constant_tensors[j].name} can be replaced by {constant_tensor.name}"
|
|
59
|
+
)
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import re
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
|
|
5
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
6
|
+
from onnxslim.third_party.onnx_graphsurgeon import Constant
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger("onnxslim")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_name(name):
|
|
12
|
+
"""Sanitizes the input string by replacing illegal characters with underscores and prefixing with an underscore if
|
|
13
|
+
numeric.
|
|
14
|
+
"""
|
|
15
|
+
_illegal_char_regex = re.compile("[^0-9a-zA-Z_]+")
|
|
16
|
+
sanitized_name = _illegal_char_regex.sub("_", name)
|
|
17
|
+
if sanitized_name.isdigit():
|
|
18
|
+
sanitized_name = f"_{sanitized_name}"
|
|
19
|
+
|
|
20
|
+
return sanitized_name
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class NodeDescriptor:
|
|
24
|
+
"""
|
|
25
|
+
case 0: input [1, 2, 3, 4, 5] output [0] Optype Name 5 1 i0 i1 i2 i3 i4 o0
|
|
26
|
+
case 1: input [1, ...] output [0] Optype Name 1+ 1 i0 o0
|
|
27
|
+
case 2: input [..., 1, ...] output [0] Optype Name 1* 1 i0 o0.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, node_spec):
|
|
31
|
+
"""Initialize NodeDescriptor with node_spec list requiring at least 4 elements."""
|
|
32
|
+
if not isinstance(node_spec, list):
|
|
33
|
+
raise ValueError("node_spec must be a list")
|
|
34
|
+
if len(node_spec) < 4:
|
|
35
|
+
raise ValueError(f"node_spec must have at least 4 elements {node_spec}")
|
|
36
|
+
|
|
37
|
+
def get_input_info(io_spec):
|
|
38
|
+
"""Parses io_spec to return a tuple of (integer, boolean) indicating the presence of a plus sign in the
|
|
39
|
+
input.
|
|
40
|
+
"""
|
|
41
|
+
if not io_spec.isdigit():
|
|
42
|
+
match = re.search(r"(\d+)([+*])", io_spec)
|
|
43
|
+
if match:
|
|
44
|
+
number = match.group(1)
|
|
45
|
+
operator = match.group(2)
|
|
46
|
+
|
|
47
|
+
if not number.isdigit():
|
|
48
|
+
raise ValueError(f"input_num and output_num must be integers {io_spec}")
|
|
49
|
+
|
|
50
|
+
if operator == "+":
|
|
51
|
+
return int(number), True, "append"
|
|
52
|
+
elif operator == "*":
|
|
53
|
+
return int(number), True, "free-match"
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError(f"operator must be + or * {io_spec}")
|
|
56
|
+
|
|
57
|
+
return int(io_spec), False, None
|
|
58
|
+
|
|
59
|
+
self.op = node_spec[0]
|
|
60
|
+
self.name = node_spec[1]
|
|
61
|
+
self.input_num, self.coarse_input_num, self.input_mode = get_input_info(node_spec[2])
|
|
62
|
+
self.output_num, self.coarse_output_num, self.output_mode = get_input_info(node_spec[3])
|
|
63
|
+
self.input_names = node_spec[4 : 4 + self.input_num]
|
|
64
|
+
self.output_names = node_spec[4 + self.input_num :]
|
|
65
|
+
assert len(self.input_names) == self.input_num
|
|
66
|
+
assert len(self.output_names) == self.output_num, f"{self.name} {len(self.output_names)} != {self.output_num}"
|
|
67
|
+
|
|
68
|
+
def __repr__(self):
|
|
69
|
+
"""Return a string representation of the object, including its name, operation type, input/output counts, and
|
|
70
|
+
input/output names.
|
|
71
|
+
"""
|
|
72
|
+
return f"name: {self.name}, type: {self.op}, input_num: {self.input_num}, output_num: {self.output_num}, input_names: {self.input_names}, output_names: {self.output_names}"
|
|
73
|
+
|
|
74
|
+
def __dict__(self):
|
|
75
|
+
"""Returns a dictionary representation of the object, with 'name' as the key."""
|
|
76
|
+
return {
|
|
77
|
+
"name": self,
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class Pattern:
|
|
82
|
+
def __init__(self, pattern):
|
|
83
|
+
"""Initialize the Pattern class with a given pattern and parse its nodes."""
|
|
84
|
+
self.pattern = pattern
|
|
85
|
+
self.nodes = self.parse_nodes()
|
|
86
|
+
|
|
87
|
+
def parse_nodes(self):
|
|
88
|
+
"""Parse pattern into a list of NodeDescriptor objects from non-empty, stripped, and split lines."""
|
|
89
|
+
nodes = self.pattern.split("\n")
|
|
90
|
+
nodes = [line.strip().split() for line in nodes if line]
|
|
91
|
+
nodes = [NodeDescriptor(node) for node in nodes if node]
|
|
92
|
+
return nodes
|
|
93
|
+
|
|
94
|
+
def match(self, node):
|
|
95
|
+
"""Match a node against a precompiled pattern."""
|
|
96
|
+
return self.pattern.match(node)
|
|
97
|
+
|
|
98
|
+
def __repr__(self):
|
|
99
|
+
"""Return a string representation of the pattern attribute."""
|
|
100
|
+
return self.pattern
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class PatternMatcher:
|
|
104
|
+
def __init__(self, pattern, priority):
|
|
105
|
+
"""Initialize the PatternMatcher with a given pattern and priority, and prepare node references and output
|
|
106
|
+
names.
|
|
107
|
+
"""
|
|
108
|
+
self.pattern = pattern
|
|
109
|
+
self.priority = priority
|
|
110
|
+
self.pattern_dict = {node.name: node for node in pattern.nodes}
|
|
111
|
+
self.output_names = [node.name for node in pattern.nodes if node.op == "output"]
|
|
112
|
+
|
|
113
|
+
def get_match_point(self):
|
|
114
|
+
"""Retrieve the match point node from the pattern dictionary based on output node input names."""
|
|
115
|
+
return self.pattern_dict[self.pattern_dict[self.output_names[0]].input_names[0]]
|
|
116
|
+
|
|
117
|
+
def reset_input(self):
|
|
118
|
+
for k, v in self.pattern_dict.items():
|
|
119
|
+
if v.op == "input":
|
|
120
|
+
if hasattr(self, v.name):
|
|
121
|
+
delattr(self, v.name)
|
|
122
|
+
|
|
123
|
+
def match(self, node):
|
|
124
|
+
"""Match a given node to a pattern by comparing input names with the match point node from the pattern
|
|
125
|
+
dictionary.
|
|
126
|
+
"""
|
|
127
|
+
self.reset_input()
|
|
128
|
+
match_point = self.get_match_point()
|
|
129
|
+
|
|
130
|
+
def match_(node, pattern_node):
|
|
131
|
+
"""Match a given node to a pattern by comparing input names with the match point node from the pattern
|
|
132
|
+
dictionary.
|
|
133
|
+
"""
|
|
134
|
+
if pattern_node.op == "input":
|
|
135
|
+
if hasattr(self, pattern_node.name):
|
|
136
|
+
if getattr(self, pattern_node.name) == node:
|
|
137
|
+
return True
|
|
138
|
+
else:
|
|
139
|
+
return False
|
|
140
|
+
else:
|
|
141
|
+
setattr(self, pattern_node.name, node)
|
|
142
|
+
return True
|
|
143
|
+
# node is an input variable
|
|
144
|
+
if not hasattr(node, "op"):
|
|
145
|
+
return False
|
|
146
|
+
|
|
147
|
+
if node.op == pattern_node.op:
|
|
148
|
+
setattr(self, pattern_node.name, node)
|
|
149
|
+
|
|
150
|
+
node_feeds = node.feeds
|
|
151
|
+
if pattern_node.coarse_input_num:
|
|
152
|
+
if len(node_feeds) < len(pattern_node.input_names):
|
|
153
|
+
return False
|
|
154
|
+
else:
|
|
155
|
+
if len(node_feeds) != len(pattern_node.input_names):
|
|
156
|
+
return False
|
|
157
|
+
|
|
158
|
+
if pattern_node.input_mode == "append" or pattern_node.input_mode is None:
|
|
159
|
+
pattern_nodes = [
|
|
160
|
+
self.pattern_dict[name] if name != "?" else None for name in pattern_node.input_names
|
|
161
|
+
]
|
|
162
|
+
all_match = True
|
|
163
|
+
for node_feed, pattern_node in zip(node_feeds, pattern_nodes):
|
|
164
|
+
if pattern_node is not None:
|
|
165
|
+
node_match = match_(node_feed, pattern_node)
|
|
166
|
+
if not node_match:
|
|
167
|
+
return False
|
|
168
|
+
setattr(self, pattern_node.name, node_feed)
|
|
169
|
+
|
|
170
|
+
return all_match
|
|
171
|
+
elif pattern_node.input_mode == "free-match":
|
|
172
|
+
pattern_nodes = [
|
|
173
|
+
self.pattern_dict[name] if name != "?" else None for name in pattern_node.input_names
|
|
174
|
+
]
|
|
175
|
+
all_match = True
|
|
176
|
+
for pattern_node in pattern_nodes:
|
|
177
|
+
if pattern_node is not None:
|
|
178
|
+
node_match = False
|
|
179
|
+
for node_feed in node_feeds:
|
|
180
|
+
node_match = match_(node_feed, pattern_node)
|
|
181
|
+
if node_match:
|
|
182
|
+
break
|
|
183
|
+
if not node_match:
|
|
184
|
+
return False
|
|
185
|
+
setattr(self, pattern_node.name, node_feed)
|
|
186
|
+
|
|
187
|
+
return all_match
|
|
188
|
+
return False
|
|
189
|
+
|
|
190
|
+
if match_(node, match_point):
|
|
191
|
+
setattr(self, "output", node.outputs)
|
|
192
|
+
if self.parameter_check():
|
|
193
|
+
return True
|
|
194
|
+
|
|
195
|
+
return False
|
|
196
|
+
|
|
197
|
+
@abstractmethod
|
|
198
|
+
def rewrite(self, opset=11):
|
|
199
|
+
"""Abstract method to rewrite the graph based on matched patterns, to be implemented by subclasses."""
|
|
200
|
+
raise NotImplementedError("rewrite method must be implemented")
|
|
201
|
+
|
|
202
|
+
def parameter_check(self):
|
|
203
|
+
"""Check and validate parameters, returning True if valid."""
|
|
204
|
+
return True
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class PatternGenerator:
|
|
208
|
+
def __init__(self, onnx_model):
|
|
209
|
+
"""Initialize the PatternGenerator class with an ONNX model and process its graph."""
|
|
210
|
+
self.graph = gs.import_onnx(onnx_model)
|
|
211
|
+
self.graph.fold_constants().cleanup().toposort()
|
|
212
|
+
|
|
213
|
+
def generate(self):
|
|
214
|
+
"""Generate the inputs, outputs, and nodes from the graph of the initialized ONNX model."""
|
|
215
|
+
inputs = self.graph.inputs
|
|
216
|
+
outputs = self.graph.outputs
|
|
217
|
+
nodes = self.graph.nodes
|
|
218
|
+
|
|
219
|
+
template = []
|
|
220
|
+
for input in inputs:
|
|
221
|
+
name = get_name(input.name)
|
|
222
|
+
template.append(
|
|
223
|
+
" ".join(
|
|
224
|
+
["input", name, "0", str(len(input.outputs))] + [get_name(output.name) for output in input.outputs]
|
|
225
|
+
)
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
for node in nodes:
|
|
229
|
+
if node.op != "Constant":
|
|
230
|
+
name = get_name(node.name)
|
|
231
|
+
feeds = node.feeds
|
|
232
|
+
users = node.users
|
|
233
|
+
template.append(
|
|
234
|
+
" ".join(
|
|
235
|
+
[node.op, name, str(len(feeds)), str(len(users))]
|
|
236
|
+
+ ["?" if isinstance(feed, Constant) else get_name(feed.name) for feed in feeds]
|
|
237
|
+
+ ["?" if isinstance(user, Constant) else get_name(user.name) for user in users]
|
|
238
|
+
)
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
for output in outputs:
|
|
242
|
+
name = get_name(output.name)
|
|
243
|
+
template.append(
|
|
244
|
+
" ".join(
|
|
245
|
+
["output", name, str(len(output.inputs)), "0"] + [get_name(input.name) for input in output.inputs]
|
|
246
|
+
)
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
return "\n".join(template)
|