onnxslim 0.1.80__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/__init__.py +16 -0
- onnxslim/__main__.py +4 -0
- onnxslim/argparser.py +215 -0
- onnxslim/cli/__init__.py +1 -0
- onnxslim/cli/_main.py +180 -0
- onnxslim/core/__init__.py +219 -0
- onnxslim/core/optimization/__init__.py +146 -0
- onnxslim/core/optimization/dead_node_elimination.py +151 -0
- onnxslim/core/optimization/subexpression_elimination.py +76 -0
- onnxslim/core/optimization/weight_tying.py +59 -0
- onnxslim/core/pattern/__init__.py +249 -0
- onnxslim/core/pattern/elimination/__init__.py +5 -0
- onnxslim/core/pattern/elimination/concat.py +61 -0
- onnxslim/core/pattern/elimination/reshape.py +77 -0
- onnxslim/core/pattern/elimination/reshape_as.py +64 -0
- onnxslim/core/pattern/elimination/slice.py +108 -0
- onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
- onnxslim/core/pattern/fusion/__init__.py +8 -0
- onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
- onnxslim/core/pattern/fusion/convadd.py +70 -0
- onnxslim/core/pattern/fusion/convbn.py +86 -0
- onnxslim/core/pattern/fusion/convmul.py +69 -0
- onnxslim/core/pattern/fusion/gelu.py +47 -0
- onnxslim/core/pattern/fusion/gemm.py +330 -0
- onnxslim/core/pattern/fusion/padconv.py +89 -0
- onnxslim/core/pattern/fusion/reduce.py +67 -0
- onnxslim/core/pattern/registry.py +28 -0
- onnxslim/misc/__init__.py +0 -0
- onnxslim/misc/tabulate.py +2681 -0
- onnxslim/third_party/__init__.py +0 -0
- onnxslim/third_party/_sympy/__init__.py +0 -0
- onnxslim/third_party/_sympy/functions.py +205 -0
- onnxslim/third_party/_sympy/numbers.py +397 -0
- onnxslim/third_party/_sympy/printers.py +491 -0
- onnxslim/third_party/_sympy/solve.py +172 -0
- onnxslim/third_party/_sympy/symbol.py +102 -0
- onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
- onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
- onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
- onnxslim/third_party/symbolic_shape_infer.py +3273 -0
- onnxslim/utils.py +794 -0
- onnxslim/version.py +1 -0
- onnxslim-0.1.80.dist-info/METADATA +207 -0
- onnxslim-0.1.80.dist-info/RECORD +65 -0
- onnxslim-0.1.80.dist-info/WHEEL +5 -0
- onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
- onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
- onnxslim-0.1.80.dist-info/top_level.txt +1 -0
- onnxslim-0.1.80.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,504 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from collections.abc import Sequence
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER
|
|
24
|
+
from onnxslim.third_party.onnx_graphsurgeon.util import misc
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Tensor:
|
|
28
|
+
"""Abstract base class for tensors in a graph."""
|
|
29
|
+
|
|
30
|
+
DYNAMIC = -1
|
|
31
|
+
|
|
32
|
+
def __init__(self):
|
|
33
|
+
"""**This class is abstract and cannot be constructed directly.**."""
|
|
34
|
+
raise NotImplementedError("Tensor is an abstract class")
|
|
35
|
+
|
|
36
|
+
def __setattr__(self, name, value):
|
|
37
|
+
"""Set an attribute, ensuring special handling for "inputs" and "outputs" properties."""
|
|
38
|
+
if name in {"inputs", "outputs"}:
|
|
39
|
+
try:
|
|
40
|
+
attr = getattr(self, name)
|
|
41
|
+
if value is attr:
|
|
42
|
+
# This can happen when using things like +=
|
|
43
|
+
# The __iadd__ is executed followed by an assignment
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
attr.clear()
|
|
47
|
+
attr.extend(value)
|
|
48
|
+
except AttributeError:
|
|
49
|
+
super().__setattr__(name, value)
|
|
50
|
+
else:
|
|
51
|
+
super().__setattr__(name, value)
|
|
52
|
+
|
|
53
|
+
def is_empty(self):
|
|
54
|
+
"""
|
|
55
|
+
Returns whether this tensor is considered empty in the graph.
|
|
56
|
+
|
|
57
|
+
*Note: 'Empty' here refers to the name of the tensor, which is omitted for
|
|
58
|
+
optional tensors, NOT the shape of the tensor*
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
bool: Whether the tensor is empty, meaning that it is used for an omitted optional input or output.
|
|
62
|
+
"""
|
|
63
|
+
return self.name == ""
|
|
64
|
+
|
|
65
|
+
def to_constant(
|
|
66
|
+
self,
|
|
67
|
+
values: np.ndarray,
|
|
68
|
+
data_location: int | None = None,
|
|
69
|
+
export_dtype: np.dtype | onnx.TensorProto.DataType = None,
|
|
70
|
+
):
|
|
71
|
+
"""
|
|
72
|
+
Modifies this tensor in-place to convert it to a Constant. This means that all consumers/producers of the tensor
|
|
73
|
+
will see the update.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
values (np.ndarray): The values in this tensor
|
|
77
|
+
|
|
78
|
+
data_location (int):
|
|
79
|
+
An enum value indicating the location where the tensor data is stored.
|
|
80
|
+
Generally, this will come from onnx.TensorProto.DataLocation.
|
|
81
|
+
|
|
82
|
+
dtype (Union[numpy.dtype, onnx.TensorProto.DataType]): The data type of the tensor.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
self
|
|
86
|
+
"""
|
|
87
|
+
self.__class__ = Constant
|
|
88
|
+
self._values = values
|
|
89
|
+
self.data_location = data_location
|
|
90
|
+
self.export_dtype = export_dtype
|
|
91
|
+
|
|
92
|
+
return self
|
|
93
|
+
|
|
94
|
+
def to_variable(self, dtype: np.dtype | onnx.TensorProto.DataType = None, shape: Sequence[int | str] | None = None):
|
|
95
|
+
"""
|
|
96
|
+
Modifies this tensor in-place to convert it to a Variable. This means that all consumers/producers of the tensor
|
|
97
|
+
will see the update.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
dtype (Union[numpy.dtype, onnx.TensorProto.DataType]): The data type of the tensor.
|
|
101
|
+
shape (Sequence[int]): The shape of the tensor.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
self
|
|
105
|
+
"""
|
|
106
|
+
if shape is None:
|
|
107
|
+
shape = []
|
|
108
|
+
variable_dtype = dtype if dtype is not None else self.export_dtype
|
|
109
|
+
|
|
110
|
+
self.__class__ = Variable
|
|
111
|
+
self.shape = shape
|
|
112
|
+
self.dtype = variable_dtype
|
|
113
|
+
|
|
114
|
+
return self
|
|
115
|
+
|
|
116
|
+
def i(self, tensor_idx=0, producer_idx=0):
|
|
117
|
+
"""
|
|
118
|
+
Convenience function to get an input tensor of one of this tensor's input nodes. Note that the parameters are
|
|
119
|
+
swapped compared to the o() function; this is because tensors are likely to have only a single producer.
|
|
120
|
+
|
|
121
|
+
For example:
|
|
122
|
+
::
|
|
123
|
+
|
|
124
|
+
assert tensor.i() == tensor.inputs[0].inputs[0]
|
|
125
|
+
assert tensor.i(1, 2) == tensor.inputs[2].inputs[1]
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
tensor_idx (int): The index of the input tensor of the input node. Defaults to 0.
|
|
129
|
+
producer_idx (int): The index of the producer node of the input tensor, if the tensor has multiple producers. Defaults to 0.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Tensor: The specified producer (input) tensor.
|
|
133
|
+
"""
|
|
134
|
+
return self.inputs[producer_idx].inputs[tensor_idx]
|
|
135
|
+
|
|
136
|
+
def o(self, consumer_idx=0, tensor_idx=0):
|
|
137
|
+
"""
|
|
138
|
+
Convenience function to get an output tensor of one of this tensor's output nodes.
|
|
139
|
+
|
|
140
|
+
For example:
|
|
141
|
+
::
|
|
142
|
+
|
|
143
|
+
assert tensor.o() == tensor.outputs[0].outputs[0]
|
|
144
|
+
assert tensor.o(2, 1) == tensor.outputs[2].outputs[1]
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
consumer_idx (int): The index of the consumer of the input tensor. Defaults to 0.
|
|
148
|
+
tensor_idx (int): The index of the output tensor of the node, if the node has multiple outputs. Defaults to 0.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Tensor: The specified consumer (output) tensor
|
|
152
|
+
"""
|
|
153
|
+
return self.outputs[consumer_idx].outputs[tensor_idx]
|
|
154
|
+
|
|
155
|
+
def __str__(self):
|
|
156
|
+
"""Returns a string representation of the object including its type, name, shape, and data type."""
|
|
157
|
+
return f"{type(self).__name__} ({self.name}): (shape={self.shape}, dtype={self.dtype})"
|
|
158
|
+
|
|
159
|
+
def __repr__(self): # Hack to make logging output pretty.
|
|
160
|
+
"""Returns a string representation of the object for logging output."""
|
|
161
|
+
return self.__str__()
|
|
162
|
+
|
|
163
|
+
def __eq__(self, other):
|
|
164
|
+
"""
|
|
165
|
+
Perform a check to see if two tensors are equal.
|
|
166
|
+
|
|
167
|
+
Tensors are considered equal if they share the same name. A Graph must not include Tensors with duplicate names.
|
|
168
|
+
"""
|
|
169
|
+
return self.name == other.name
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def is_input(self):
|
|
173
|
+
"""Indicates whether this tensor is an input tensor in the graph."""
|
|
174
|
+
return self._is_input if hasattr(self, "_is_input") else False
|
|
175
|
+
|
|
176
|
+
@is_input.setter
|
|
177
|
+
def is_input(self, is_input: bool = False):
|
|
178
|
+
"""Indicates whether this tensor is an input tensor in the graph."""
|
|
179
|
+
self._is_input = is_input
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def is_output(self):
|
|
183
|
+
"""Indicates if tensor is marked as an output within the computational graph."""
|
|
184
|
+
return self._is_output if hasattr(self, "_is_output") else False
|
|
185
|
+
|
|
186
|
+
@is_output.setter
|
|
187
|
+
def is_output(self, is_output: bool = False):
|
|
188
|
+
"""Indicates if the tensor is used as an output in the graph."""
|
|
189
|
+
self._is_output = is_output
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class Variable(Tensor):
|
|
193
|
+
@staticmethod
|
|
194
|
+
def empty():
|
|
195
|
+
"""Create and return an empty Variable tensor with an empty name."""
|
|
196
|
+
return Variable(name="")
|
|
197
|
+
|
|
198
|
+
def __init__(
|
|
199
|
+
self,
|
|
200
|
+
name: str,
|
|
201
|
+
dtype: np.dtype | onnx.TensorProto.DataType = None,
|
|
202
|
+
shape: Sequence[int | str] | None = None,
|
|
203
|
+
type: str = "tensor_type",
|
|
204
|
+
):
|
|
205
|
+
"""
|
|
206
|
+
Represents a Tensor whose value is not known until inference-time.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
name (str): The name of the tensor.
|
|
210
|
+
dtype (Union[numpy.dtype, onnx.TensorProto.DataType]): The data type of the tensor.
|
|
211
|
+
shape (Sequence[Union[int, str]]): The shape of the tensor. This may contain strings if the model uses dimension parameters.
|
|
212
|
+
type (str): The type of the tensor.
|
|
213
|
+
"""
|
|
214
|
+
self.name = name
|
|
215
|
+
self.inputs = misc.SynchronizedList(self, field_name="outputs", initial=[])
|
|
216
|
+
self.outputs = misc.SynchronizedList(self, field_name="inputs", initial=[])
|
|
217
|
+
self.dtype = dtype
|
|
218
|
+
self.shape = misc.default_value(shape, None)
|
|
219
|
+
self.type = type
|
|
220
|
+
|
|
221
|
+
def to_constant(
|
|
222
|
+
self,
|
|
223
|
+
values: np.ndarray,
|
|
224
|
+
export_dtype: np.dtype | onnx.TensorProto.DataType = None,
|
|
225
|
+
):
|
|
226
|
+
"""Converts the Variable to a Constant with given values and optional export data type."""
|
|
227
|
+
del self.dtype
|
|
228
|
+
del self.shape
|
|
229
|
+
|
|
230
|
+
return super().to_constant(values, export_dtype=export_dtype)
|
|
231
|
+
|
|
232
|
+
def copy(self):
|
|
233
|
+
"""
|
|
234
|
+
Makes a shallow copy of this tensor, omitting input and output information.
|
|
235
|
+
|
|
236
|
+
Note: Generally, you should only ever make a copy of a Graph.
|
|
237
|
+
"""
|
|
238
|
+
return Variable(self.name, self.dtype, self.shape, self.type)
|
|
239
|
+
|
|
240
|
+
def __eq__(self, other):
|
|
241
|
+
"""Perform a check to see if two variables are equal."""
|
|
242
|
+
if not isinstance(other, Variable):
|
|
243
|
+
return False
|
|
244
|
+
|
|
245
|
+
name_match = self.name == other.name
|
|
246
|
+
inputs_match = len(self.inputs) == len(other.inputs) and all(
|
|
247
|
+
inp.name == other_inp.name for inp, other_inp in zip(self.inputs, other.inputs)
|
|
248
|
+
)
|
|
249
|
+
outputs_match = len(self.outputs) == len(other.outputs) and all(
|
|
250
|
+
out.name == other_out.name for out, other_out in zip(self.outputs, other.outputs)
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
dtype_match = self.dtype == other.dtype
|
|
254
|
+
shape_match = self.shape == other.shape
|
|
255
|
+
type_match = self.type == other.type
|
|
256
|
+
|
|
257
|
+
return name_match and inputs_match and outputs_match and dtype_match and shape_match and type_match
|
|
258
|
+
|
|
259
|
+
def replace_all_uses_with(self, var: Variable):
|
|
260
|
+
# replace all the uses of this variable with another variable
|
|
261
|
+
for feed_node in list(self.inputs):
|
|
262
|
+
for i, input_var in enumerate(feed_node.outputs):
|
|
263
|
+
if input_var == self:
|
|
264
|
+
feed_node.outputs[i] = var
|
|
265
|
+
|
|
266
|
+
for user_node in list(self.outputs): # find all the nodes that use this variable as input
|
|
267
|
+
for i, input_var in enumerate(user_node.inputs): # iterate
|
|
268
|
+
if input_var == self:
|
|
269
|
+
user_node.inputs[i] = var
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
class LazyValues:
|
|
273
|
+
"""A special object that represents constant tensor values that should be lazily loaded."""
|
|
274
|
+
|
|
275
|
+
def __init__(self, tensor):
|
|
276
|
+
"""
|
|
277
|
+
Args:
|
|
278
|
+
tensor (onnx.TensorProto, onnx.SparseTensorProto): The ONNX tensor that this instance should lazily load.
|
|
279
|
+
"""
|
|
280
|
+
from onnxslim.third_party.onnx_graphsurgeon.importers.onnx_importer import (
|
|
281
|
+
get_itemsize,
|
|
282
|
+
get_onnx_tensor_dtype,
|
|
283
|
+
get_onnx_tensor_shape,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
self.tensor = tensor
|
|
287
|
+
self.shape = get_onnx_tensor_shape(self.tensor)
|
|
288
|
+
self.dtype = get_onnx_tensor_dtype(self.tensor)
|
|
289
|
+
self.nbytes = misc.volume(self.shape) * get_itemsize(self.dtype)
|
|
290
|
+
|
|
291
|
+
def load(self):
|
|
292
|
+
"""
|
|
293
|
+
Load a numpy array from the underlying tensor values.
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
np.array: A numpy array containing the values of the tensor.
|
|
297
|
+
"""
|
|
298
|
+
import onnx
|
|
299
|
+
import onnx.numpy_helper
|
|
300
|
+
|
|
301
|
+
from onnxslim.third_party.onnx_graphsurgeon.importers.onnx_importer import (
|
|
302
|
+
get_dtype_name,
|
|
303
|
+
get_numpy_type,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
if get_numpy_type(self.dtype) is None:
|
|
307
|
+
G_LOGGER.warning(
|
|
308
|
+
f"Datatype: {get_dtype_name(self.dtype)} could not be converted to a NumPy type.\n"
|
|
309
|
+
f"Accessing the values of this constant tensor ({self.tensor.name}) will cause them to be casted to a supported data type. "
|
|
310
|
+
f"This means that the weights will have a different type than the original model when they are exported again!\n"
|
|
311
|
+
f"If this is not what you intended, please avoid accessing the values of this constant tensor."
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
return np.array(onnx.numpy_helper.to_array(self.tensor))
|
|
315
|
+
|
|
316
|
+
def __str__(self):
|
|
317
|
+
"""Returns a formatted string representation of the LazyValues object indicating its shape and dtype."""
|
|
318
|
+
return f"LazyValues (shape={self.shape}, dtype={self.dtype})"
|
|
319
|
+
|
|
320
|
+
def __repr__(self): # Hack to make logging output pretty.
|
|
321
|
+
"""Returns an unambiguous string representation of the LazyValues object for logging purposes."""
|
|
322
|
+
return self.__str__()
|
|
323
|
+
|
|
324
|
+
def __eq__(self, other):
|
|
325
|
+
"""Perform a check to see if two variables are equal."""
|
|
326
|
+
if not isinstance(other, LazyValues):
|
|
327
|
+
return False
|
|
328
|
+
|
|
329
|
+
for field in self.tensor.DESCRIPTOR.fields:
|
|
330
|
+
if field.name == "name":
|
|
331
|
+
continue
|
|
332
|
+
if getattr(self.tensor, field.name) != getattr(other.tensor, field.name):
|
|
333
|
+
return False
|
|
334
|
+
|
|
335
|
+
return True
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
class SparseValues(LazyValues):
|
|
339
|
+
"""A special object that represents constant tensor values that is sparse."""
|
|
340
|
+
|
|
341
|
+
def load(self):
|
|
342
|
+
"""
|
|
343
|
+
Load a numpy array from the sparse structure.
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
np.array: A numpy array containing the values of the tensor.
|
|
347
|
+
"""
|
|
348
|
+
import onnx
|
|
349
|
+
import onnx.numpy_helper
|
|
350
|
+
|
|
351
|
+
supported_index_type = [onnx.TensorProto.INT64]
|
|
352
|
+
if self.tensor.indices.data_type not in supported_index_type:
|
|
353
|
+
G_LOGGER.critical(
|
|
354
|
+
f"Unsupported index data type {self.tensor.indices.data_type} in {self.tensor.values.name}"
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
if self.tensor.values.data_type == onnx.TensorProto.FLOAT16:
|
|
358
|
+
values_data = np.asarray(self.tensor.values.int32_data, dtype=np.uint16).view(np.float16)
|
|
359
|
+
else:
|
|
360
|
+
field_name = onnx.helper.tensor_dtype_to_field(self.tensor.values.data_type)
|
|
361
|
+
values = getattr(self.tensor.values, field_name)
|
|
362
|
+
dtype = onnx.helper.tensor_dtype_to_np_dtype(self.tensor.values.data_type)
|
|
363
|
+
values_data = np.asarray(values, dtype)
|
|
364
|
+
indices_data = self.tensor.indices.int64_data
|
|
365
|
+
|
|
366
|
+
if len(self.tensor.indices.dims) == 1:
|
|
367
|
+
values = np.zeros(np.prod(self.tensor.dims))
|
|
368
|
+
# [NNZ] layout, in which case the i-th value must be the linearized-index of the i-th value.
|
|
369
|
+
values[indices_data] = values_data
|
|
370
|
+
values = values.reshape(self.tensor.dims)
|
|
371
|
+
elif len(self.tensor.indices.dims) == 2:
|
|
372
|
+
# [NNZ, rank] with the [i,j]-th value corresponding to the j-th index of the i-th value
|
|
373
|
+
values = np.zeros(self.tensor.dims)
|
|
374
|
+
indices_data = np.asarray(indices_data).reshape(self.tensor.indices.dims)
|
|
375
|
+
|
|
376
|
+
for value_data, index_data in zip(values_data, indices_data):
|
|
377
|
+
values[tuple(index_data)] = value_data
|
|
378
|
+
else:
|
|
379
|
+
G_LOGGER.critical(f"Unsupported index data dims {self.tensor.indices.dims} in {self.tensor.values.name}")
|
|
380
|
+
|
|
381
|
+
return values
|
|
382
|
+
|
|
383
|
+
def __str__(self):
|
|
384
|
+
"""Return a string representation of the SparseValues object with its shape and data type."""
|
|
385
|
+
return f"SparseValues (shape={self.shape}, dtype={self.dtype})"
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
class Constant(Tensor):
|
|
389
|
+
def __init__(
|
|
390
|
+
self,
|
|
391
|
+
name: str,
|
|
392
|
+
values: np.ndarray | LazyValues,
|
|
393
|
+
data_location: int | None = None,
|
|
394
|
+
export_dtype: np.dtype | onnx.TensorProto.DataType = None,
|
|
395
|
+
):
|
|
396
|
+
"""
|
|
397
|
+
Represents a Tensor whose value is known.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
name (str): The name of the tensor.
|
|
401
|
+
values (numpy.ndarray): The values in this tensor, in the form of a NumPy array.
|
|
402
|
+
|
|
403
|
+
data_location (int):
|
|
404
|
+
An enum value indicating the location where the tensor data is stored.
|
|
405
|
+
Generally, this will come from onnx.TensorProto.DataLocation.
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
export_dtype (Union[np.dtype, onnx.TensorProto.DataType]):
|
|
409
|
+
The data type of the tensor when exported to onnx. If not specified, then
|
|
410
|
+
the data type of values will be used.
|
|
411
|
+
"""
|
|
412
|
+
self.name = name
|
|
413
|
+
self.inputs = misc.SynchronizedList(self, field_name="outputs", initial=[])
|
|
414
|
+
self.outputs = misc.SynchronizedList(self, field_name="inputs", initial=[])
|
|
415
|
+
if (
|
|
416
|
+
not isinstance(values, np.ndarray)
|
|
417
|
+
and not isinstance(values, LazyValues)
|
|
418
|
+
and not isinstance(values, SparseValues)
|
|
419
|
+
):
|
|
420
|
+
G_LOGGER.critical(
|
|
421
|
+
"Provided `values` argument is not a NumPy array, a LazyValues instance or a"
|
|
422
|
+
"SparseValues instance. Please provide a NumPy array or LazyValues instance "
|
|
423
|
+
f"to construct a Constant. Note: Provided `values` parameter was: {values}"
|
|
424
|
+
)
|
|
425
|
+
self._values = values
|
|
426
|
+
self.data_location = data_location
|
|
427
|
+
self._export_dtype = export_dtype
|
|
428
|
+
|
|
429
|
+
def to_variable(self, dtype: np.dtype = None, shape: Sequence[int | str] | None = None):
|
|
430
|
+
"""Convert instance values to an appropriate variable with specified dtype and shape."""
|
|
431
|
+
if shape is None:
|
|
432
|
+
shape = []
|
|
433
|
+
del self._export_dtype
|
|
434
|
+
del self._values
|
|
435
|
+
|
|
436
|
+
if dtype is not None:
|
|
437
|
+
return super().to_variable(dtype, shape)
|
|
438
|
+
|
|
439
|
+
var_dtype = self.export_dtype
|
|
440
|
+
|
|
441
|
+
return super().to_variable(var_dtype, shape)
|
|
442
|
+
|
|
443
|
+
def copy(self):
|
|
444
|
+
"""
|
|
445
|
+
Makes a shallow copy of this tensor, omitting input and output information.
|
|
446
|
+
|
|
447
|
+
Note: Generally, you should only ever make a copy of a Graph.
|
|
448
|
+
"""
|
|
449
|
+
return Constant(self.name, self._values, export_dtype=self.export_dtype)
|
|
450
|
+
|
|
451
|
+
@property
|
|
452
|
+
def values(self):
|
|
453
|
+
"""Return the values of the tensor, loading them if they are accessed for the first time."""
|
|
454
|
+
if isinstance(self._values, LazyValues):
|
|
455
|
+
self._values = self._values.load()
|
|
456
|
+
return self._values
|
|
457
|
+
|
|
458
|
+
@values.setter
|
|
459
|
+
def values(self, values: np.ndarray | LazyValues):
|
|
460
|
+
"""Return the values of the tensor, loading them if accessed for the first time."""
|
|
461
|
+
self._values = values
|
|
462
|
+
|
|
463
|
+
@property
|
|
464
|
+
def shape(self):
|
|
465
|
+
"""Return the shape of the tensor values."""
|
|
466
|
+
return self._values.shape
|
|
467
|
+
|
|
468
|
+
@property
|
|
469
|
+
def dtype(self):
|
|
470
|
+
"""Return the data type (dtype) of the tensor values."""
|
|
471
|
+
return self._values.dtype
|
|
472
|
+
|
|
473
|
+
@property
|
|
474
|
+
def export_dtype(self):
|
|
475
|
+
"""Return the export data type (export_dtype) of the tensor values if specified, otherwise None."""
|
|
476
|
+
return self._export_dtype if self._export_dtype is not None else self.dtype
|
|
477
|
+
|
|
478
|
+
@export_dtype.setter
|
|
479
|
+
def export_dtype(self, export_dtype):
|
|
480
|
+
"""Return the export data type of tensor values if specified, otherwise return the default data type."""
|
|
481
|
+
self._export_dtype = export_dtype
|
|
482
|
+
|
|
483
|
+
def __repr__(self): # Hack to make logging output pretty.
|
|
484
|
+
"""Return a string representation of the object, including its values, for improved logging readability."""
|
|
485
|
+
ret = self.__str__()
|
|
486
|
+
ret += f"\n{self._values}"
|
|
487
|
+
return ret
|
|
488
|
+
|
|
489
|
+
def __eq__(self, other):
|
|
490
|
+
"""Perform a check to see if two constants are equal."""
|
|
491
|
+
if not isinstance(other, Constant):
|
|
492
|
+
return False
|
|
493
|
+
|
|
494
|
+
if self._values.shape != other._values.shape:
|
|
495
|
+
return False
|
|
496
|
+
|
|
497
|
+
if self._values.dtype != other._values.dtype:
|
|
498
|
+
return False
|
|
499
|
+
|
|
500
|
+
return (
|
|
501
|
+
self._values == other._values
|
|
502
|
+
if isinstance(self._values, LazyValues) and isinstance(other._values, LazyValues)
|
|
503
|
+
else np.array_equal(self.values, other.values)
|
|
504
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from onnxslim.third_party.onnx_graphsurgeon.logger.logger import G_LOGGER, LogMode
|