onnxslim 0.1.80__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. onnxslim/__init__.py +16 -0
  2. onnxslim/__main__.py +4 -0
  3. onnxslim/argparser.py +215 -0
  4. onnxslim/cli/__init__.py +1 -0
  5. onnxslim/cli/_main.py +180 -0
  6. onnxslim/core/__init__.py +219 -0
  7. onnxslim/core/optimization/__init__.py +146 -0
  8. onnxslim/core/optimization/dead_node_elimination.py +151 -0
  9. onnxslim/core/optimization/subexpression_elimination.py +76 -0
  10. onnxslim/core/optimization/weight_tying.py +59 -0
  11. onnxslim/core/pattern/__init__.py +249 -0
  12. onnxslim/core/pattern/elimination/__init__.py +5 -0
  13. onnxslim/core/pattern/elimination/concat.py +61 -0
  14. onnxslim/core/pattern/elimination/reshape.py +77 -0
  15. onnxslim/core/pattern/elimination/reshape_as.py +64 -0
  16. onnxslim/core/pattern/elimination/slice.py +108 -0
  17. onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
  18. onnxslim/core/pattern/fusion/__init__.py +8 -0
  19. onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
  20. onnxslim/core/pattern/fusion/convadd.py +70 -0
  21. onnxslim/core/pattern/fusion/convbn.py +86 -0
  22. onnxslim/core/pattern/fusion/convmul.py +69 -0
  23. onnxslim/core/pattern/fusion/gelu.py +47 -0
  24. onnxslim/core/pattern/fusion/gemm.py +330 -0
  25. onnxslim/core/pattern/fusion/padconv.py +89 -0
  26. onnxslim/core/pattern/fusion/reduce.py +67 -0
  27. onnxslim/core/pattern/registry.py +28 -0
  28. onnxslim/misc/__init__.py +0 -0
  29. onnxslim/misc/tabulate.py +2681 -0
  30. onnxslim/third_party/__init__.py +0 -0
  31. onnxslim/third_party/_sympy/__init__.py +0 -0
  32. onnxslim/third_party/_sympy/functions.py +205 -0
  33. onnxslim/third_party/_sympy/numbers.py +397 -0
  34. onnxslim/third_party/_sympy/printers.py +491 -0
  35. onnxslim/third_party/_sympy/solve.py +172 -0
  36. onnxslim/third_party/_sympy/symbol.py +102 -0
  37. onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
  38. onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
  39. onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
  40. onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
  41. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
  42. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
  43. onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
  44. onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
  45. onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
  46. onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
  47. onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
  48. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
  49. onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
  50. onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
  51. onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
  52. onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
  53. onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
  54. onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
  55. onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
  56. onnxslim/third_party/symbolic_shape_infer.py +3273 -0
  57. onnxslim/utils.py +794 -0
  58. onnxslim/version.py +1 -0
  59. onnxslim-0.1.80.dist-info/METADATA +207 -0
  60. onnxslim-0.1.80.dist-info/RECORD +65 -0
  61. onnxslim-0.1.80.dist-info/WHEEL +5 -0
  62. onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
  63. onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
  64. onnxslim-0.1.80.dist-info/top_level.txt +1 -0
  65. onnxslim-0.1.80.dist-info/zip-safe +1 -0
