onnxslim 0.1.80__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/__init__.py +16 -0
- onnxslim/__main__.py +4 -0
- onnxslim/argparser.py +215 -0
- onnxslim/cli/__init__.py +1 -0
- onnxslim/cli/_main.py +180 -0
- onnxslim/core/__init__.py +219 -0
- onnxslim/core/optimization/__init__.py +146 -0
- onnxslim/core/optimization/dead_node_elimination.py +151 -0
- onnxslim/core/optimization/subexpression_elimination.py +76 -0
- onnxslim/core/optimization/weight_tying.py +59 -0
- onnxslim/core/pattern/__init__.py +249 -0
- onnxslim/core/pattern/elimination/__init__.py +5 -0
- onnxslim/core/pattern/elimination/concat.py +61 -0
- onnxslim/core/pattern/elimination/reshape.py +77 -0
- onnxslim/core/pattern/elimination/reshape_as.py +64 -0
- onnxslim/core/pattern/elimination/slice.py +108 -0
- onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
- onnxslim/core/pattern/fusion/__init__.py +8 -0
- onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
- onnxslim/core/pattern/fusion/convadd.py +70 -0
- onnxslim/core/pattern/fusion/convbn.py +86 -0
- onnxslim/core/pattern/fusion/convmul.py +69 -0
- onnxslim/core/pattern/fusion/gelu.py +47 -0
- onnxslim/core/pattern/fusion/gemm.py +330 -0
- onnxslim/core/pattern/fusion/padconv.py +89 -0
- onnxslim/core/pattern/fusion/reduce.py +67 -0
- onnxslim/core/pattern/registry.py +28 -0
- onnxslim/misc/__init__.py +0 -0
- onnxslim/misc/tabulate.py +2681 -0
- onnxslim/third_party/__init__.py +0 -0
- onnxslim/third_party/_sympy/__init__.py +0 -0
- onnxslim/third_party/_sympy/functions.py +205 -0
- onnxslim/third_party/_sympy/numbers.py +397 -0
- onnxslim/third_party/_sympy/printers.py +491 -0
- onnxslim/third_party/_sympy/solve.py +172 -0
- onnxslim/third_party/_sympy/symbol.py +102 -0
- onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
- onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
- onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
- onnxslim/third_party/symbolic_shape_infer.py +3273 -0
- onnxslim/utils.py +794 -0
- onnxslim/version.py +1 -0
- onnxslim-0.1.80.dist-info/METADATA +207 -0
- onnxslim-0.1.80.dist-info/RECORD +65 -0
- onnxslim-0.1.80.dist-info/WHEEL +5 -0
- onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
- onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
- onnxslim-0.1.80.dist-info/top_level.txt +1 -0
- onnxslim-0.1.80.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# mypy: allow-untyped-defs
|
|
2
|
+
"""
|
|
3
|
+
This file contains canonical definitions for our symbol naming conventions,
|
|
4
|
+
across torch.fx.experimental.symbolic_shapes and torch._inductor. The
|
|
5
|
+
intention is:
|
|
6
|
+
|
|
7
|
+
1. To make it easily greppable where all the sites we use a prefix are
|
|
8
|
+
2. Make it possible to easily tell if we can introduce a new prefix without
|
|
9
|
+
introducing a conflict
|
|
10
|
+
|
|
11
|
+
You can occasionally test if prefixes have been hardcoded by renaming prefixes
|
|
12
|
+
in this file and seeing what breaks.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from collections.abc import Iterable
|
|
18
|
+
from enum import Enum, auto
|
|
19
|
+
|
|
20
|
+
import sympy
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SymT(Enum):
|
|
24
|
+
SIZE = auto()
|
|
25
|
+
FLOAT = auto()
|
|
26
|
+
UNBACKED_INT = auto()
|
|
27
|
+
UNBACKED_FLOAT = auto()
|
|
28
|
+
# Inductor: The intermediates in inner_fn tmp0, one generated per ops call.
|
|
29
|
+
# If one of these shows up in an indexing expression, that means an
|
|
30
|
+
# indirect load is happening.
|
|
31
|
+
TMP = auto()
|
|
32
|
+
# Inductor: Placeholder variable that is later replaced with TMP
|
|
33
|
+
INDIRECT = auto()
|
|
34
|
+
# Inductor: Some size expressions are replaced with a precomputed size ps0
|
|
35
|
+
# which is computed host side, and then directly reused in the kernel, so
|
|
36
|
+
# we don't repeatedly recompute it on device.
|
|
37
|
+
PRECOMPUTED_SIZE = auto()
|
|
38
|
+
# Inductor: An indexing variable i0 in loops IR which ranges over non-reduced
|
|
39
|
+
# dim in the loop
|
|
40
|
+
INDEX = auto()
|
|
41
|
+
# Inductor: A reduction indexing (r0, r1) variables in loops IR which ranges over
|
|
42
|
+
# reduced dim(s) in the loop
|
|
43
|
+
R0_INDEX = auto()
|
|
44
|
+
R1_INDEX = auto()
|
|
45
|
+
# Inductor: In templated kernels torch._inductor.kernel, we have a hook to
|
|
46
|
+
# store the final output and append epilogue fusions. To do this, we must
|
|
47
|
+
# know what the indexes the outputs range over. NB: These will also
|
|
48
|
+
# advertise as INDEX, this is... probably OK?
|
|
49
|
+
TEMPLATE_INDEX = auto()
|
|
50
|
+
# Inductor: iteration domain for blockIdx.x/blockIdx.y
|
|
51
|
+
XBLOCK = auto()
|
|
52
|
+
YBLOCK = auto()
|
|
53
|
+
ZBLOCK = auto()
|
|
54
|
+
# Inductor: this is used solely for dynamic_reshape_indexer
|
|
55
|
+
VIEW = auto()
|
|
56
|
+
# Alternate (non-modular) indexing used in halide kernels
|
|
57
|
+
HALIDE = auto()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# Invariant: there must not be a prefix which is a prefix of another string,
|
|
61
|
+
# as this introduces ambiguity
|
|
62
|
+
prefix_str = {
|
|
63
|
+
SymT.SIZE: "s", # integer
|
|
64
|
+
SymT.UNBACKED_INT: "u", # integer
|
|
65
|
+
# Prefix z here is chosen to avoid false aliasing in symbol_is_type test
|
|
66
|
+
# DO NOT add a "z" type. You also need to avoid conflicts on these
|
|
67
|
+
# prefixes but this is somewhat easier to manage
|
|
68
|
+
SymT.FLOAT: "zf",
|
|
69
|
+
SymT.UNBACKED_FLOAT: "zuf",
|
|
70
|
+
SymT.TMP: "tmp",
|
|
71
|
+
SymT.PRECOMPUTED_SIZE: "ps",
|
|
72
|
+
SymT.INDEX: "i",
|
|
73
|
+
SymT.R0_INDEX: "r0_",
|
|
74
|
+
SymT.R1_INDEX: "r1_",
|
|
75
|
+
SymT.TEMPLATE_INDEX: "idx",
|
|
76
|
+
SymT.XBLOCK: "x",
|
|
77
|
+
SymT.YBLOCK: "y",
|
|
78
|
+
SymT.ZBLOCK: "z",
|
|
79
|
+
SymT.INDIRECT: "indirect", # false aliasing?
|
|
80
|
+
SymT.VIEW: "view",
|
|
81
|
+
SymT.HALIDE: "h",
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol:
|
|
86
|
+
# TODO: maybe put the assumptions here directly
|
|
87
|
+
return sympy.Symbol(f"{prefix_str[prefix]}{idx}", **kwargs)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# This type is a little wider than it should be, because free_symbols says
|
|
91
|
+
# that it contains Basic, rather than Symbol
|
|
92
|
+
def symbol_is_type(sym: sympy.Basic, prefix: SymT | Iterable[SymT]) -> bool:
|
|
93
|
+
assert isinstance(sym, sympy.Symbol)
|
|
94
|
+
name_str = sym.name.lower() # Match capitalized names like XBLOCK, RBLOCK
|
|
95
|
+
if isinstance(prefix, SymT):
|
|
96
|
+
return name_str.startswith(prefix_str[prefix])
|
|
97
|
+
else:
|
|
98
|
+
return name_str.startswith(tuple(prefix_str[p] for p in prefix))
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def free_symbol_is_type(e: sympy.Expr, prefix: SymT | Iterable[SymT]) -> bool:
|
|
102
|
+
return any(symbol_is_type(v, prefix) for v in e.free_symbols)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import export_onnx
|
|
2
|
+
from onnxslim.third_party.onnx_graphsurgeon.graph_pattern import (
|
|
3
|
+
GraphPattern,
|
|
4
|
+
PatternMapping,
|
|
5
|
+
)
|
|
6
|
+
from onnxslim.third_party.onnx_graphsurgeon.importers.onnx_importer import import_onnx
|
|
7
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.function import Function
|
|
8
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
|
|
9
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.node import Node
|
|
10
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable
|
|
11
|
+
from onnxslim.third_party.onnx_graphsurgeon.util.exception import (
|
|
12
|
+
OnnxGraphSurgeonException,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__version__ = "0.5.1"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from onnxslim.third_party.onnx_graphsurgeon.exporters.base_exporter import BaseExporter
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
|
|
18
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BaseExporter:
|
|
22
|
+
@staticmethod
|
|
23
|
+
def export_graph(graph: Graph):
|
|
24
|
+
"""
|
|
25
|
+
Export a graph to some destination graph.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
graph (Graph): The source graph to export.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
object: The exported graph. For example, this might be an onnx.GraphProto
|
|
32
|
+
"""
|
|
33
|
+
raise NotImplementedError("BaseExporter is an abstract class")
|
|
@@ -0,0 +1,432 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from collections import OrderedDict
|
|
20
|
+
from collections.abc import Sequence
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
import onnx
|
|
24
|
+
from onnx import IR_VERSION, ModelProto, defs
|
|
25
|
+
|
|
26
|
+
from onnxslim.third_party.onnx_graphsurgeon.exporters.base_exporter import BaseExporter
|
|
27
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.function import Function
|
|
28
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
|
|
29
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.node import Node
|
|
30
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import (
|
|
31
|
+
Constant,
|
|
32
|
+
LazyValues,
|
|
33
|
+
SparseValues,
|
|
34
|
+
Tensor,
|
|
35
|
+
Variable,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER
|
|
39
|
+
from onnxslim.third_party.onnx_graphsurgeon.util import misc
|
|
40
|
+
|
|
41
|
+
from ml_dtypes import bfloat16, float8_e4m3fn
|
|
42
|
+
|
|
43
|
+
def dtype_to_onnx(dtype: np.dtype | onnx.TensorProto.DataType) -> int:
|
|
44
|
+
"""Converts a numpy dtype or ONNX data type to its integer representation."""
|
|
45
|
+
if isinstance(dtype, int):
|
|
46
|
+
return dtype
|
|
47
|
+
return onnx.helper.np_dtype_to_tensor_dtype(np.dtype(dtype))
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def check_duplicate_node_names(nodes: Sequence[Node], level=G_LOGGER.WARNING):
|
|
51
|
+
"""Check if node names are unique and log any duplicates based on the specified severity level."""
|
|
52
|
+
# Note:
|
|
53
|
+
# Empty string or None attribute values are not considered duplicates.
|
|
54
|
+
name_map = {}
|
|
55
|
+
for node in nodes:
|
|
56
|
+
if not node.name:
|
|
57
|
+
continue
|
|
58
|
+
if node.name in name_map:
|
|
59
|
+
msg = f"Found distinct Nodes that share the same name:\n[id: {id(name_map[node.name])}]:\n {name_map[node.name]}---\n[id: {id(node)}]:\n {node}\n"
|
|
60
|
+
G_LOGGER.log(msg, level)
|
|
61
|
+
else:
|
|
62
|
+
name_map[node.name] = node
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def update_import_domains(graph):
|
|
66
|
+
"""Update the import_domains field of a graph to include its ONNX opset and other used non-ONNX domains."""
|
|
67
|
+
# as well as other non-ONNX domains which are used by this graph's nodes.
|
|
68
|
+
# Returns the updated value of the import_domains field.
|
|
69
|
+
|
|
70
|
+
# Add domain of the standard ONNX opset.
|
|
71
|
+
if graph.import_domains is None:
|
|
72
|
+
graph.import_domains = [onnx.helper.make_opsetid("", graph.opset)]
|
|
73
|
+
|
|
74
|
+
# Crawl over all nodes in this graph and its subgraphs, and add the nodes' domains.
|
|
75
|
+
all_used_domains = {node.domain for node in graph.nodes}
|
|
76
|
+
for subgraph in graph.subgraphs(recursive=True):
|
|
77
|
+
all_used_domains |= {n.domain for n in subgraph.nodes}
|
|
78
|
+
all_used_domains.discard(None)
|
|
79
|
+
|
|
80
|
+
# Update self.import_domains with any missing domains.
|
|
81
|
+
current_domains = {opsetid.domain for opsetid in graph.import_domains}
|
|
82
|
+
DEFAULT_CUSTOM_OPSET_VERSION = 1
|
|
83
|
+
for used_domain in all_used_domains:
|
|
84
|
+
if used_domain not in current_domains:
|
|
85
|
+
graph.import_domains.append(onnx.helper.make_opsetid(used_domain, DEFAULT_CUSTOM_OPSET_VERSION))
|
|
86
|
+
current_domains.add(used_domain)
|
|
87
|
+
return graph.import_domains
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def float32_to_bfloat16_uint16(x):
|
|
91
|
+
"""Convert a float32 value to bfloat16 represented as uint16."""
|
|
92
|
+
return bfloat16(x).view(np.uint16)
|
|
93
|
+
|
|
94
|
+
def float32_to_float8e4m3(x):
|
|
95
|
+
"""Convert a float32 value to float8e4m3 represented as uint8."""
|
|
96
|
+
return float8_e4m3fn(x).view(np.uint8)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class NumpyArrayConverter:
|
|
100
|
+
def __init__(self, container, scalar_converter):
|
|
101
|
+
self.func = np.vectorize(scalar_converter, otypes=[container])
|
|
102
|
+
|
|
103
|
+
def __call__(self, arr):
|
|
104
|
+
return self.func(arr)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
_NUMPY_ARRAY_CONVERTERS = {
|
|
108
|
+
onnx.TensorProto.BFLOAT16: NumpyArrayConverter(np.uint16, float32_to_bfloat16_uint16),
|
|
109
|
+
# FP8 in TensorRT supports negative zeros, no infinities
|
|
110
|
+
# See https://onnx.ai/onnx/technical/float8.html#papers
|
|
111
|
+
onnx.TensorProto.FLOAT8E4M3FN: NumpyArrayConverter(np.uint8, float32_to_float8e4m3),
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def constant_to_onnx_tensor(tensor: Constant) -> onnx.TensorProto:
|
|
116
|
+
source_dtype = dtype_to_onnx(tensor.dtype)
|
|
117
|
+
target_dtype = dtype_to_onnx(tensor.export_dtype)
|
|
118
|
+
|
|
119
|
+
if source_dtype != target_dtype:
|
|
120
|
+
source_dtype_str = onnx.helper.tensor_dtype_to_string(source_dtype)
|
|
121
|
+
target_dtype_str = onnx.helper.tensor_dtype_to_string(target_dtype)
|
|
122
|
+
assert source_dtype == onnx.TensorProto.FLOAT, (
|
|
123
|
+
f"Cannot convert onnx dtype {source_dtype_str} to {target_dtype_str}. "
|
|
124
|
+
"Source dtype must be float32 to convert to numpy unsupported dtypes."
|
|
125
|
+
)
|
|
126
|
+
assert target_dtype in _NUMPY_ARRAY_CONVERTERS.keys(), (
|
|
127
|
+
f"Cannot convert onnx dtype {source_dtype_str} to {target_dtype_str}. "
|
|
128
|
+
f"Only float32 to {_NUMPY_ARRAY_CONVERTERS.keys()} is supported."
|
|
129
|
+
)
|
|
130
|
+
arr = _NUMPY_ARRAY_CONVERTERS[target_dtype](tensor.values)
|
|
131
|
+
tensor_raw_bytes = arr.tobytes()
|
|
132
|
+
else:
|
|
133
|
+
tensor_raw_bytes = tensor.values.tobytes()
|
|
134
|
+
|
|
135
|
+
return onnx.helper.make_tensor(
|
|
136
|
+
name=tensor.name,
|
|
137
|
+
data_type=target_dtype,
|
|
138
|
+
dims=tensor.shape,
|
|
139
|
+
vals=tensor_raw_bytes,
|
|
140
|
+
raw=True,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class OnnxExporter(BaseExporter):
|
|
145
|
+
@staticmethod
|
|
146
|
+
def export_tensor_proto(tensor: Constant) -> onnx.TensorProto:
|
|
147
|
+
# Do *not* load LazyValues into an intermediate numpy array - instead, use
|
|
148
|
+
# the original onnx.TensorProto directly.
|
|
149
|
+
if isinstance(tensor._values, LazyValues):
|
|
150
|
+
onnx_tensor = tensor._values.tensor
|
|
151
|
+
onnx_tensor.name = tensor.name
|
|
152
|
+
else:
|
|
153
|
+
onnx_tensor = constant_to_onnx_tensor(tensor)
|
|
154
|
+
|
|
155
|
+
if tensor.data_location is not None:
|
|
156
|
+
onnx_tensor.data_location = tensor.data_location
|
|
157
|
+
return onnx_tensor
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def export_sparse_tensor_proto(tensor: Constant) -> onnx.SparseTensorProto:
|
|
161
|
+
"""Exports a given Constant tensor as an ONNX SparseTensorProto."""
|
|
162
|
+
return tensor._values.tensor
|
|
163
|
+
|
|
164
|
+
@staticmethod
|
|
165
|
+
def export_value_info_proto(tensor: Tensor, do_type_check: bool) -> onnx.ValueInfoProto:
|
|
166
|
+
"""Creates an ONNX ValueInfoProto from a Tensor, optionally checking for dtype information."""
|
|
167
|
+
if do_type_check and tensor.dtype is None:
|
|
168
|
+
G_LOGGER.critical(
|
|
169
|
+
f"Graph input and output tensors must include dtype information. Please set the dtype attribute for: {tensor}"
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
if tensor.dtype is None:
|
|
173
|
+
onnx_tensor = onnx.helper.make_empty_tensor_value_info(tensor.name)
|
|
174
|
+
elif isinstance(tensor, Constant) or tensor.type == "tensor_type":
|
|
175
|
+
onnx_tensor = onnx.helper.make_tensor_value_info(tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape)
|
|
176
|
+
elif tensor.type == "sequence_type":
|
|
177
|
+
onnx_tensor = onnx.helper.make_tensor_sequence_value_info(
|
|
178
|
+
tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape
|
|
179
|
+
)
|
|
180
|
+
elif tensor.type == "sparse_tensor_type":
|
|
181
|
+
onnx_tensor = onnx.helper.make_sparse_tensor_value_info(
|
|
182
|
+
tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape
|
|
183
|
+
)
|
|
184
|
+
return onnx_tensor
|
|
185
|
+
|
|
186
|
+
@staticmethod
|
|
187
|
+
def export_attributes(attrs: dict, subgraph_tensor_map=None) -> list[onnx.AttributeProto]:
|
|
188
|
+
"""Convert function attributes to ONNX AttributeProtos for model export."""
|
|
189
|
+
onnx_attrs: list[onnx.AttributeProto] = []
|
|
190
|
+
for key, val in attrs.items():
|
|
191
|
+
if isinstance(val, Tensor):
|
|
192
|
+
val = OnnxExporter.export_tensor_proto(val)
|
|
193
|
+
elif isinstance(val, Graph):
|
|
194
|
+
# Subgraphs don't need to have types specified for their tensors.
|
|
195
|
+
graph = onnx.GraphProto()
|
|
196
|
+
OnnxExporter.export_graph(graph, val, subgraph_tensor_map=subgraph_tensor_map, do_type_check=False)
|
|
197
|
+
val = graph
|
|
198
|
+
elif isinstance(val, Node.AttributeRef):
|
|
199
|
+
onnx_attr = onnx.AttributeProto()
|
|
200
|
+
onnx_attr.name = key
|
|
201
|
+
onnx_attr.type = misc.convert_to_onnx_attr_type(val.type)
|
|
202
|
+
|
|
203
|
+
# Netron has a bug which makes it crash if a Tensor attribute has no tensor data.
|
|
204
|
+
# So provide some meaningless tensor data for Netron to read.
|
|
205
|
+
if val.type == Tensor:
|
|
206
|
+
tensor_proto = OnnxExporter.export_tensor_proto(Constant("", np.array([0], dtype=np.float32)))
|
|
207
|
+
onnx_attr.t.CopyFrom(tensor_proto)
|
|
208
|
+
|
|
209
|
+
onnx_attr.ref_attr_name = val.name
|
|
210
|
+
onnx_attrs.append(onnx_attr)
|
|
211
|
+
continue
|
|
212
|
+
elif isinstance(val, type):
|
|
213
|
+
# May be a numpy type
|
|
214
|
+
try:
|
|
215
|
+
val = dtype_to_onnx(val)
|
|
216
|
+
except TypeError:
|
|
217
|
+
pass
|
|
218
|
+
onnx_attrs.append(onnx.helper.make_attribute(key, val))
|
|
219
|
+
return onnx_attrs
|
|
220
|
+
|
|
221
|
+
@staticmethod
|
|
222
|
+
def export_node(node: Node, subgraph_tensor_map=None) -> onnx.NodeProto:
|
|
223
|
+
# Cannot pass in attrs directly as make_node will change the order
|
|
224
|
+
"""Static method to convert an internal node to an ONNX node representation."""
|
|
225
|
+
onnx_node = onnx.helper.make_node(
|
|
226
|
+
node.op,
|
|
227
|
+
inputs=[t.name for t in node.inputs],
|
|
228
|
+
outputs=[t.name for t in node.outputs],
|
|
229
|
+
name=node.name,
|
|
230
|
+
domain=node.domain,
|
|
231
|
+
)
|
|
232
|
+
onnx_node.attribute.extend(OnnxExporter.export_attributes(node.attrs, subgraph_tensor_map))
|
|
233
|
+
return onnx_node
|
|
234
|
+
|
|
235
|
+
@staticmethod
|
|
236
|
+
def export_function(func: Function) -> onnx.FunctionProto:
|
|
237
|
+
"""
|
|
238
|
+
Export an onnx-graphsurgeon Function to an ONNX FunctionProto.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
func (Function): The function to export.
|
|
242
|
+
"""
|
|
243
|
+
# Unlike onnx Graphs, onnx Functions don't have an 'initializer' field.
|
|
244
|
+
# So we need to replace all Constant tensors with onnx Constant nodes which produce them.
|
|
245
|
+
# We need to be careful to (a) preserve topological ordering and (b) not make the new nodes visible to the user.
|
|
246
|
+
func_nodes = func.nodes.copy()
|
|
247
|
+
new_const_nodes = [
|
|
248
|
+
Node("Constant", attrs={"value": tensor}, outputs=[tensor.copy()])
|
|
249
|
+
for tensor in func.tensors().values()
|
|
250
|
+
if isinstance(tensor, Constant)
|
|
251
|
+
]
|
|
252
|
+
# Const nodes have no inputs, so this maintains a topological ordering.
|
|
253
|
+
func_nodes = new_const_nodes + func_nodes
|
|
254
|
+
|
|
255
|
+
check_duplicate_node_names(func_nodes, level=G_LOGGER.WARNING)
|
|
256
|
+
nodes = [OnnxExporter.export_node(node) for node in func_nodes]
|
|
257
|
+
|
|
258
|
+
# Update the import_domains field to include all domains used by this function.
|
|
259
|
+
opset_imports = update_import_domains(func)
|
|
260
|
+
|
|
261
|
+
onnx_inputs = [inp.name for inp in func.inputs]
|
|
262
|
+
onnx_outputs = [out.name for out in func.outputs]
|
|
263
|
+
|
|
264
|
+
attributes = []
|
|
265
|
+
attribute_protos = {}
|
|
266
|
+
for attr_name, default_val in func.attrs.items():
|
|
267
|
+
if default_val is None:
|
|
268
|
+
attributes.append(attr_name)
|
|
269
|
+
else:
|
|
270
|
+
attribute_protos[attr_name] = default_val
|
|
271
|
+
attribute_protos = OnnxExporter.export_attributes(attribute_protos)
|
|
272
|
+
|
|
273
|
+
return onnx.helper.make_function(
|
|
274
|
+
func.domain or "",
|
|
275
|
+
func.name,
|
|
276
|
+
onnx_inputs,
|
|
277
|
+
onnx_outputs,
|
|
278
|
+
nodes,
|
|
279
|
+
opset_imports,
|
|
280
|
+
attributes=attributes,
|
|
281
|
+
attribute_protos=attribute_protos,
|
|
282
|
+
doc_string=func.doc_string,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
@staticmethod
|
|
286
|
+
def export_graph(
|
|
287
|
+
graph_proto: onnx.GraphProto,
|
|
288
|
+
graph: Graph,
|
|
289
|
+
tensor_map: OrderedDict[str, Tensor] | None = None,
|
|
290
|
+
subgraph_tensor_map: OrderedDict[str, Tensor] | None = None,
|
|
291
|
+
do_type_check=True,
|
|
292
|
+
) -> None:
|
|
293
|
+
"""
|
|
294
|
+
Export an onnx-graphsurgeon Graph to an ONNX GraphProto.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
graph (Graph): The graph to export.
|
|
298
|
+
|
|
299
|
+
do_type_check (bool): Whether to check that input and output tensors have data types defined, and fail if not.
|
|
300
|
+
Defaults to True.
|
|
301
|
+
"""
|
|
302
|
+
check_duplicate_node_names(graph.nodes, level=G_LOGGER.WARNING)
|
|
303
|
+
nodes = [OnnxExporter.export_node(node, subgraph_tensor_map) for node in graph.nodes]
|
|
304
|
+
inputs = [OnnxExporter.export_value_info_proto(inp, do_type_check) for inp in graph.inputs]
|
|
305
|
+
outputs = [OnnxExporter.export_value_info_proto(out, do_type_check) for out in graph.outputs]
|
|
306
|
+
if tensor_map is None:
|
|
307
|
+
tensor_map = graph.tensors()
|
|
308
|
+
tensor_map = misc.unique_dicts(tensor_map, subgraph_tensor_map)
|
|
309
|
+
else:
|
|
310
|
+
tensor_map = misc.combine_dicts(tensor_map, subgraph_tensor_map)
|
|
311
|
+
initializer = [
|
|
312
|
+
OnnxExporter.export_tensor_proto(tensor)
|
|
313
|
+
for tensor in tensor_map.values()
|
|
314
|
+
if isinstance(tensor, Constant) and not isinstance(tensor._values, SparseValues)
|
|
315
|
+
]
|
|
316
|
+
|
|
317
|
+
sparse_initializer = [
|
|
318
|
+
OnnxExporter.export_sparse_tensor_proto(tensor)
|
|
319
|
+
for tensor in tensor_map.values()
|
|
320
|
+
if isinstance(tensor, Constant) and isinstance(tensor._values, SparseValues)
|
|
321
|
+
]
|
|
322
|
+
|
|
323
|
+
# Remove inputs and outputs to export ValueInfoProtos
|
|
324
|
+
for tensor in graph.inputs + graph.outputs:
|
|
325
|
+
if tensor.name in tensor_map:
|
|
326
|
+
del tensor_map[tensor.name]
|
|
327
|
+
|
|
328
|
+
# Omit tensors from value_info if we don't know their shape/dtype
|
|
329
|
+
def has_value_info(tensor):
|
|
330
|
+
"""Check if a tensor is a Variable with either a defined dtype or shape."""
|
|
331
|
+
return isinstance(tensor, Variable) and (tensor.dtype is not None or tensor.shape is not None)
|
|
332
|
+
|
|
333
|
+
value_info = [
|
|
334
|
+
OnnxExporter.export_value_info_proto(tensor, do_type_check)
|
|
335
|
+
for tensor in tensor_map.values()
|
|
336
|
+
if has_value_info(tensor)
|
|
337
|
+
]
|
|
338
|
+
|
|
339
|
+
if initializer is None:
|
|
340
|
+
initializer = []
|
|
341
|
+
if sparse_initializer is None:
|
|
342
|
+
sparse_initializer = []
|
|
343
|
+
if value_info is None:
|
|
344
|
+
value_info = []
|
|
345
|
+
|
|
346
|
+
graph_proto.node.extend(nodes)
|
|
347
|
+
graph_proto.name = graph.name
|
|
348
|
+
graph_proto.input.extend(inputs)
|
|
349
|
+
graph_proto.output.extend(outputs)
|
|
350
|
+
for init in initializer:
|
|
351
|
+
graph_proto.initializer.add().CopyFrom(init)
|
|
352
|
+
for sparse in sparse_initializer:
|
|
353
|
+
graph_proto.sparse_initializer.add().CopyFrom(sparse)
|
|
354
|
+
graph_proto.value_info.extend(value_info)
|
|
355
|
+
if graph.doc_string:
|
|
356
|
+
graph.doc_string = graph.doc_string
|
|
357
|
+
|
|
358
|
+
return graph_proto
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def export_onnx(graph: Graph, do_type_check=True, **kwargs) -> onnx.ModelProto:
|
|
362
|
+
"""
|
|
363
|
+
Exports an onnx-graphsurgeon Graph to an ONNX model.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
graph (Graph): The graph to export
|
|
367
|
+
|
|
368
|
+
do_type_check (bool): Whether to check that input and output tensors have data types defined, and fail if not.
|
|
369
|
+
Defaults to True.
|
|
370
|
+
kwargs: Additional arguments to onnx.helper.make_model
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
onnx.ModelProto: A corresponding ONNX model.
|
|
374
|
+
"""
|
|
375
|
+
sub_graphs = graph.subgraphs(recursive=True)
|
|
376
|
+
|
|
377
|
+
graph_constants_list = [
|
|
378
|
+
{name: tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, Constant)}
|
|
379
|
+
for sub_graph in sub_graphs
|
|
380
|
+
]
|
|
381
|
+
|
|
382
|
+
if not graph_constants_list:
|
|
383
|
+
intersection = None
|
|
384
|
+
else:
|
|
385
|
+
intersection = (
|
|
386
|
+
{
|
|
387
|
+
key: graph_constants_list[0][key]
|
|
388
|
+
for key in graph_constants_list[0]
|
|
389
|
+
if all(key in d and graph_constants_list[0][key] == d[key] for d in graph_constants_list[1:])
|
|
390
|
+
}
|
|
391
|
+
if graph_constants_list
|
|
392
|
+
else None
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
model = ModelProto() # create in advance to avoid unnecessary copy
|
|
396
|
+
OnnxExporter.export_graph(
|
|
397
|
+
model.graph, graph, tensor_map=graph.tensors(), subgraph_tensor_map=intersection, do_type_check=do_type_check
|
|
398
|
+
)
|
|
399
|
+
onnx_functions = [OnnxExporter.export_function(func) for func in graph.functions]
|
|
400
|
+
kwargs["functions"] = onnx_functions
|
|
401
|
+
|
|
402
|
+
if "opset_imports" not in kwargs:
|
|
403
|
+
kwargs["opset_imports"] = update_import_domains(graph)
|
|
404
|
+
|
|
405
|
+
if "ir_version" not in kwargs and graph.ir_version is not None:
|
|
406
|
+
kwargs["ir_version"] = graph.ir_version
|
|
407
|
+
else:
|
|
408
|
+
model.ir_version = IR_VERSION
|
|
409
|
+
|
|
410
|
+
opset_imports = None
|
|
411
|
+
opset_imports = kwargs.pop("opset_imports", None) # type: ignore
|
|
412
|
+
if opset_imports is not None:
|
|
413
|
+
model.opset_import.extend(opset_imports)
|
|
414
|
+
else:
|
|
415
|
+
# Default import
|
|
416
|
+
imp = model.opset_import.add()
|
|
417
|
+
imp.version = defs.onnx_opset_version()
|
|
418
|
+
|
|
419
|
+
functions = None
|
|
420
|
+
functions = kwargs.pop("functions", None) # type: ignore
|
|
421
|
+
if functions is not None:
|
|
422
|
+
model.functions.extend(functions)
|
|
423
|
+
|
|
424
|
+
for k, v in kwargs.items():
|
|
425
|
+
# TODO: Does this work with repeated fields?
|
|
426
|
+
setattr(model, k, v)
|
|
427
|
+
|
|
428
|
+
if graph.metadata_props is not None:
|
|
429
|
+
model.metadata_props.extend(graph.metadata_props)
|
|
430
|
+
model.producer_name = graph.producer_name
|
|
431
|
+
model.producer_version = graph.producer_version
|
|
432
|
+
return model
|