onnxslim 0.1.71__tar.gz → 0.1.73__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {onnxslim-0.1.71/onnxslim.egg-info → onnxslim-0.1.73}/PKG-INFO +1 -1
- onnxslim-0.1.73/VERSION +1 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/__init__.py +10 -3
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/optimization/__init__.py +2 -2
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/optimization/dead_node_elimination.py +1 -1
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/pattern/__init__.py +15 -2
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/pattern/elimination/slice.py +1 -1
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/pattern/elimination/unsqueeze.py +2 -2
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/pattern/fusion/__init__.py +1 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/pattern/fusion/concat_reshape.py +1 -1
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/pattern/fusion/convadd.py +1 -1
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/pattern/fusion/convbn.py +1 -1
- onnxslim-0.1.73/onnxslim/core/pattern/fusion/convmul.py +69 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/pattern/fusion/gemm.py +27 -5
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/pattern/fusion/padconv.py +3 -3
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/pattern/registry.py +3 -1
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/misc/tabulate.py +10 -8
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/_sympy/functions.py +1 -1
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/_sympy/printers.py +3 -2
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/_sympy/symbol.py +4 -3
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +2 -2
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +1 -1
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +5 -5
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/ir/function.py +9 -9
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +10 -10
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/ir/node.py +7 -7
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +5 -5
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/symbolic_shape_infer.py +111 -111
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/utils.py +6 -7
- onnxslim-0.1.73/onnxslim/version.py +1 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73/onnxslim.egg-info}/PKG-INFO +1 -1
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim.egg-info/SOURCES.txt +1 -0
- onnxslim-0.1.71/VERSION +0 -1
- onnxslim-0.1.71/onnxslim/version.py +0 -1
- {onnxslim-0.1.71 → onnxslim-0.1.73}/LICENSE +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/MANIFEST.in +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/README.md +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/__init__.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/__main__.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/argparser.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/cli/__init__.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/cli/_main.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/optimization/subexpression_elimination.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/optimization/weight_tying.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/pattern/elimination/__init__.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/pattern/elimination/concat.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/pattern/elimination/reshape.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/pattern/elimination/reshape_as.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/pattern/fusion/gelu.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/core/pattern/fusion/reduce.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/misc/__init__.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/__init__.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/_sympy/__init__.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/_sympy/numbers.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/_sympy/solve.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/__init__.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/util/exception.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim/third_party/onnx_graphsurgeon/util/misc.py +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim.egg-info/dependency_links.txt +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim.egg-info/entry_points.txt +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim.egg-info/requires.txt +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim.egg-info/top_level.txt +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/onnxslim.egg-info/zip-safe +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/pyproject.toml +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/setup.cfg +0 -0
- {onnxslim-0.1.71 → onnxslim-0.1.73}/setup.py +0 -0
onnxslim-0.1.73/VERSION
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
0.1.73
|
|
@@ -1,6 +1,9 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
import os
|
|
3
5
|
import tempfile
|
|
6
|
+
from typing import Optional
|
|
4
7
|
|
|
5
8
|
import numpy as np
|
|
6
9
|
import onnx
|
|
@@ -18,6 +21,7 @@ logger = logging.getLogger("onnxslim")
|
|
|
18
21
|
|
|
19
22
|
DEBUG = bool(os.getenv("ONNXSLIM_DEBUG"))
|
|
20
23
|
AUTO_MERGE = True if os.getenv("ONNXSLIM_AUTO_MERGE") is None else bool(int(os.getenv("ONNXSLIM_AUTO_MERGE")))
|
|
24
|
+
FORCE_ONNXRUNTIME_SHAPE_INFERENCE = bool(os.getenv("ONNXSLIM_FORCE_ONNXRUNTIME_SHAPE_INFERENCE"))
|
|
21
25
|
|
|
22
26
|
|
|
23
27
|
def input_shape_modification(model: onnx.ModelProto, input_shapes: str) -> onnx.ModelProto:
|
|
@@ -122,6 +126,9 @@ def input_modification(model: onnx.ModelProto, inputs: str) -> onnx.ModelProto:
|
|
|
122
126
|
def shape_infer(model: onnx.ModelProto):
|
|
123
127
|
"""Infer tensor shapes in an ONNX model using symbolic and static shape inference techniques."""
|
|
124
128
|
logger.debug("Start shape inference.")
|
|
129
|
+
if FORCE_ONNXRUNTIME_SHAPE_INFERENCE:
|
|
130
|
+
logger.debug("force onnxruntime shape infer.")
|
|
131
|
+
return SymbolicShapeInference.infer_shapes(model, auto_merge=AUTO_MERGE)
|
|
125
132
|
try:
|
|
126
133
|
logger.debug("try onnxruntime shape infer.")
|
|
127
134
|
model = SymbolicShapeInference.infer_shapes(model, auto_merge=AUTO_MERGE)
|
|
@@ -142,7 +149,7 @@ def shape_infer(model: onnx.ModelProto):
|
|
|
142
149
|
return model
|
|
143
150
|
|
|
144
151
|
|
|
145
|
-
def optimize(model: onnx.ModelProto, skip_fusion_patterns: str = None, size_threshold: int = None):
|
|
152
|
+
def optimize(model: onnx.ModelProto, skip_fusion_patterns: str | None = None, size_threshold: int | None = None):
|
|
146
153
|
"""Optimize the given ONNX model with options to skip specific fusion patterns and return the optimized model."""
|
|
147
154
|
logger.debug("Start converting model to gs.")
|
|
148
155
|
graph = gs.import_onnx(model).toposort()
|
|
@@ -171,11 +178,11 @@ def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto:
|
|
|
171
178
|
|
|
172
179
|
for node in graph.nodes:
|
|
173
180
|
if node.op == "Cast":
|
|
174
|
-
inp_dtype =
|
|
181
|
+
inp_dtype = next(input.dtype for input in node.inputs)
|
|
175
182
|
if inp_dtype in [np.float16, np.float32]:
|
|
176
183
|
node.erase()
|
|
177
184
|
else:
|
|
178
|
-
outp_dtype =
|
|
185
|
+
outp_dtype = next(output.dtype for output in node.outputs)
|
|
179
186
|
if outp_dtype == np.float16:
|
|
180
187
|
node.attrs["to"] = dtype_to_onnx(np.float32)
|
|
181
188
|
node.outputs[0].dtype = np.float32
|
|
@@ -51,7 +51,7 @@ class OptimizationSettings:
|
|
|
51
51
|
return any([getattr(cls, key) for key in cls.keys()])
|
|
52
52
|
|
|
53
53
|
|
|
54
|
-
def optimize_model(model: onnx.ModelProto | gs.Graph, skip_fusion_patterns: str = None) -> onnx.ModelProto:
|
|
54
|
+
def optimize_model(model: onnx.ModelProto | gs.Graph, skip_fusion_patterns: str | None = None) -> onnx.ModelProto:
|
|
55
55
|
"""Optimize and transform the given ONNX model using various fusion patterns and graph rewriting techniques."""
|
|
56
56
|
graph = model if isinstance(model, gs.Graph) else gs.import_onnx(model)
|
|
57
57
|
if OptimizationSettings.graph_fusion:
|
|
@@ -85,7 +85,7 @@ def replace_custom_layer(
|
|
|
85
85
|
inputs,
|
|
86
86
|
outputs: list[str],
|
|
87
87
|
name: str,
|
|
88
|
-
attrs: dict = None,
|
|
88
|
+
attrs: dict | None = None,
|
|
89
89
|
domain: str = "ai.onnx.contrib",
|
|
90
90
|
):
|
|
91
91
|
"""Replace a custom layer in the computational graph with specified parameters and domain."""
|
|
@@ -29,7 +29,7 @@ def dead_node_elimination(graph, is_subgraph=False):
|
|
|
29
29
|
node.erase()
|
|
30
30
|
logger.debug(f"removing {node.op} op: {node.name}")
|
|
31
31
|
elif node.op == "Cast":
|
|
32
|
-
inp_dtype =
|
|
32
|
+
inp_dtype = next(dtype_to_onnx(input.dtype) for input in node.inputs)
|
|
33
33
|
if inp_dtype == node.attrs["to"]:
|
|
34
34
|
node.erase()
|
|
35
35
|
logger.debug(f"removing {node.op} op: {node.name}")
|
|
@@ -114,10 +114,17 @@ class PatternMatcher:
|
|
|
114
114
|
"""Retrieve the match point node from the pattern dictionary based on output node input names."""
|
|
115
115
|
return self.pattern_dict[self.pattern_dict[self.output_names[0]].input_names[0]]
|
|
116
116
|
|
|
117
|
+
def reset_input(self):
|
|
118
|
+
for k, v in self.pattern_dict.items():
|
|
119
|
+
if v.op == "input":
|
|
120
|
+
if hasattr(self, v.name):
|
|
121
|
+
delattr(self, v.name)
|
|
122
|
+
|
|
117
123
|
def match(self, node):
|
|
118
124
|
"""Match a given node to a pattern by comparing input names with the match point node from the pattern
|
|
119
125
|
dictionary.
|
|
120
126
|
"""
|
|
127
|
+
self.reset_input()
|
|
121
128
|
match_point = self.get_match_point()
|
|
122
129
|
|
|
123
130
|
def match_(node, pattern_node):
|
|
@@ -125,8 +132,14 @@ class PatternMatcher:
|
|
|
125
132
|
dictionary.
|
|
126
133
|
"""
|
|
127
134
|
if pattern_node.op == "input":
|
|
128
|
-
|
|
129
|
-
|
|
135
|
+
if hasattr(self, pattern_node.name):
|
|
136
|
+
if getattr(self, pattern_node.name) == node:
|
|
137
|
+
return True
|
|
138
|
+
else:
|
|
139
|
+
return False
|
|
140
|
+
else:
|
|
141
|
+
setattr(self, pattern_node.name, node)
|
|
142
|
+
return True
|
|
130
143
|
# node is an input variable
|
|
131
144
|
if not hasattr(node, "op"):
|
|
132
145
|
return False
|
|
@@ -58,7 +58,7 @@ class SlicePatternMatcher(PatternMatcher):
|
|
|
58
58
|
inputs = []
|
|
59
59
|
inputs.extend(
|
|
60
60
|
(
|
|
61
|
-
|
|
61
|
+
next(iter(first_slice_node.inputs)),
|
|
62
62
|
gs.Constant(
|
|
63
63
|
second_slice_node_inputs[1].name + "_starts",
|
|
64
64
|
values=np.array(new_starts, dtype=np.int64),
|
|
@@ -50,14 +50,14 @@ class UnsqueezePatternMatcher(PatternMatcher):
|
|
|
50
50
|
axis + sum(1 for axis_ in axes_node_unsqueeze_1 if axis_ <= axis) for axis in axes_node_unsqueeze_0
|
|
51
51
|
]
|
|
52
52
|
|
|
53
|
-
inputs = [
|
|
53
|
+
inputs = [next(iter(node_unsqueeze_0.inputs))]
|
|
54
54
|
outputs = list(node_unsqueeze_1.outputs)
|
|
55
55
|
|
|
56
56
|
index = node_unsqueeze_1.inputs.index(node_unsqueeze_0.outputs[0])
|
|
57
57
|
node_unsqueeze_1.inputs.pop(index)
|
|
58
58
|
for i, item in enumerate(node_unsqueeze_0.inputs):
|
|
59
59
|
node_unsqueeze_1.inputs.insert(index + i, item)
|
|
60
|
-
inputs = [
|
|
60
|
+
inputs = [next(iter(node_unsqueeze_1.inputs))]
|
|
61
61
|
outputs = list(node_unsqueeze_1.outputs)
|
|
62
62
|
node_unsqueeze_1.inputs.clear()
|
|
63
63
|
node_unsqueeze_1.outputs.clear()
|
|
@@ -36,7 +36,7 @@ class ConcatReshapeMatcher(PatternMatcher):
|
|
|
36
36
|
def rewrite(self, opset=11):
|
|
37
37
|
match_case = {}
|
|
38
38
|
concat_node = self.concat_0
|
|
39
|
-
index =
|
|
39
|
+
index = next(idx for idx, i in enumerate(concat_node.inputs) if isinstance(i, gs.Variable))
|
|
40
40
|
constant = gs.Constant(
|
|
41
41
|
concat_node.inputs[index].name + "_fixed",
|
|
42
42
|
values=np.array([-1], dtype=np.int64),
|
|
@@ -41,7 +41,7 @@ class ConvAddMatcher(PatternMatcher):
|
|
|
41
41
|
conv_bias = conv_node.inputs[2].values + node.inputs[1].values.squeeze()
|
|
42
42
|
|
|
43
43
|
inputs = []
|
|
44
|
-
inputs.append(
|
|
44
|
+
inputs.append(next(iter(conv_node.inputs)))
|
|
45
45
|
inputs.append(conv_weight)
|
|
46
46
|
weight_name = list(conv_node.inputs)[1].name
|
|
47
47
|
if weight_name.endswith("weight"):
|
|
@@ -53,7 +53,7 @@ class ConvBatchNormMatcher(PatternMatcher):
|
|
|
53
53
|
conv_b = (conv_transpose_bias - bn_running_mean) * bn_var_rsqrt * bn_scale + bn_bias
|
|
54
54
|
|
|
55
55
|
inputs = []
|
|
56
|
-
inputs.append(
|
|
56
|
+
inputs.append(next(iter(conv_transpose_node.inputs)))
|
|
57
57
|
weight_name = list(conv_transpose_node.inputs)[1].name
|
|
58
58
|
if weight_name.endswith("weight"):
|
|
59
59
|
bias_name = f"{weight_name[:-6]}bias"
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
2
|
+
from onnxslim.core.pattern import Pattern, PatternMatcher
|
|
3
|
+
from onnxslim.core.pattern.registry import register_fusion_pattern
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ConvMulMatcher(PatternMatcher):
|
|
7
|
+
def __init__(self, priority):
|
|
8
|
+
"""Initializes the ConvMulMatcher for fusing Conv and Mul layers in an ONNX graph."""
|
|
9
|
+
pattern = Pattern(
|
|
10
|
+
"""
|
|
11
|
+
input input 0 1 conv_0
|
|
12
|
+
Conv conv_0 1+ 1 input mul_0
|
|
13
|
+
Mul mul_0 2 1 conv_0 ? output
|
|
14
|
+
output output 1 0 mul_0
|
|
15
|
+
"""
|
|
16
|
+
)
|
|
17
|
+
super().__init__(pattern, priority)
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def name(self):
|
|
21
|
+
"""Returns the name of the FusionConvMul pattern."""
|
|
22
|
+
return "FusionConvMul"
|
|
23
|
+
|
|
24
|
+
def rewrite(self, opset=11):
|
|
25
|
+
match_case = {}
|
|
26
|
+
conv_node = self.conv_0
|
|
27
|
+
mul_node = self.mul_0
|
|
28
|
+
conv_weight = list(conv_node.inputs)[1]
|
|
29
|
+
if len(conv_node.users) == 1 and conv_node.users[0] == mul_node and isinstance(mul_node.inputs[1], gs.Constant):
|
|
30
|
+
mul_constant = mul_node.inputs[1].values
|
|
31
|
+
|
|
32
|
+
if mul_constant.squeeze().ndim == 1 and mul_constant.squeeze().shape[0] == conv_weight.shape[0]:
|
|
33
|
+
weight_shape = conv_weight.values.shape
|
|
34
|
+
reshape_shape = [-1] + [1] * (len(weight_shape) - 1)
|
|
35
|
+
|
|
36
|
+
mul_scale_reshaped = mul_constant.squeeze().reshape(reshape_shape)
|
|
37
|
+
new_weight = conv_weight.values * mul_scale_reshaped
|
|
38
|
+
|
|
39
|
+
inputs = []
|
|
40
|
+
inputs.append(next(iter(conv_node.inputs)))
|
|
41
|
+
|
|
42
|
+
weight_name = list(conv_node.inputs)[1].name
|
|
43
|
+
inputs.append(gs.Constant(weight_name, values=new_weight))
|
|
44
|
+
|
|
45
|
+
if len(conv_node.inputs) == 3:
|
|
46
|
+
conv_bias = conv_node.inputs[2].values
|
|
47
|
+
new_bias = conv_bias * mul_constant.squeeze()
|
|
48
|
+
bias_name = list(conv_node.inputs)[2].name
|
|
49
|
+
inputs.append(gs.Constant(bias_name, values=new_bias))
|
|
50
|
+
|
|
51
|
+
outputs = list(mul_node.outputs)
|
|
52
|
+
|
|
53
|
+
conv_node.outputs.clear()
|
|
54
|
+
mul_node.inputs.clear()
|
|
55
|
+
mul_node.outputs.clear()
|
|
56
|
+
|
|
57
|
+
match_case[conv_node.name] = {
|
|
58
|
+
"op": conv_node.op,
|
|
59
|
+
"inputs": inputs,
|
|
60
|
+
"outputs": outputs,
|
|
61
|
+
"name": conv_node.name,
|
|
62
|
+
"attrs": conv_node.attrs,
|
|
63
|
+
"domain": None,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
return match_case
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
register_fusion_pattern(ConvMulMatcher(1))
|
|
@@ -105,7 +105,7 @@ class MatMulAddPatternMatcher(PatternMatcher):
|
|
|
105
105
|
}
|
|
106
106
|
)
|
|
107
107
|
|
|
108
|
-
values = list(input_variable.shape[:-1])
|
|
108
|
+
values = [*list(input_variable.shape[:-1]), matmul_bias_variable.values.shape[-1]]
|
|
109
109
|
post_reshape_const = gs.Constant(
|
|
110
110
|
f"{matmul_node.name}_post_reshape_in",
|
|
111
111
|
values=np.array(values, dtype=np.int64),
|
|
@@ -289,16 +289,38 @@ class GemmAddPatternMatcher(PatternMatcher):
|
|
|
289
289
|
)
|
|
290
290
|
and add_bias_variable
|
|
291
291
|
and len(reshape_node.users) == 1
|
|
292
|
+
and gemm_node.outputs[0].shape
|
|
292
293
|
):
|
|
294
|
+
|
|
295
|
+
def can_broadcast_to(shape_from, shape_to):
|
|
296
|
+
"""Return True if shape_from can broadcast to shape_to per NumPy rules."""
|
|
297
|
+
if shape_from is None or shape_to is None:
|
|
298
|
+
return False
|
|
299
|
+
try:
|
|
300
|
+
np.empty(shape_to, dtype=np.float32) + np.empty(shape_from, dtype=np.float32)
|
|
301
|
+
return True
|
|
302
|
+
except ValueError:
|
|
303
|
+
return False
|
|
304
|
+
|
|
293
305
|
gemm_bias_constant = gemm_node.inputs[2] if len(gemm_node.inputs) == 3 else None
|
|
294
306
|
if gemm_bias_constant:
|
|
295
307
|
gemm_bias = gemm_bias_constant.values
|
|
296
308
|
add_bias = add_bias_variable.values
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
309
|
+
if (
|
|
310
|
+
can_broadcast_to(gemm_bias.shape, gemm_node.outputs[0].shape)
|
|
311
|
+
and can_broadcast_to(add_bias.shape, gemm_node.outputs[0].shape)
|
|
312
|
+
and add_bias.ndim <= 2
|
|
313
|
+
):
|
|
314
|
+
gemm_bias_fused = gemm_bias + add_bias
|
|
315
|
+
gemm_bias_fused_constant = gs.Constant(gemm_bias_constant.name + "_fused", values=gemm_bias_fused)
|
|
316
|
+
gemm_node.inputs[2] = gemm_bias_fused_constant
|
|
317
|
+
else:
|
|
318
|
+
return match_case
|
|
300
319
|
else:
|
|
301
|
-
gemm_node.
|
|
320
|
+
if can_broadcast_to(add_bias_variable.values.shape, gemm_node.outputs[0].shape):
|
|
321
|
+
gemm_node.inputs[2] = add_bias_variable
|
|
322
|
+
else:
|
|
323
|
+
return match_case
|
|
302
324
|
|
|
303
325
|
add_node.replace_all_uses_with(reshape_node)
|
|
304
326
|
|
|
@@ -24,6 +24,7 @@ class PadConvMatcher(PatternMatcher):
|
|
|
24
24
|
def parameter_check(self) -> bool:
|
|
25
25
|
"""Validates if the padding parameter for a convolutional node is a constant."""
|
|
26
26
|
pad_node = self.pad_0
|
|
27
|
+
|
|
27
28
|
return isinstance(pad_node.inputs[1], gs.Constant)
|
|
28
29
|
|
|
29
30
|
def rewrite(self, opset=11):
|
|
@@ -37,13 +38,12 @@ class PadConvMatcher(PatternMatcher):
|
|
|
37
38
|
|
|
38
39
|
pad_inputs = len(pad_node.inputs)
|
|
39
40
|
if pad_inputs < 3 or (
|
|
40
|
-
pad_inputs >= 3
|
|
41
|
-
and (isinstance(pad_node.inputs[2], gs.Constant) and pad_node.inputs[2].values == 0)
|
|
41
|
+
(pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Constant) and pad_node.inputs[2].values == 0))
|
|
42
42
|
or (pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Variable) and pad_node.inputs[2].name == ""))
|
|
43
43
|
):
|
|
44
44
|
if (
|
|
45
45
|
isinstance(pad_node.inputs[1], gs.Constant)
|
|
46
|
-
and pad_node.attrs
|
|
46
|
+
and pad_node.attrs.get("mode", "constant") == "constant"
|
|
47
47
|
and conv_node.inputs[1].shape
|
|
48
48
|
):
|
|
49
49
|
conv_weight_dim = len(conv_node.inputs[1].shape)
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from collections import OrderedDict
|
|
2
4
|
|
|
3
5
|
DEFAULT_FUSION_PATTERNS = OrderedDict()
|
|
@@ -12,7 +14,7 @@ def register_fusion_pattern(fusion_pattern):
|
|
|
12
14
|
DEFAULT_FUSION_PATTERNS[layer_type] = fusion_pattern
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
def get_fusion_patterns(skip_fusion_patterns: str = None):
|
|
17
|
+
def get_fusion_patterns(skip_fusion_patterns: str | None = None):
|
|
16
18
|
"""Returns a copy of the default fusion patterns, optionally excluding specific patterns."""
|
|
17
19
|
default_fusion_patterns = DEFAULT_FUSION_PATTERNS.copy()
|
|
18
20
|
if skip_fusion_patterns:
|
|
@@ -24,7 +24,7 @@ def _is_file(f):
|
|
|
24
24
|
return isinstance(f, io.IOBase)
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
__all__ = ["
|
|
27
|
+
__all__ = ["simple_separated_format", "tabulate", "tabulate_formats"]
|
|
28
28
|
try:
|
|
29
29
|
from .version import version as __version__ # noqa: F401
|
|
30
30
|
except ImportError:
|
|
@@ -1202,7 +1202,7 @@ def _format(val, valtype, floatfmt, intfmt, missingval="", has_invisible=True):
|
|
|
1202
1202
|
if val is None:
|
|
1203
1203
|
return missingval
|
|
1204
1204
|
|
|
1205
|
-
if valtype is str or valtype is not int and valtype is not bytes and valtype is not float:
|
|
1205
|
+
if valtype is str or (valtype is not int and valtype is not bytes and valtype is not float):
|
|
1206
1206
|
return f"{val}"
|
|
1207
1207
|
elif valtype is int:
|
|
1208
1208
|
return format(val, intfmt)
|
|
@@ -1276,7 +1276,7 @@ def _prepend_row_index(rows, index):
|
|
|
1276
1276
|
index_iter = iter(index)
|
|
1277
1277
|
for row in sans_rows:
|
|
1278
1278
|
index_v = next(index_iter)
|
|
1279
|
-
new_rows.append([index_v
|
|
1279
|
+
new_rows.append([index_v, *list(row)])
|
|
1280
1280
|
rows = new_rows
|
|
1281
1281
|
_reinsert_separating_lines(rows, separating_lines)
|
|
1282
1282
|
return rows
|
|
@@ -1325,7 +1325,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"):
|
|
|
1325
1325
|
"""
|
|
1326
1326
|
try:
|
|
1327
1327
|
bool(headers)
|
|
1328
|
-
is_headers2bool_broken = False
|
|
1328
|
+
is_headers2bool_broken = False
|
|
1329
1329
|
except ValueError: # numpy.ndarray, pandas.core.index.Index, ...
|
|
1330
1330
|
is_headers2bool_broken = True # noqa
|
|
1331
1331
|
headers = list(headers)
|
|
@@ -1423,7 +1423,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"):
|
|
|
1423
1423
|
if headers == "firstrow":
|
|
1424
1424
|
if len(rows) > 0:
|
|
1425
1425
|
if index is not None:
|
|
1426
|
-
headers = [index[0]
|
|
1426
|
+
headers = [index[0], *list(rows[0])]
|
|
1427
1427
|
index = index[1:]
|
|
1428
1428
|
else:
|
|
1429
1429
|
headers = rows[0]
|
|
@@ -1606,7 +1606,7 @@ def tabulate(
|
|
|
1606
1606
|
given header. Possible values are: "global" (no override), "same"
|
|
1607
1607
|
(follow column alignment), "right", "center", "left".
|
|
1608
1608
|
|
|
1609
|
-
Note on intended
|
|
1609
|
+
Note on intended behavior: If there is no `tabular_data`, any column
|
|
1610
1610
|
alignment argument is ignored. Hence, in this case, header
|
|
1611
1611
|
alignment cannot be inferred from column alignment.
|
|
1612
1612
|
|
|
@@ -2534,8 +2534,10 @@ class _CustomTextWrap(textwrap.TextWrapper):
|
|
|
2534
2534
|
if (
|
|
2535
2535
|
self.max_lines is None
|
|
2536
2536
|
or len(lines) + 1 < self.max_lines
|
|
2537
|
-
or (
|
|
2538
|
-
|
|
2537
|
+
or (
|
|
2538
|
+
(not chunks or (self.drop_whitespace and len(chunks) == 1 and not chunks[0].strip()))
|
|
2539
|
+
and cur_len <= width
|
|
2540
|
+
)
|
|
2539
2541
|
):
|
|
2540
2542
|
# Convert current line back to a string and store it in
|
|
2541
2543
|
# list of all lines (return value).
|
|
@@ -52,7 +52,7 @@ def simple_floordiv_gcd(p: sympy.Basic, q: sympy.Basic) -> sympy.Basic:
|
|
|
52
52
|
|
|
53
53
|
We try to rewrite p and q in the form n*e*p1 + n*e*p2 and n*e*q0, where n is the greatest common integer factor and
|
|
54
54
|
e is the largest syntactic common factor (i.e., common sub-expression) in p and q. Then the gcd returned is n*e,
|
|
55
|
-
|
|
55
|
+
canceling which we would be left with p1 + p2 and q0.
|
|
56
56
|
|
|
57
57
|
Note that further factoring of p1 + p2 and q0 might be possible with sympy.factor (which uses domain-specific
|
|
58
58
|
theories). E.g., we are unable to find that x*y + x + y + 1 is divisible by x + 1. More generally, when q is of the
|
|
@@ -1,5 +1,6 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import sys
|
|
2
|
-
from typing import Optional
|
|
3
4
|
|
|
4
5
|
import sympy
|
|
5
6
|
from sympy.printing.precedence import PRECEDENCE, precedence
|
|
@@ -22,7 +23,7 @@ class ExprPrinter(StrPrinter):
|
|
|
22
23
|
def _print_Not(self, expr: sympy.Expr) -> str:
|
|
23
24
|
return f"not ({self._print(expr.args[0])})"
|
|
24
25
|
|
|
25
|
-
def _print_Add(self, expr: sympy.Expr, order:
|
|
26
|
+
def _print_Add(self, expr: sympy.Expr, order: str | None = None) -> str:
|
|
26
27
|
return self.stringify(expr.args, " + ", precedence(expr))
|
|
27
28
|
|
|
28
29
|
def _print_Relational(self, expr: sympy.Expr) -> str:
|
|
@@ -12,9 +12,10 @@ You can occasionally test if prefixes have been hardcoded by renaming prefixes
|
|
|
12
12
|
in this file and seeing what breaks.
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
15
17
|
from collections.abc import Iterable
|
|
16
18
|
from enum import Enum, auto
|
|
17
|
-
from typing import Union
|
|
18
19
|
|
|
19
20
|
import sympy
|
|
20
21
|
|
|
@@ -88,7 +89,7 @@ def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol:
|
|
|
88
89
|
|
|
89
90
|
# This type is a little wider than it should be, because free_symbols says
|
|
90
91
|
# that it contains Basic, rather than Symbol
|
|
91
|
-
def symbol_is_type(sym: sympy.Basic, prefix:
|
|
92
|
+
def symbol_is_type(sym: sympy.Basic, prefix: SymT | Iterable[SymT]) -> bool:
|
|
92
93
|
assert isinstance(sym, sympy.Symbol)
|
|
93
94
|
name_str = sym.name.lower() # Match capitalized names like XBLOCK, RBLOCK
|
|
94
95
|
if isinstance(prefix, SymT):
|
|
@@ -97,5 +98,5 @@ def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Iterable[SymT]]) -> boo
|
|
|
97
98
|
return name_str.startswith(tuple(prefix_str[p] for p in prefix))
|
|
98
99
|
|
|
99
100
|
|
|
100
|
-
def free_symbol_is_type(e: sympy.Expr, prefix:
|
|
101
|
+
def free_symbol_is_type(e: sympy.Expr, prefix: SymT | Iterable[SymT]) -> bool:
|
|
101
102
|
return any(symbol_is_type(v, prefix) for v in e.free_symbols)
|
|
@@ -278,8 +278,8 @@ class OnnxExporter(BaseExporter):
|
|
|
278
278
|
def export_graph(
|
|
279
279
|
graph_proto: onnx.GraphProto,
|
|
280
280
|
graph: Graph,
|
|
281
|
-
tensor_map: OrderedDict[str, Tensor] = None,
|
|
282
|
-
subgraph_tensor_map: OrderedDict[str, Tensor] = None,
|
|
281
|
+
tensor_map: OrderedDict[str, Tensor] | None = None,
|
|
282
|
+
subgraph_tensor_map: OrderedDict[str, Tensor] | None = None,
|
|
283
283
|
do_type_check=True,
|
|
284
284
|
) -> None:
|
|
285
285
|
"""
|
|
@@ -82,7 +82,7 @@ class PatternMapping(dict):
|
|
|
82
82
|
def __str__(self) -> str:
|
|
83
83
|
"""Returns a string representation of the pattern mapping, including inputs, outputs, and constants."""
|
|
84
84
|
if self.onnx_node is None:
|
|
85
|
-
return "{" + str.join(", ", [f"{key}: {
|
|
85
|
+
return "{" + str.join(", ", [f"{key}: {value!s}" for key, value in self.items()]) + "}"
|
|
86
86
|
return self.onnx_node.name
|
|
87
87
|
|
|
88
88
|
|
|
@@ -357,7 +357,7 @@ class OnnxImporter(BaseImporter):
|
|
|
357
357
|
@staticmethod
|
|
358
358
|
def import_function(
|
|
359
359
|
onnx_function: onnx.FunctionProto,
|
|
360
|
-
model_opset: int = None,
|
|
360
|
+
model_opset: int | None = None,
|
|
361
361
|
model_import_domains: onnx.OperatorSetIdProto = None,
|
|
362
362
|
) -> Function:
|
|
363
363
|
"""Imports an ONNX function to a Function object using the model opset and import domains."""
|
|
@@ -405,13 +405,13 @@ class OnnxImporter(BaseImporter):
|
|
|
405
405
|
@staticmethod
|
|
406
406
|
def import_graph(
|
|
407
407
|
onnx_graph: onnx.GraphProto,
|
|
408
|
-
tensor_map: OrderedDict[str, Tensor] = None,
|
|
408
|
+
tensor_map: OrderedDict[str, Tensor] | None = None,
|
|
409
409
|
opset=None,
|
|
410
410
|
import_domains: onnx.OperatorSetIdProto = None,
|
|
411
411
|
ir_version=None,
|
|
412
|
-
producer_name: str = None,
|
|
413
|
-
producer_version: str = None,
|
|
414
|
-
functions: list[Function] = None,
|
|
412
|
+
producer_name: str | None = None,
|
|
413
|
+
producer_version: str | None = None,
|
|
414
|
+
functions: list[Function] | None = None,
|
|
415
415
|
metadata_props=None,
|
|
416
416
|
) -> Graph:
|
|
417
417
|
"""
|
|
@@ -43,15 +43,15 @@ class Function(Graph):
|
|
|
43
43
|
def __init__(
|
|
44
44
|
self,
|
|
45
45
|
name: str,
|
|
46
|
-
domain: str = None,
|
|
47
|
-
nodes: Sequence[Node] = None,
|
|
48
|
-
inputs: Sequence[Tensor] = None,
|
|
49
|
-
outputs: Sequence[Tensor] = None,
|
|
50
|
-
doc_string: str = None,
|
|
51
|
-
opset: int = None,
|
|
52
|
-
import_domains: Sequence[onnx.OperatorSetIdProto] = None,
|
|
53
|
-
functions: Sequence[Function] = None,
|
|
54
|
-
attrs: dict = None,
|
|
46
|
+
domain: str | None = None,
|
|
47
|
+
nodes: Sequence[Node] | None = None,
|
|
48
|
+
inputs: Sequence[Tensor] | None = None,
|
|
49
|
+
outputs: Sequence[Tensor] | None = None,
|
|
50
|
+
doc_string: str | None = None,
|
|
51
|
+
opset: int | None = None,
|
|
52
|
+
import_domains: Sequence[onnx.OperatorSetIdProto] | None = None,
|
|
53
|
+
functions: Sequence[Function] | None = None,
|
|
54
|
+
attrs: dict | None = None,
|
|
55
55
|
):
|
|
56
56
|
"""
|
|
57
57
|
Args:
|
|
@@ -100,17 +100,17 @@ class Graph:
|
|
|
100
100
|
|
|
101
101
|
def __init__(
|
|
102
102
|
self,
|
|
103
|
-
nodes: Sequence[Node] = None,
|
|
104
|
-
inputs: Sequence[Tensor] = None,
|
|
105
|
-
outputs: Sequence[Tensor] = None,
|
|
103
|
+
nodes: Sequence[Node] | None = None,
|
|
104
|
+
inputs: Sequence[Tensor] | None = None,
|
|
105
|
+
outputs: Sequence[Tensor] | None = None,
|
|
106
106
|
name=None,
|
|
107
107
|
doc_string=None,
|
|
108
108
|
opset=None,
|
|
109
109
|
import_domains=None,
|
|
110
110
|
ir_version=None,
|
|
111
|
-
producer_name: str = None,
|
|
112
|
-
producer_version: str = None,
|
|
113
|
-
functions: Sequence[Function] = None,
|
|
111
|
+
producer_name: str | None = None,
|
|
112
|
+
producer_version: str | None = None,
|
|
113
|
+
functions: Sequence[Function] | None = None,
|
|
114
114
|
metadata_props=None,
|
|
115
115
|
):
|
|
116
116
|
"""
|
|
@@ -347,7 +347,7 @@ class Graph:
|
|
|
347
347
|
func_ids.add(func.unique_id)
|
|
348
348
|
return self.functions
|
|
349
349
|
|
|
350
|
-
for graph in self.functions
|
|
350
|
+
for graph in [*self.functions, self]:
|
|
351
351
|
for subgraph in graph.subgraphs(recursive=True):
|
|
352
352
|
new_list = absorb_function_list(subgraph.functions)
|
|
353
353
|
subgraph._functions = new_list
|
|
@@ -774,7 +774,7 @@ class Graph:
|
|
|
774
774
|
if len(node.attrs) != 1:
|
|
775
775
|
G_LOGGER.warning("Constant node must contain exactly one attribute")
|
|
776
776
|
continue
|
|
777
|
-
attr_name, attr_val =
|
|
777
|
+
attr_name, attr_val = next(iter(node.attrs.items()))
|
|
778
778
|
allowed_attrs = {
|
|
779
779
|
"value",
|
|
780
780
|
"value_float",
|
|
@@ -975,7 +975,7 @@ class Graph:
|
|
|
975
975
|
|
|
976
976
|
def get_scalar_value(tensor):
|
|
977
977
|
"""Gets the scalar value of a constant tensor with a single item."""
|
|
978
|
-
return
|
|
978
|
+
return next(iter(tensor.values)) if tensor.shape else tensor.values
|
|
979
979
|
|
|
980
980
|
def fold_shape(tensor):
|
|
981
981
|
"""Returns the input tensor shape if available, otherwise returns None."""
|
|
@@ -1409,7 +1409,7 @@ class Graph:
|
|
|
1409
1409
|
self.nodes.append(node)
|
|
1410
1410
|
return node.outputs
|
|
1411
1411
|
|
|
1412
|
-
def copy(self, tensor_map: OrderedDict[str, Tensor] = None):
|
|
1412
|
+
def copy(self, tensor_map: OrderedDict[str, Tensor] | None = None):
|
|
1413
1413
|
"""
|
|
1414
1414
|
Copy the graph.
|
|
1415
1415
|
|
|
@@ -42,11 +42,11 @@ class Node:
|
|
|
42
42
|
def __init__(
|
|
43
43
|
self,
|
|
44
44
|
op: str,
|
|
45
|
-
name: str = None,
|
|
46
|
-
attrs: dict[str, object] = None,
|
|
47
|
-
inputs: list[Tensor] = None,
|
|
48
|
-
outputs: list[Tensor] = None,
|
|
49
|
-
domain: str = None,
|
|
45
|
+
name: str | None = None,
|
|
46
|
+
attrs: dict[str, object] | None = None,
|
|
47
|
+
inputs: list[Tensor] | None = None,
|
|
48
|
+
outputs: list[Tensor] | None = None,
|
|
49
|
+
domain: str | None = None,
|
|
50
50
|
):
|
|
51
51
|
"""
|
|
52
52
|
A node represents an operation in a graph, and consumes zero or more Tensors, and produces zero or more Tensors.
|
|
@@ -152,8 +152,8 @@ class Node:
|
|
|
152
152
|
|
|
153
153
|
def copy(
|
|
154
154
|
self,
|
|
155
|
-
inputs: list[Tensor] = None,
|
|
156
|
-
outputs: list[Tensor] = None,
|
|
155
|
+
inputs: list[Tensor] | None = None,
|
|
156
|
+
outputs: list[Tensor] | None = None,
|
|
157
157
|
tensor_map=None,
|
|
158
158
|
):
|
|
159
159
|
"""
|