@@ -0,0 +1,146 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from collections import Counter
5
+ from typing import List, Optional, Union
6
+
7
+ import onnx
8
+
9
+ import onnxslim.third_party.onnx_graphsurgeon as gs
10
+ from onnxslim.core.pattern.registry import get_fusion_patterns
11
+ from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
12
+
13
+ logger = logging.getLogger("onnxslim")
14
+
15
+ from .dead_node_elimination import dead_node_elimination
16
+ from .subexpression_elimination import subexpression_elimination
17
+ from .weight_tying import tie_weights
18
+
19
+
20
+ class OptimizationSettings:
21
+ constant_folding = True
22
+ graph_fusion = True
23
+ dead_node_elimination = True
24
+ subexpression_elimination = True
25
+ weight_tying = True
26
+
27
+ @classmethod
28
+ def keys(cls):
29
+ return [
30
+ "constant_folding",
31
+ "graph_fusion",
32
+ "dead_node_elimination",
33
+ "subexpression_elimination",
34
+ "weight_tying",
35
+ ]
36
+
37
+ @classmethod
38
+ def reset(cls, skip_optimizations: list[str] | None = None):
39
+ for key in cls.keys():
40
+ if skip_optimizations and key in skip_optimizations:
41
+ setattr(cls, key, False)
42
+ else:
43
+ setattr(cls, key, True)
44
+
45
+ @classmethod
46
+ def stats(cls):
47
+ return {key: getattr(cls, key) for key in cls.keys()}
48
+
49
+ @classmethod
50
+ def enabled(cls):
51
+ return any([getattr(cls, key) for key in cls.keys()])
52
+
53
+
54
+ def optimize_model(model: onnx.ModelProto | gs.Graph, skip_fusion_patterns: str | None = None) -> onnx.ModelProto:
55
+ """Optimize and transform the given ONNX model using various fusion patterns and graph rewriting techniques."""
56
+ graph = model if isinstance(model, gs.Graph) else gs.import_onnx(model)
57
+ if OptimizationSettings.graph_fusion:
58
+ logger.debug("Start graph_fusion.")
59
+ fusion_patterns = get_fusion_patterns(skip_fusion_patterns)
60
+ graph_fusion(graph, fusion_patterns)
61
+ logger.debug("Finish graph_fusion.")
62
+ if OptimizationSettings.dead_node_elimination:
63
+ logger.debug("Start dead_node_elimination.")
64
+ dead_node_elimination(graph)
65
+ graph.cleanup(remove_unused_graph_inputs=True).toposort()
66
+ logger.debug("Finish dead_node_elimination.")
67
+ if OptimizationSettings.subexpression_elimination:
68
+ logger.debug("Start subexpression_elimination.")
69
+ subexpression_elimination(graph)
70
+ graph.cleanup(remove_unused_graph_inputs=True).toposort()
71
+ logger.debug("Finish subexpression_elimination.")
72
+ if OptimizationSettings.weight_tying:
73
+ logger.debug("Start weight_tying.")
74
+ tie_weights(graph)
75
+ logger.debug("Finish weight_tying.")
76
+ model = gs.export_onnx(graph)
77
+
78
+ return model
79
+
80
+
81
+ @gs.Graph.register()
82
+ def replace_custom_layer(
83
+ self,
84
+ op: str,
85
+ inputs,
86
+ outputs: list[str],
87
+ name: str,
88
+ attrs: dict | None = None,
89
+ domain: str = "ai.onnx.contrib",
90
+ ):
91
+ """Replace a custom layer in the computational graph with specified parameters and domain."""
92
+ return self.layer(
93
+ op=op,
94
+ inputs=inputs,
95
+ outputs=outputs,
96
+ name=name,
97
+ attrs=attrs,
98
+ domain=domain,
99
+ )
100
+
101
+
102
+ def graph_fusion(graph: Graph, fusion_patterns: dict, is_subgraph=False):
103
+ for subgraph in graph.subgraphs():
104
+ graph_fusion(subgraph, fusion_patterns, is_subgraph=True)
105
+
106
+ fusion_pairs = find_matches(graph, fusion_patterns)
107
+ for match in fusion_pairs.values():
108
+ graph.replace_custom_layer(**match)
109
+
110
+ graph.cleanup(remove_unused_graph_inputs=True if not is_subgraph else False).toposort()
111
+
112
+
113
+ def find_matches(graph: Graph, fusion_patterns: dict):
114
+ """Find matching patterns in the graph based on provided fusion patterns."""
115
+ match_map = {}
116
+
117
+ counter = Counter()
118
+ for node in reversed(graph.nodes):
119
+ if node.name not in match_map:
120
+ for layer_type, pattern_matcher in fusion_patterns.items():
121
+ match = pattern_matcher.match(node)
122
+ if match:
123
+ match_case = pattern_matcher.rewrite(opset=graph.opset)
124
+ logger.debug(f"matched pattern {layer_type}")
125
+ for _, match in match_case.items():
126
+ if "op" not in match:
127
+ match.update({"op": layer_type})
128
+ if "name" not in match:
129
+ match.update({"name": f"{layer_type.lower()}_{counter[layer_type]}"})
130
+ counter.update([layer_type])
131
+ match_map.update(match_case)
132
+
133
+ return match_map
134
+
135
+
136
+ def get_previous_node_by_type(node, op_type, trajectory=None):
137
+ """Recursively find and return the first preceding node of a specified type in the computation graph."""
138
+ if trajectory is None:
139
+ trajectory = []
140
+ node_feeds = node.feeds
141
+ for node_feed in node_feeds:
142
+ trajectory.append(node_feed)
143
+ if node_feed.op == op_type:
144
+ return trajectory
145
+ else:
146
+ return get_previous_node_by_type(node_feed, op_type, trajectory)
@@ -0,0 +1,151 @@
1
+ import logging
2
+
3
+ import numpy as np
4
+
5
+ import onnxslim.third_party.onnx_graphsurgeon as gs
6
+ from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx
7
+ from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Variable
8
+
9
+ logger = logging.getLogger("onnxslim")
10
+
11
+
12
+ def dead_node_elimination(graph, is_subgraph=False):
13
+ """Perform in-place constant folding optimizations on the given computational graph by eliminating redundant
14
+ nodes.
15
+ """
16
+ for subgraph in graph.subgraphs():
17
+ dead_node_elimination(subgraph, is_subgraph=True)
18
+
19
+ for node in graph.nodes:
20
+ if node.op in {"Identity", "Dropout"}:
21
+ if not is_subgraph:
22
+ node.erase()
23
+ logger.debug(f"removing {node.op} op: {node.name}")
24
+ elif node.op == "Pad":
25
+ if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant):
26
+ pad_value = node.inputs[1].values.tolist()
27
+ pad_value = pad_value if isinstance(pad_value, list) else [pad_value]
28
+ if all(value == 0 for value in pad_value):
29
+ node.erase()
30
+ logger.debug(f"removing {node.op} op: {node.name}")
31
+ elif node.op == "Cast":
32
+ inp_dtype = next(dtype_to_onnx(input.dtype) for input in node.inputs)
33
+ if inp_dtype == node.attrs["to"]:
34
+ node.erase()
35
+ logger.debug(f"removing {node.op} op: {node.name}")
36
+ elif node.op == "Reshape":
37
+ if (node.inputs[0].shape and len(node.inputs[0].shape) == 1) and (
38
+ node.outputs[0].shape and len(node.outputs[0].shape) == 1
39
+ ):
40
+ node.erase()
41
+ logger.debug(f"removing {node.op} op: {node.name}")
42
+ elif node.inputs[0].shape and node.outputs[0].shape and node.inputs[0].shape == node.outputs[0].shape:
43
+ node.erase()
44
+ logger.debug(f"removing {node.op} op: {node.name}")
45
+ else:
46
+ node_output_shape = node.outputs[0].shape
47
+ if node_output_shape and check_shape(node_output_shape) and not isinstance(node.inputs[1], gs.Constant):
48
+ shapes = [shape if isinstance(shape, int) else -1 for shape in node_output_shape]
49
+ reshape_const = gs.Constant(
50
+ f"{node.inputs[1].name}_",
51
+ values=np.array(shapes, dtype=np.int64),
52
+ )
53
+ node.inputs.pop(1)
54
+ node.inputs.insert(1, reshape_const)
55
+ logger.debug(f"replacing {node.op} op: {node.name}")
56
+ elif node.op == "Mul":
57
+ if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or (
58
+ isinstance(node.inputs[0], Constant) and isinstance(node.inputs[1], Variable)
59
+ ):
60
+ idx, constant_variable = get_constant_variable(node, return_idx=True)
61
+ if np.all(constant_variable.values == 1):
62
+ var_idx = 0 if idx == 1 else 1
63
+ node.erase(var_idx, 0)
64
+ logger.debug(f"removing {node.op} op: {node.name}")
65
+ elif node.op == "Add":
66
+ if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or (
67
+ isinstance(node.inputs[0], Constant) and isinstance(node.inputs[1], Variable)
68
+ ):
69
+ idx, constant_variable = get_constant_variable(node, return_idx=True)
70
+ value = constant_variable.values
71
+ var_idx = 0 if idx == 1 else 1
72
+ if value.ndim == 0 and value == 0:
73
+ node.erase(var_idx, 0)
74
+ logger.debug(f"removing {node.op} op: {node.name}")
75
+ elif np.all(value == 0) and (node.inputs[var_idx].shape == node.outputs[0].shape):
76
+ node.erase(var_idx, 0)
77
+ logger.debug(f"removing {node.op} op: {node.name}")
78
+ elif node.op == "Expand":
79
+ # tests/test_onnx_nets.py::TestTimmClass::test_timm[lambda_resnet26rpt_256]
80
+ if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant):
81
+ constant_variable = node.inputs[1]
82
+ value = constant_variable.values
83
+ if node.inputs[0].shape is not None and node.inputs[0].shape == node.outputs[0].shape:
84
+ node.erase()
85
+ logger.debug(f"removing {node.op} op: {node.name}")
86
+ elif value.ndim == 0 and value == 1:
87
+ node.erase()
88
+ logger.debug(f"removing {node.op} op: {node.name}")
89
+ elif node.op == "Concat":
90
+ if len(node.inputs) == 1:
91
+ node.erase()
92
+ logger.debug(f"removing {node.op} op: {node.name}")
93
+ else:
94
+ for input in node.inputs:
95
+ if isinstance(input, Constant) and input.values.size == 0:
96
+ node.inputs.remove(input)
97
+ elif node.op == "Sub":
98
+ if isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable):
99
+ constant_variable = node.inputs[1]
100
+ value = constant_variable.values
101
+ if value.ndim == 0 and value == 0:
102
+ node.erase()
103
+ logger.debug(f"removing {node.op} op: {node.name}")
104
+ elif np.all(value == 0) and (node.inputs[0].shape == node.outputs[0].shape):
105
+ node.erase()
106
+ logger.debug(f"removing {node.op} op: {node.name}")
107
+ elif node.op == "Div":
108
+ if isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable):
109
+ constant_variable = node.inputs[1]
110
+ value = constant_variable.values
111
+ if value.ndim == 0 and value == 1:
112
+ node.erase()
113
+ logger.debug(f"removing {node.op} op: {node.name}")
114
+ elif np.all(value == 1) and (node.inputs[0].shape == node.outputs[0].shape):
115
+ node.erase()
116
+ logger.debug(f"removing {node.op} op: {node.name}")
117
+ elif node.op == "Split":
118
+ if (
119
+ len(node.outputs) == 1
120
+ and node.outputs[0].shape
121
+ and node.inputs[0].shape
122
+ and node.outputs[0].shape == node.inputs[0].shape
123
+ ):
124
+ node.erase()
125
+ logger.debug(f"removing {node.op} op: {node.name}")
126
+ elif node.op == "Resize":
127
+ mode = node.attrs.get("mode")
128
+ if mode is None:
129
+ node.attrs["mode"] = "nearest"
130
+ logger.debug(f"setting mode to nearest for {node.op} op: {node.name} since it is not set")
131
+
132
+
133
+ def check_shape(shapes):
134
+ """Verify that 'shapes' contains exactly one string and all other elements are positive integers."""
135
+ string_count = 0
136
+ non_negative_int_count = 0
137
+
138
+ for item in shapes:
139
+ if isinstance(item, str):
140
+ string_count += 1
141
+ elif isinstance(item, int) and item > 0:
142
+ non_negative_int_count += 1
143
+
144
+ return (string_count == 1 and non_negative_int_count == len(shapes) - 1) or non_negative_int_count == len(shapes)
145
+
146
+
147
+ def get_constant_variable(node, return_idx=False):
148
+ """Return the first constant variable found in a node's inputs, optionally including the index."""
149
+ for idx, input in enumerate(list(node.inputs)):
150
+ if isinstance(input, Constant):
151
+ return (idx, input) if return_idx else input
@@ -0,0 +1,76 @@
1
+ import logging
2
+
3
+ from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Variable
4
+
5
+ logger = logging.getLogger("onnxslim")
6
+
7
+
8
+ def find_and_remove_replaceable_nodes(nodes):
9
+ """Find and remove duplicate or replaceable nodes in a given list of computational graph nodes."""
10
+
11
+ def get_node_key(node):
12
+ input_names = []
13
+ for input_node in node.inputs:
14
+ if isinstance(input_node, Variable):
15
+ input_names.append(input_node.name)
16
+ return "_".join(input_names) if input_names else None
17
+
18
+ node_dict = {}
19
+ for node in nodes:
20
+ key = get_node_key(node)
21
+ if key:
22
+ if key in node_dict:
23
+ node_dict[key].append(node)
24
+ else:
25
+ node_dict[key] = [node]
26
+
27
+ for key, bucketed_nodes in node_dict.items():
28
+ if len(bucketed_nodes) > 1:
29
+ keep_nodes = [True] * len(bucketed_nodes)
30
+ for i, node in enumerate(bucketed_nodes):
31
+ if keep_nodes[i]:
32
+ for j in range(i + 1, len(bucketed_nodes)):
33
+ if keep_nodes[j] and can_be_replaced(node, bucketed_nodes[j]):
34
+ keep_nodes[j] = False
35
+ existing_node = node
36
+ to_be_removed_node = bucketed_nodes[j]
37
+ to_be_removed_node.replace_all_uses_with(existing_node)
38
+ logger.debug(
39
+ f"Node {to_be_removed_node.name} Op {to_be_removed_node.op} can be replaced by {existing_node.name}"
40
+ )
41
+
42
+
43
+ def sequences_equal(seq1, seq2):
44
+ """Check if two sequences are equal by comparing their lengths and elements."""
45
+ length_match = len(seq1) == len(seq2)
46
+ if not length_match:
47
+ return False
48
+
49
+ return all(elem1 == elem2 for elem1, elem2 in zip(seq1, seq2))
50
+
51
+
52
+ def can_be_replaced(node, other_node):
53
+ """Check if two nodes can be replaced based on their operations, attributes, and inputs."""
54
+ attrs_match = node.op == other_node.op and node.attrs == other_node.attrs
55
+ node_input = [input for input in node.inputs if not input.is_empty()]
56
+ other_node_input = [input for input in other_node.inputs if not input.is_empty()]
57
+ inputs_match = sequences_equal(node_input, other_node_input)
58
+
59
+ return attrs_match and inputs_match
60
+
61
+
62
+ def subexpression_elimination(graph):
63
+ """Perform subexpression elimination on a computational graph to optimize node operations."""
64
+ nodes_by_op = {}
65
+
66
+ for subgraph in graph.subgraphs():
67
+ subexpression_elimination(subgraph)
68
+
69
+ for node in graph.nodes:
70
+ op = node.op
71
+ if op not in nodes_by_op:
72
+ nodes_by_op[op] = []
73
+ nodes_by_op[op].append(node)
74
+
75
+ for nodes in nodes_by_op.values():
76
+ find_and_remove_replaceable_nodes(nodes)
@@ -0,0 +1,59 @@
1
+ import logging
2
+ import os
3
+
4
+ logger = logging.getLogger("onnxslim")
5
+ from collections import defaultdict
6
+
7
+ import numpy as np
8
+
9
+ import onnxslim.third_party.onnx_graphsurgeon as gs
10
+
11
+ THRESHOLD = int(os.getenv("ONNXSLIM_THRESHOLD")) if os.getenv("ONNXSLIM_THRESHOLD") else 1000
12
+
13
+
14
+ def tie_weights(graph):
15
+ """Tie weights in a computational graph to reduce the number of parameters."""
16
+ tensor_map = graph.tensors()
17
+ constant_tensors = [tensor for tensor in tensor_map.values() if isinstance(tensor, gs.Constant)]
18
+
19
+ sub_graphs = graph.subgraphs(recursive=True)
20
+ sub_graphs_constant_tensors = [
21
+ [tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, gs.Constant)]
22
+ for sub_graph in sub_graphs
23
+ ]
24
+
25
+ constant_tensors.extend([tensor for tensors in sub_graphs_constant_tensors for tensor in tensors])
26
+
27
+ constant_by_shape = defaultdict(list)
28
+
29
+ for constant_tensor in constant_tensors:
30
+ shape = tuple(constant_tensor.shape)
31
+ if np.prod(shape) < THRESHOLD:
32
+ constant_by_shape[shape].append(constant_tensor)
33
+
34
+ for nodes in constant_by_shape.values():
35
+ find_and_remove_replaceable_constants(nodes)
36
+
37
+
38
+ def find_and_remove_replaceable_constants(constant_tensors):
39
+ def replace_constant_references(existing_constant, to_be_removed_constant):
40
+ users = list(to_be_removed_constant.outputs)
41
+
42
+ for user in users:
43
+ for idx, inp in enumerate(user.inputs):
44
+ if (inp == to_be_removed_constant) and (inp.name == to_be_removed_constant.name):
45
+ user.inputs.pop(idx)
46
+ user.inputs.insert(idx, existing_constant)
47
+
48
+ if len(constant_tensors) > 1:
49
+ keep_constants = [True] * len(constant_tensors)
50
+ for i, constant_tensor in enumerate(constant_tensors):
51
+ if keep_constants[i]:
52
+ for j in range(i + 1, len(constant_tensors)):
53
+ if keep_constants[j]:
54
+ if constant_tensor == constant_tensors[j]:
55
+ keep_constants[j] = False
56
+ replace_constant_references(constant_tensor, constant_tensors[j])
57
+ logger.debug(
58
+ f"Constant {constant_tensors[j].name} can be replaced by {constant_tensor.name}"
59
+ )
@@ -0,0 +1,249 @@
1
+ import logging
2
+ import re
3
+ from abc import abstractmethod
4
+
5
+ import onnxslim.third_party.onnx_graphsurgeon as gs
6
+ from onnxslim.third_party.onnx_graphsurgeon import Constant
7
+
8
+ logger = logging.getLogger("onnxslim")
9
+
10
+
11
+ def get_name(name):
12
+ """Sanitizes the input string by replacing illegal characters with underscores and prefixing with an underscore if
13
+ numeric.
14
+ """
15
+ _illegal_char_regex = re.compile("[^0-9a-zA-Z_]+")
16
+ sanitized_name = _illegal_char_regex.sub("_", name)
17
+ if sanitized_name.isdigit():
18
+ sanitized_name = f"_{sanitized_name}"
19
+
20
+ return sanitized_name
21
+
22
+
23
+ class NodeDescriptor:
24
+ """
25
+ case 0: input [1, 2, 3, 4, 5] output [0] Optype Name 5 1 i0 i1 i2 i3 i4 o0
26
+ case 1: input [1, ...] output [0] Optype Name 1+ 1 i0 o0
27
+ case 2: input [..., 1, ...] output [0] Optype Name 1* 1 i0 o0.
28
+ """
29
+
30
+ def __init__(self, node_spec):
31
+ """Initialize NodeDescriptor with node_spec list requiring at least 4 elements."""
32
+ if not isinstance(node_spec, list):
33
+ raise ValueError("node_spec must be a list")
34
+ if len(node_spec) < 4:
35
+ raise ValueError(f"node_spec must have at least 4 elements {node_spec}")
36
+
37
+ def get_input_info(io_spec):
38
+ """Parses io_spec to return a tuple of (integer, boolean) indicating the presence of a plus sign in the
39
+ input.
40
+ """
41
+ if not io_spec.isdigit():
42
+ match = re.search(r"(\d+)([+*])", io_spec)
43
+ if match:
44
+ number = match.group(1)
45
+ operator = match.group(2)
46
+
47
+ if not number.isdigit():
48
+ raise ValueError(f"input_num and output_num must be integers {io_spec}")
49
+
50
+ if operator == "+":
51
+ return int(number), True, "append"
52
+ elif operator == "*":
53
+ return int(number), True, "free-match"
54
+ else:
55
+ raise ValueError(f"operator must be + or * {io_spec}")
56
+
57
+ return int(io_spec), False, None
58
+
59
+ self.op = node_spec[0]
60
+ self.name = node_spec[1]
61
+ self.input_num, self.coarse_input_num, self.input_mode = get_input_info(node_spec[2])
62
+ self.output_num, self.coarse_output_num, self.output_mode = get_input_info(node_spec[3])
63
+ self.input_names = node_spec[4 : 4 + self.input_num]
64
+ self.output_names = node_spec[4 + self.input_num :]
65
+ assert len(self.input_names) == self.input_num
66
+ assert len(self.output_names) == self.output_num, f"{self.name} {len(self.output_names)} != {self.output_num}"
67
+
68
+ def __repr__(self):
69
+ """Return a string representation of the object, including its name, operation type, input/output counts, and
70
+ input/output names.
71
+ """
72
+ return f"name: {self.name}, type: {self.op}, input_num: {self.input_num}, output_num: {self.output_num}, input_names: {self.input_names}, output_names: {self.output_names}"
73
+
74
+ def __dict__(self):
75
+ """Returns a dictionary representation of the object, with 'name' as the key."""
76
+ return {
77
+ "name": self,
78
+ }
79
+
80
+
81
+ class Pattern:
82
+ def __init__(self, pattern):
83
+ """Initialize the Pattern class with a given pattern and parse its nodes."""
84
+ self.pattern = pattern
85
+ self.nodes = self.parse_nodes()
86
+
87
+ def parse_nodes(self):
88
+ """Parse pattern into a list of NodeDescriptor objects from non-empty, stripped, and split lines."""
89
+ nodes = self.pattern.split("\n")
90
+ nodes = [line.strip().split() for line in nodes if line]
91
+ nodes = [NodeDescriptor(node) for node in nodes if node]
92
+ return nodes
93
+
94
+ def match(self, node):
95
+ """Match a node against a precompiled pattern."""
96
+ return self.pattern.match(node)
97
+
98
+ def __repr__(self):
99
+ """Return a string representation of the pattern attribute."""
100
+ return self.pattern
101
+
102
+
103
+ class PatternMatcher:
104
+ def __init__(self, pattern, priority):
105
+ """Initialize the PatternMatcher with a given pattern and priority, and prepare node references and output
106
+ names.
107
+ """
108
+ self.pattern = pattern
109
+ self.priority = priority
110
+ self.pattern_dict = {node.name: node for node in pattern.nodes}
111
+ self.output_names = [node.name for node in pattern.nodes if node.op == "output"]
112
+
113
+ def get_match_point(self):
114
+ """Retrieve the match point node from the pattern dictionary based on output node input names."""
115
+ return self.pattern_dict[self.pattern_dict[self.output_names[0]].input_names[0]]
116
+
117
+ def reset_input(self):
118
+ for k, v in self.pattern_dict.items():
119
+ if v.op == "input":
120
+ if hasattr(self, v.name):
121
+ delattr(self, v.name)
122
+
123
+ def match(self, node):
124
+ """Match a given node to a pattern by comparing input names with the match point node from the pattern
125
+ dictionary.
126
+ """
127
+ self.reset_input()
128
+ match_point = self.get_match_point()
129
+
130
+ def match_(node, pattern_node):
131
+ """Match a given node to a pattern by comparing input names with the match point node from the pattern
132
+ dictionary.
133
+ """
134
+ if pattern_node.op == "input":
135
+ if hasattr(self, pattern_node.name):
136
+ if getattr(self, pattern_node.name) == node:
137
+ return True
138
+ else:
139
+ return False
140
+ else:
141
+ setattr(self, pattern_node.name, node)
142
+ return True
143
+ # node is an input variable
144
+ if not hasattr(node, "op"):
145
+ return False
146
+
147
+ if node.op == pattern_node.op:
148
+ setattr(self, pattern_node.name, node)
149
+
150
+ node_feeds = node.feeds
151
+ if pattern_node.coarse_input_num:
152
+ if len(node_feeds) < len(pattern_node.input_names):
153
+ return False
154
+ else:
155
+ if len(node_feeds) != len(pattern_node.input_names):
156
+ return False
157
+
158
+ if pattern_node.input_mode == "append" or pattern_node.input_mode is None:
159
+ pattern_nodes = [
160
+ self.pattern_dict[name] if name != "?" else None for name in pattern_node.input_names
161
+ ]
162
+ all_match = True
163
+ for node_feed, pattern_node in zip(node_feeds, pattern_nodes):
164
+ if pattern_node is not None:
165
+ node_match = match_(node_feed, pattern_node)
166
+ if not node_match:
167
+ return False
168
+ setattr(self, pattern_node.name, node_feed)
169
+
170
+ return all_match
171
+ elif pattern_node.input_mode == "free-match":
172
+ pattern_nodes = [
173
+ self.pattern_dict[name] if name != "?" else None for name in pattern_node.input_names
174
+ ]
175
+ all_match = True
176
+ for pattern_node in pattern_nodes:
177
+ if pattern_node is not None:
178
+ node_match = False
179
+ for node_feed in node_feeds:
180
+ node_match = match_(node_feed, pattern_node)
181
+ if node_match:
182
+ break
183
+ if not node_match:
184
+ return False
185
+ setattr(self, pattern_node.name, node_feed)
186
+
187
+ return all_match
188
+ return False
189
+
190
+ if match_(node, match_point):
191
+ setattr(self, "output", node.outputs)
192
+ if self.parameter_check():
193
+ return True
194
+
195
+ return False
196
+
197
+ @abstractmethod
198
+ def rewrite(self, opset=11):
199
+ """Abstract method to rewrite the graph based on matched patterns, to be implemented by subclasses."""
200
+ raise NotImplementedError("rewrite method must be implemented")
201
+
202
+ def parameter_check(self):
203
+ """Check and validate parameters, returning True if valid."""
204
+ return True
205
+
206
+
207
+ class PatternGenerator:
208
+ def __init__(self, onnx_model):
209
+ """Initialize the PatternGenerator class with an ONNX model and process its graph."""
210
+ self.graph = gs.import_onnx(onnx_model)
211
+ self.graph.fold_constants().cleanup().toposort()
212
+
213
+ def generate(self):
214
+ """Generate the inputs, outputs, and nodes from the graph of the initialized ONNX model."""
215
+ inputs = self.graph.inputs
216
+ outputs = self.graph.outputs
217
+ nodes = self.graph.nodes
218
+
219
+ template = []
220
+ for input in inputs:
221
+ name = get_name(input.name)
222
+ template.append(
223
+ " ".join(
224
+ ["input", name, "0", str(len(input.outputs))] + [get_name(output.name) for output in input.outputs]
225
+ )
226
+ )
227
+
228
+ for node in nodes:
229
+ if node.op != "Constant":
230
+ name = get_name(node.name)
231
+ feeds = node.feeds
232
+ users = node.users
233
+ template.append(
234
+ " ".join(
235
+ [node.op, name, str(len(feeds)), str(len(users))]
236
+ + ["?" if isinstance(feed, Constant) else get_name(feed.name) for feed in feeds]
237
+ + ["?" if isinstance(user, Constant) else get_name(user.name) for user in users]
238
+ )
239
+ )
240
+
241
+ for output in outputs:
242
+ name = get_name(output.name)
243
+ template.append(
244
+ " ".join(
245
+ ["output", name, str(len(output.inputs)), "0"] + [get_name(input.name) for input in output.inputs]
246
+ )
247
+ )
248
+
249
+ return "\n".join(template)
@@ -0,0 +1,5 @@
1
+ from .concat import *
2
+ from .reshape import *
3
+ from .reshape_as import *
4
+ from .slice import *
5
+ from .unsqueeze import *