onnxslim 0.1.80__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. onnxslim/__init__.py +16 -0
  2. onnxslim/__main__.py +4 -0
  3. onnxslim/argparser.py +215 -0
  4. onnxslim/cli/__init__.py +1 -0
  5. onnxslim/cli/_main.py +180 -0
  6. onnxslim/core/__init__.py +219 -0
  7. onnxslim/core/optimization/__init__.py +146 -0
  8. onnxslim/core/optimization/dead_node_elimination.py +151 -0
  9. onnxslim/core/optimization/subexpression_elimination.py +76 -0
  10. onnxslim/core/optimization/weight_tying.py +59 -0
  11. onnxslim/core/pattern/__init__.py +249 -0
  12. onnxslim/core/pattern/elimination/__init__.py +5 -0
  13. onnxslim/core/pattern/elimination/concat.py +61 -0
  14. onnxslim/core/pattern/elimination/reshape.py +77 -0
  15. onnxslim/core/pattern/elimination/reshape_as.py +64 -0
  16. onnxslim/core/pattern/elimination/slice.py +108 -0
  17. onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
  18. onnxslim/core/pattern/fusion/__init__.py +8 -0
  19. onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
  20. onnxslim/core/pattern/fusion/convadd.py +70 -0
  21. onnxslim/core/pattern/fusion/convbn.py +86 -0
  22. onnxslim/core/pattern/fusion/convmul.py +69 -0
  23. onnxslim/core/pattern/fusion/gelu.py +47 -0
  24. onnxslim/core/pattern/fusion/gemm.py +330 -0
  25. onnxslim/core/pattern/fusion/padconv.py +89 -0
  26. onnxslim/core/pattern/fusion/reduce.py +67 -0
  27. onnxslim/core/pattern/registry.py +28 -0
  28. onnxslim/misc/__init__.py +0 -0
  29. onnxslim/misc/tabulate.py +2681 -0
  30. onnxslim/third_party/__init__.py +0 -0
  31. onnxslim/third_party/_sympy/__init__.py +0 -0
  32. onnxslim/third_party/_sympy/functions.py +205 -0
  33. onnxslim/third_party/_sympy/numbers.py +397 -0
  34. onnxslim/third_party/_sympy/printers.py +491 -0
  35. onnxslim/third_party/_sympy/solve.py +172 -0
  36. onnxslim/third_party/_sympy/symbol.py +102 -0
  37. onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
  38. onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
  39. onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
  40. onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
  41. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
  42. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
  43. onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
  44. onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
  45. onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
  46. onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
  47. onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
  48. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
  49. onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
  50. onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
  51. onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
  52. onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
  53. onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
  54. onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
  55. onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
  56. onnxslim/third_party/symbolic_shape_infer.py +3273 -0
  57. onnxslim/utils.py +794 -0
  58. onnxslim/version.py +1 -0
  59. onnxslim-0.1.80.dist-info/METADATA +207 -0
  60. onnxslim-0.1.80.dist-info/RECORD +65 -0
  61. onnxslim-0.1.80.dist-info/WHEEL +5 -0
  62. onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
  63. onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
  64. onnxslim-0.1.80.dist-info/top_level.txt +1 -0
  65. onnxslim-0.1.80.dist-info/zip-safe +1 -0
onnxslim/utils.py ADDED
@@ -0,0 +1,794 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib.util
4
+ import logging
5
+ import os
6
+ import sys
7
+ from collections import defaultdict
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ import onnx
12
+ from onnx import checker, helper
13
+
14
+ import onnxslim.third_party.onnx_graphsurgeon as gs
15
+ from onnxslim.misc.tabulate import SEPARATING_LINE, tabulate
16
+ from onnxslim.third_party.onnx_graphsurgeon.logger.logger import G_LOGGER
17
+
18
+ logger = logging.getLogger("onnxslim")
19
+
20
+
21
+ import ml_dtypes
22
+
23
+ try:
24
+ from onnx._mapping import TensorDtypeMap
25
+ except ImportError:
26
+ from onnx.mapping import TensorDtypeMap
27
+
28
+ TENSOR_TYPE_MAP = {}
29
+
30
+ candidates = [
31
+ ("BFLOAT16", "bfloat16", "UINT16"),
32
+ ("FLOAT8E4M3FN", "float8_e4m3fn", "UINT8"),
33
+ ("FLOAT8E4M3FNUZ", "float8_e4m3fnuz", "UINT8"),
34
+ ("FLOAT8E5M2", "float8_e5m2", "UINT8"),
35
+ ("FLOAT8E5M2FNUZ", "float8_e5m2fnuz", "UINT8"),
36
+ ("UINT4", "uint4", "INT32"),
37
+ ("INT4", "int4", "INT32"),
38
+ ("FLOAT4E2M1", "float4_e2m1fn", "UINT8"),
39
+ ]
40
+
41
+ for onnx_name, ml_name, storage_name in candidates:
42
+ if hasattr(onnx.TensorProto, onnx_name) and hasattr(ml_dtypes, ml_name):
43
+ TENSOR_TYPE_MAP[int(getattr(onnx.TensorProto, onnx_name))] = TensorDtypeMap(
44
+ np.dtype(getattr(ml_dtypes, ml_name)),
45
+ int(getattr(onnx.TensorProto, storage_name)),
46
+ f"TensorProto.{onnx_name}",
47
+ )
48
+
49
+
50
+ def init_logging(verbose=False):
51
+ """Configure the logging settings for the application based on the verbosity level."""
52
+ logger = logging.getLogger("onnxslim")
53
+
54
+ if verbose: # DEBUG
55
+ logger.setLevel(logging.DEBUG)
56
+ G_LOGGER.severity = logging.DEBUG
57
+ else: # ERROR
58
+ logger.setLevel(logging.ERROR)
59
+ G_LOGGER.severity = logging.ERROR
60
+
61
+ if not logger.handlers:
62
+ handler = logging.StreamHandler(sys.stderr)
63
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
64
+ handler.setFormatter(formatter)
65
+ logger.addHandler(handler)
66
+
67
+ G_LOGGER.colors = False
68
+
69
+ if is_onnxruntime_available():
70
+ import onnxruntime as ort
71
+
72
+ ort.set_default_logger_severity(3)
73
+
74
+ return logger
75
+
76
+
77
+ def format_bytes(size: int | tuple[int, ...]) -> str:
78
+ """Convert byte sizes into human-readable format with appropriate units (B, KB, MB, GB)."""
79
+ if isinstance(size, int):
80
+ size = (size,)
81
+ elif isinstance(size, np.integer):
82
+ size = (int(size),)
83
+
84
+ units = ["B", "KB", "MB", "GB"]
85
+ formatted_sizes = []
86
+
87
+ for size_in_bytes in size:
88
+ unit_index = 0
89
+ while size_in_bytes >= 1024 and unit_index < len(units) - 1:
90
+ size_in_bytes /= 1024
91
+ unit_index += 1
92
+
93
+ formatted_size = f"{size_in_bytes:.2f} {units[unit_index]}"
94
+ formatted_sizes.append(formatted_size)
95
+
96
+ if len(formatted_sizes) == 1:
97
+ return formatted_sizes[0]
98
+ else:
99
+ return f"{formatted_sizes[0]} ({formatted_sizes[1]})"
100
+
101
+
102
+ def onnx_dtype_to_numpy(onnx_dtype: int) -> np.dtype:
103
+ """Maps an ONNX dtype to its corresponding NumPy dtype."""
104
+ tensor_dtype = TENSOR_TYPE_MAP.get(onnx_dtype)
105
+
106
+ if tensor_dtype:
107
+ return tensor_dtype.np_dtype
108
+
109
+ if onnx_dtype in onnx.helper.get_all_tensor_dtypes():
110
+ return np.dtype(helper.tensor_dtype_to_np_dtype(onnx_dtype))
111
+
112
+ return "UNDEFINED"
113
+
114
+
115
+ def gen_onnxruntime_input_data(
116
+ model: onnx.ModelProto, model_check_inputs: list[str] | None = None
117
+ ) -> dict[str, np.ndarray]:
118
+ """Generate random input data for an ONNX model considering potential specific input shapes and types."""
119
+ input_info = {}
120
+ for input_tensor in model.graph.input:
121
+ name = input_tensor.name
122
+ shape = []
123
+ for dim in input_tensor.type.tensor_type.shape.dim:
124
+ if dim.HasField("dim_param"):
125
+ shape.append(dim.dim_param)
126
+ elif dim.HasField("dim_value"):
127
+ shape.append(dim.dim_value)
128
+ else:
129
+ shape.append(None)
130
+ dtype = onnx_dtype_to_numpy(input_tensor.type.tensor_type.elem_type)
131
+
132
+ input_info[name] = {"shape": shape, "dtype": dtype}
133
+
134
+ if model_check_inputs:
135
+ for model_check_input in model_check_inputs:
136
+ key, value = model_check_input.rsplit(":", 1)
137
+ if value.endswith(".npy"):
138
+ if key not in input_info:
139
+ raise Exception(
140
+ f"model_check_input name:{key} not found in model, available keys: {' '.join(input_info.keys())}"
141
+ )
142
+ data = np.load(value)
143
+ input_info[key] = {"data": data}
144
+ else:
145
+ values_list = [int(val) for val in value.split(",")]
146
+ if key in input_info:
147
+ input_info[key]["shape"] = values_list
148
+ else:
149
+ raise Exception(
150
+ f"model_check_input name:{key} not found in model, available keys: {' '.join(input_info.keys())}"
151
+ )
152
+
153
+ input_data_dict = {}
154
+ for name, info in input_info.items():
155
+ if "data" in info:
156
+ input_data_dict[name] = info["data"]
157
+ else:
158
+ shapes = [shape if (shape != -1 and not isinstance(shape, str)) else 1 for shape in info["shape"]]
159
+ shapes = shapes or [1]
160
+ dtype = info["dtype"]
161
+
162
+ if dtype in {np.int32, np.int64}:
163
+ random_data = np.random.randint(10, size=shapes).astype(dtype)
164
+ else:
165
+ random_data = np.random.rand(*shapes).astype(dtype)
166
+ input_data_dict[name] = random_data
167
+
168
+ return input_data_dict
169
+
170
+
171
+ def onnxruntime_inference(model: onnx.ModelProto, input_data: dict) -> dict[str, np.array]:
172
+ """Perform inference using ONNX Runtime on the given model and input data."""
173
+ import os
174
+ import tempfile
175
+
176
+ import onnx
177
+ import onnxruntime as rt
178
+
179
+ if model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
180
+ tmp_dir = tempfile.TemporaryDirectory()
181
+ tmp_path = os.path.join(tmp_dir.name, "tmp.onnx")
182
+ location = f"{os.path.basename(tmp_path)}.data"
183
+ if os.path.exists(location):
184
+ os.remove(location)
185
+ onnx.save(
186
+ model,
187
+ tmp_path,
188
+ save_as_external_data=True,
189
+ all_tensors_to_one_file=True,
190
+ location=location,
191
+ )
192
+ onnx_model = tmp_path
193
+ else:
194
+ onnx_model = model.SerializeToString()
195
+
196
+ sess = rt.InferenceSession(onnx_model, providers=["CPUExecutionProvider"])
197
+ onnx_output = sess.run(None, input_data)
198
+
199
+ output_names = [output.name for output in sess.get_outputs()]
200
+ onnx_output = dict(zip(output_names, onnx_output))
201
+
202
+ if isinstance(onnx_model, str):
203
+ model = onnx.load(onnx_model)
204
+
205
+ return onnx_output, model
206
+
207
+
208
+ def format_model_info(model_info_list: dict | list[dict], elapsed_time: float | None = None):
209
+ assert model_info_list, "model_info_list must contain more than one model info"
210
+ from colorama import Fore, init
211
+
212
+ init()
213
+ if not isinstance(model_info_list, (list, tuple)):
214
+ model_info_list = [model_info_list]
215
+
216
+ final_op_info = []
217
+ final_op_info.extend(
218
+ (
219
+ ["Model Name"] + [item.tag for item in model_info_list],
220
+ [SEPARATING_LINE] * (len(model_info_list) + 1),
221
+ ["Model Info"]
222
+ + ["Op Set: " + item.op_set + " / IR Version: " + item.ir_version for item in model_info_list],
223
+ [SEPARATING_LINE] * (len(model_info_list) + 1),
224
+ )
225
+ )
226
+
227
+ def get_io_info(model_info_list, tag=None):
228
+ if tag == "OUT":
229
+ ios = [op_type for model_info in model_info_list for op_type in model_info.output_info]
230
+ else:
231
+ ios = [op_type for model_info in model_info_list for op_type in model_info.input_info]
232
+ ios = list(dict.fromkeys([io.name for io in ios]))
233
+ io_info = []
234
+ for io in ios:
235
+ input_info_list = [f"{tag}: {io}"]
236
+ for model_info in model_info_list:
237
+ if tag == "OUT":
238
+ io_tensor = model_info.output_maps.get(io, None)
239
+ else:
240
+ io_tensor = model_info.input_maps.get(io, None)
241
+ inputs_shape = (io_tensor.dtype, io_tensor.shape) if io_tensor else ""
242
+ if isinstance(inputs_shape, (list, tuple)):
243
+ inputs_shape = ": ".join([str(i) for i in inputs_shape])
244
+ input_info_list.append(inputs_shape)
245
+ io_info.append(input_info_list)
246
+
247
+ return io_info
248
+
249
+ final_op_info.extend(get_io_info(model_info_list, "IN"))
250
+ final_op_info.extend(get_io_info(model_info_list, "OUT"))
251
+
252
+ final_op_info.append([SEPARATING_LINE] * (len(model_info_list) + 1))
253
+
254
+ all_ops = {op_type for model_info in model_info_list for op_type in model_info.op_type_counts}
255
+ sorted_ops = sorted(all_ops)
256
+ for op in sorted_ops:
257
+ op_info_list = [op]
258
+ float_number = model_info_list[0].op_type_counts.get(op, 0)
259
+ op_info_list.append(float_number)
260
+ for model_info in model_info_list[1:]:
261
+ slimmed_number = model_info.op_type_counts.get(op, 0)
262
+ if float_number > slimmed_number:
263
+ slimmed_number = Fore.GREEN + str(slimmed_number) + Fore.WHITE
264
+ op_info_list.append(slimmed_number)
265
+
266
+ final_op_info.append(op_info_list)
267
+ final_op_info.extend(
268
+ (
269
+ [SEPARATING_LINE] * (len(model_info_list) + 1),
270
+ ["Model Size"] + [format_bytes(model_info.model_size) for model_info in model_info_list],
271
+ )
272
+ )
273
+ if elapsed_time:
274
+ final_op_info.extend(
275
+ (
276
+ [SEPARATING_LINE] * (len(model_info_list) + 1),
277
+ ["Elapsed Time", f"{elapsed_time:.2f} s"],
278
+ )
279
+ )
280
+
281
+ return final_op_info
282
+
283
+
284
+ def print_model_info_as_table(model_info_list: dict | list[dict], elapsed_time: float | None = None):
285
+ """Prints the model information as a formatted table for the given model name and list of model details."""
286
+ if not isinstance(model_info_list, (list, tuple)):
287
+ model_info_list = [model_info_list]
288
+
289
+ final_op_info = format_model_info(model_info_list, elapsed_time)
290
+ lines = tabulate(
291
+ final_op_info,
292
+ headers=[],
293
+ tablefmt="pretty",
294
+ maxcolwidths=[None] + [40] * len(model_info_list),
295
+ ).split("\n")
296
+ if elapsed_time:
297
+ time_row = lines[-2].split("|")
298
+ time_row[-3] = (
299
+ time_row[-2][: len(time_row[-2]) // 2 + 1] + time_row[-3] + time_row[-2][len(time_row[-2]) // 2 :]
300
+ )
301
+ time_row.pop(-2)
302
+ lines[-2] = "|".join(time_row)
303
+ output = "\n".join([line if line != "| \x01 |" else lines[0] for line in lines])
304
+
305
+ print(output)
306
+
307
+
308
+ def dump_model_info_to_disk(model_info: dict):
309
+ """Writes model information to a CSV file for a given model name and dictionary of model info."""
310
+ import csv
311
+
312
+ csv_file_path = f"{model_info.tag}_model_info.csv"
313
+ with open(csv_file_path, "a", newline="") as csvfile: # Use 'a' for append mode
314
+ fieldnames = ["NodeName", "OpType", "OutputDtype", "OutputShape"]
315
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
316
+
317
+ # If the file is empty, write the header
318
+ if csvfile.tell() == 0:
319
+ writer.writeheader()
320
+
321
+ # Write the data
322
+ for node_name, info in model_info.op_info.items():
323
+ op_type, output_info_list = info.op, info.outputs
324
+ if len(output_info_list) >= 1:
325
+ # Write the first row with actual NodeName and OpType
326
+ row_data_first = {
327
+ "NodeName": node_name,
328
+ "OpType": op_type,
329
+ "OutputDtype": output_info_list[0].dtype, # First entry in the list
330
+ "OutputShape": output_info_list[0].shape, # First entry in the list
331
+ }
332
+ writer.writerow(row_data_first)
333
+
334
+ # Write subsequent rows with empty strings for NodeName and OpType
335
+ for output_dtype, output_shape in output_info_list[1:]:
336
+ row_data_empty = {
337
+ "NodeName": "",
338
+ "OpType": "",
339
+ "OutputDtype": output_dtype,
340
+ "OutputShape": output_shape,
341
+ }
342
+ writer.writerow(row_data_empty)
343
+ print(f"Model info written to {csv_file_path}")
344
+
345
+
346
+ def get_opset(model: onnx.ModelProto) -> int:
347
+ """Returns the ONNX opset version for a given model."""
348
+ try:
349
+ for importer in model.opset_import:
350
+ if importer.domain in {"", "ai.onnx"}:
351
+ return importer.version
352
+
353
+ return None
354
+ except Exception:
355
+ return None
356
+
357
+
358
+ def get_ir_version(model: onnx.ModelProto) -> int:
359
+ """Returns the ONNX ir version for a given model."""
360
+ try:
361
+ return model.ir_version
362
+ except Exception:
363
+ return None
364
+
365
+
366
+ class TensorInfo:
367
+ def __init__(self, tensor):
368
+ self.dtype: np.dtype = np.float32
369
+ self.shape: tuple[str | int] = None
370
+
371
+ self._extract_info(tensor)
372
+
373
+ def _extract_info(self, tensor):
374
+ """Extract the data type and shape of an ONNX tensor."""
375
+ self.dtype = onnx_dtype_to_numpy(tensor.type.tensor_type.elem_type)
376
+ shape = None
377
+ if tensor.type.tensor_type.HasField("shape"):
378
+ shape = []
379
+ for dim in tensor.type.tensor_type.shape.dim:
380
+ if dim.HasField("dim_param"):
381
+ shape.append(dim.dim_param)
382
+ elif dim.HasField("dim_value"):
383
+ shape.append(dim.dim_value)
384
+ else:
385
+ shape.append("?")
386
+
387
+ self.shape = tuple(shape) if shape is not None else None
388
+ self.name = tensor.name
389
+
390
+
391
+ class OperatorInfo:
392
+ def __init__(self, operator, outputs=None):
393
+ self.name: str = None
394
+ self.op: str = None
395
+
396
+ self._extract_info(operator)
397
+ self.outputs = outputs
398
+
399
+ def _extract_info(self, operator):
400
+ self.name: str = operator.name
401
+ self.op: str = operator.op_type
402
+
403
+
404
+ class ModelInfo:
405
+ def __init__(self, model: str | onnx.ModelProto, tag: str = "OnnxSlim"):
406
+ if isinstance(model, str):
407
+ tag = Path(model).name
408
+ model = onnx.load(model)
409
+
410
+ self.tag: str = tag
411
+ self.model_size: int = -1
412
+ self.op_set: str = None
413
+ self.ir_version: str = None
414
+ self.op_type_counts: dict[str, int] = defaultdict(int)
415
+ self.op_info: dict[str, dict] = {}
416
+ self.input_info: list[str, tuple[str, tuple]] = []
417
+ self.output_info: list[str, tuple[str, tuple]] = []
418
+
419
+ self._summarize_model(model)
420
+
421
+ def _summarize_model(self, model):
422
+ self.op_set = str(get_opset(model))
423
+ self.ir_version = str(get_ir_version(model))
424
+ self.model_size = get_initializer_size(model)
425
+
426
+ for input in model.graph.input:
427
+ self.input_info.append(TensorInfo(input))
428
+
429
+ for output in model.graph.output:
430
+ self.output_info.append(TensorInfo(output))
431
+
432
+ value_info_dict = {value_info.name: value_info for value_info in model.graph.value_info}
433
+
434
+ def get_graph_node_info(graph: onnx.GraphProto) -> dict[str, list[str]]:
435
+ for node in graph.node:
436
+ op_type = node.op_type
437
+ self.op_type_counts[op_type] += 1
438
+ output_tensor_info = []
439
+ for output in node.output:
440
+ if output in value_info_dict:
441
+ tensor = value_info_dict[output]
442
+ tensor_info = TensorInfo(tensor)
443
+ output_tensor_info.append(tensor_info)
444
+
445
+ self.op_info[node.name] = OperatorInfo(node, output_tensor_info)
446
+
447
+ for attr in node.attribute:
448
+ ATTR_TYPE_MAPPING = {v: k for k, v in onnx.AttributeProto.AttributeType.items()}
449
+ if attr.type in ATTR_TYPE_MAPPING:
450
+ attr_str = ATTR_TYPE_MAPPING[attr.type]
451
+ if attr_str == "GRAPH":
452
+ get_graph_node_info(attr.g)
453
+
454
+ get_graph_node_info(model.graph)
455
+
456
+ @property
457
+ def input_maps(self):
458
+ self.input_dict = {input_info.name: input_info for input_info in self.input_info}
459
+
460
+ return self.input_dict
461
+
462
+ @property
463
+ def output_maps(self):
464
+ self.output_dict = {output_info.name: output_info for output_info in self.output_info}
465
+
466
+ return self.output_dict
467
+
468
+
469
+ def summarize_model(model: str | onnx.ModelProto, tag="OnnxModel") -> dict:
470
+ """Generates a summary of the ONNX model, including model size, operations, and tensor shapes."""
471
+ logger.debug("Start summarizing model.")
472
+ model_info = ModelInfo(model, tag)
473
+ logger.debug("Finish summarizing model.")
474
+ return model_info
475
+
476
+
477
+ def model_save_as_external_data(model: onnx.ModelProto, model_path: str):
478
+ """Save an ONNX model with tensor data as an external file."""
479
+ location = f"{os.path.basename(model_path)}.data"
480
+ if os.path.exists(location):
481
+ os.remove(location)
482
+ onnx.save(
483
+ model,
484
+ model_path,
485
+ save_as_external_data=True,
486
+ all_tensors_to_one_file=True,
487
+ location=location,
488
+ )
489
+
490
+
491
+ def check_onnx(model: onnx.ModelProto, model_check_inputs=None):
492
+ """Validates an ONNX model by generating input data and performing inference to check outputs."""
493
+ input_data_dict = gen_onnxruntime_input_data(model, model_check_inputs)
494
+ raw_onnx_output, model = onnxruntime_inference(model, input_data_dict)
495
+
496
+ return input_data_dict, raw_onnx_output, model
497
+
498
+
499
+ def check_point(model: onnx.ModelProto):
500
+ """Imports an ONNX model checkpoint into a Graphsurgeon graph representation."""
501
+ return gs.import_onnx(model)
502
+
503
+
504
+ def save(
505
+ model: onnx.ModelProto,
506
+ model_path: str,
507
+ model_check: bool = False,
508
+ save_as_external_data: bool = False,
509
+ model_info: dict | None = None,
510
+ ):
511
+ """Save an ONNX model to a specified path, with optional model checking for validity."""
512
+ if model_check:
513
+ try:
514
+ checker.check_model(model)
515
+ except ValueError:
516
+ logger.warning("Model too large and cannot be checked.")
517
+
518
+ if model_path: # model larger than 2GB can be saved, but compiler like trtexec won't parse it
519
+ if get_initializer_size(model) <= checker.MAXIMUM_PROTOBUF and not save_as_external_data:
520
+ onnx.save(model, model_path)
521
+ else:
522
+ import os
523
+
524
+ location = f"{os.path.basename(model_path)}.data"
525
+ if os.path.exists(location):
526
+ os.remove(location)
527
+ onnx.save(
528
+ model,
529
+ model_path,
530
+ save_as_external_data=True,
531
+ all_tensors_to_one_file=True,
532
+ location=location,
533
+ )
534
+ logger.debug("Model too large and saved as external data automatically.")
535
+
536
+ if model_info:
537
+ model_size = model.ByteSize()
538
+ model_info.model_size = [model_size, model_info.model_size]
539
+
540
+
541
+ def check_result(raw_onnx_output, slimmed_onnx_output):
542
+ """Verify the consistency of outputs between the raw and slimmed ONNX models, logging warnings if discrepancies are
543
+ detected.
544
+ """
545
+ if set(raw_onnx_output.keys()) != set(slimmed_onnx_output.keys()):
546
+ print("Model output mismatch after slimming.")
547
+ print(f"Raw model output keys: {raw_onnx_output.keys()}")
548
+ print(f"Slimmed model output keys: {slimmed_onnx_output.keys()}")
549
+ print("Please check the model carefully.")
550
+ return False
551
+ else:
552
+ for key in raw_onnx_output.keys():
553
+ if not np.allclose(
554
+ raw_onnx_output[key],
555
+ slimmed_onnx_output[key],
556
+ rtol=1e-03,
557
+ atol=1e-04,
558
+ equal_nan=True,
559
+ ):
560
+ print(f"\033[31mModel output {key} mismatch after slimming.")
561
+ print("\033[31mPlease check the model carefully.")
562
+ return False
563
+
564
+ return True
565
+
566
+
567
+ def get_numpy_type(onnx_type):
568
+ if not isinstance(onnx_type, int):
569
+ # Already a NumPy type
570
+ return onnx_type
571
+
572
+ numpy_unsupported_types = [
573
+ onnx.TensorProto.BFLOAT16,
574
+ onnx.TensorProto.FLOAT8E4M3FN,
575
+ onnx.TensorProto.FLOAT8E4M3FNUZ,
576
+ onnx.TensorProto.FLOAT8E5M2,
577
+ onnx.TensorProto.FLOAT8E5M2FNUZ,
578
+ ]
579
+
580
+ # TENSOR_TYPE_TO_NP_TYPE maps types unsupported by NumPy to random other types.
581
+ # This obviously breaks things, so we need to treat this as a special case.
582
+ if onnx_type not in numpy_unsupported_types and onnx_type in onnx.helper.get_all_tensor_dtypes():
583
+ return onnx.helper.tensor_dtype_to_np_dtype(onnx_type)
584
+ return None
585
+
586
+
587
+ def get_itemsize(dtype):
588
+ np_dtype = get_numpy_type(dtype)
589
+ if np_dtype is not None:
590
+ return np.dtype(np_dtype).itemsize
591
+
592
+ if dtype == onnx.TensorProto.BFLOAT16:
593
+ return 2
594
+
595
+ if dtype in [
596
+ onnx.TensorProto.FLOAT8E4M3FN,
597
+ onnx.TensorProto.FLOAT8E4M3FNUZ,
598
+ onnx.TensorProto.FLOAT8E5M2,
599
+ onnx.TensorProto.FLOAT8E5M2FNUZ,
600
+ ]:
601
+ return 1
602
+
603
+ print(f"Unknown ONNX dtype: {dtype}")
604
+ raise ValueError(f"Unsupported TensorProto dtype: {dtype}")
605
+
606
+
607
+ def calculate_tensor_size(tensor):
608
+ """Calculates the size of an ONNX tensor in bytes based on its shape and data type size."""
609
+ shape = tensor.dims
610
+ num_elements = np.prod(shape) if shape else 0
611
+ element_size = get_itemsize(tensor.data_type)
612
+ return num_elements * element_size
613
+
614
+
615
+ def get_initializer_size(model):
616
+ """Calculate total size of all subgraphs in an ONNX model."""
617
+ total_size = get_graph_initializer_size(model.graph)
618
+ return total_size
619
+
620
+
621
+ def get_graph_initializer_size(graph):
622
+ initializer_size = 0
623
+ for tensor in graph.initializer:
624
+ tensor_size = calculate_tensor_size(tensor)
625
+ initializer_size += tensor_size
626
+
627
+ for node in graph.node:
628
+ if node.op_type == "Constant":
629
+ for attr in node.attribute:
630
+ if attr.name == "value" and attr.type == onnx.AttributeProto.TENSOR:
631
+ initializer_size += calculate_tensor_size(attr.t)
632
+
633
+ elif node.op_type == "If":
634
+ initializer_size += get_graph_initializer_size(node.attribute[0].g)
635
+ initializer_size += get_graph_initializer_size(node.attribute[1].g)
636
+ elif node.op_type == "Loop":
637
+ initializer_size += get_graph_initializer_size(node.attribute[0].g)
638
+ elif node.op_type == "Scan":
639
+ initializer_size += get_graph_initializer_size(node.attribute[0].g)
640
+
641
+ return initializer_size
642
+
643
+
644
+ def is_onnxruntime_available():
645
+ if importlib.util.find_spec("onnxruntime") is None:
646
+ logger = logging.getLogger("onnxslim")
647
+ logger.debug("onnxruntime is not available, please install it first for better optimization")
648
+ return False
649
+ else:
650
+ try:
651
+ # in case of onnxruntime import error
652
+ import onnxruntime as ort
653
+
654
+ if hasattr(ort, "__version__"):
655
+ return True
656
+ else:
657
+ return False
658
+ except:
659
+ logger = logging.getLogger("onnxslim")
660
+ logger.debug("onnxruntime is not available, please install it first for better optimization")
661
+ return False
662
+
663
+
664
+ def check_onnx_compatibility():
665
+ """Ensure ONNX Runtime and ONNX versions are compatible for model inference."""
666
+ compatibility_dict = {
667
+ "1.20": "1.16",
668
+ "1.19": "1.16",
669
+ "1.18": "1.16",
670
+ "1.17": "1.15",
671
+ "1.16": "1.14.1",
672
+ "1.15": "1.14",
673
+ "1.14": "1.13",
674
+ "1.13": "1.12",
675
+ "1.12": "1.12",
676
+ "1.11": "1.11",
677
+ "1.10": "1.10",
678
+ "1.9": "1.10",
679
+ "1.8": "1.9",
680
+ "1.7": "1.8",
681
+ "1.6": "1.8",
682
+ "1.5": "1.7",
683
+ "1.4": "1.7",
684
+ "1.3": "1.7",
685
+ "1.2": "1.6",
686
+ "1.1": "1.6",
687
+ "1.0": "1.6",
688
+ "0.5": "1.5",
689
+ "0.4": "1.5",
690
+ "0.3": "1.4",
691
+ "0.2": "1.3",
692
+ "0.1": "1.3",
693
+ }
694
+ import onnx
695
+ import onnxruntime
696
+
697
+ onnx_version = onnx.__version__
698
+ # ort_version = onnxruntime.__version__
699
+ ort_version = ".".join(onnxruntime.__version__.split("+")[0].split(".")[:2])
700
+ # Check compatibility
701
+ expected_onnx_version = compatibility_dict.get(ort_version)
702
+ if expected_onnx_version is None:
703
+ print(
704
+ f"Warning: Onnx Runtime version {ort_version} has no specified compatible ONNX version. Compatibility issues may occur."
705
+ )
706
+ elif expected_onnx_version == ".".join(onnx_version.split("+")[0].split(".")[:2]):
707
+ logger.info(
708
+ f"Installed Onnx Runtime version {ort_version} is compatible with installed ONNX version {onnx_version}."
709
+ )
710
+ else:
711
+ print(
712
+ f"Warning: Installed Onnx Runtime version {ort_version} is not compatible with installed ONNX version {onnx_version}. Expected ONNX version: {expected_onnx_version}."
713
+ )
714
+
715
+
716
+ def get_max_tensor(model, topk=5):
717
+ graph = gs.import_onnx(model)
718
+
719
+ tensor_map = graph.tensors()
720
+ constant_tensors = [tensor for tensor in tensor_map.values() if isinstance(tensor, gs.Constant)]
721
+
722
+ sub_graphs = graph.subgraphs(recursive=True)
723
+ sub_graphs_constant_tensors = [
724
+ [tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, gs.Constant)]
725
+ for sub_graph in sub_graphs
726
+ ]
727
+
728
+ constant_tensors.extend([tensor for tensors in sub_graphs_constant_tensors for tensor in tensors])
729
+
730
+ sizes = [tensor.values.size for tensor in constant_tensors]
731
+ sorted_indices = np.argsort(sizes)[::-1][:topk]
732
+
733
+ for i in sorted_indices:
734
+ tensor = constant_tensors[i]
735
+ print(
736
+ f"Tensor name: {tensor.name}, shape: {tensor.values.shape}, dtype: {tensor.values.dtype} size: {tensor.values.size}"
737
+ )
738
+
739
+
740
+ # copied from https://onnx.ai/onnx/api/tools.html
741
+ def update_outputs_dims(
742
+ model,
743
+ output_dims,
744
+ ):
745
+ dim_param_set: set[str] = set()
746
+
747
+ def init_dim_param_set(dim_param_set, value_infos):
748
+ for info in value_infos:
749
+ shape = info.type.tensor_type.shape
750
+ for dim in shape.dim:
751
+ if dim.HasField("dim_param"):
752
+ dim_param_set.add(dim.dim_param) # type: ignore
753
+
754
+ init_dim_param_set(dim_param_set, model.graph.output) # type: ignore
755
+
756
+ def update_dim(tensor, dim, j, name) -> None:
757
+ dim_proto = tensor.type.tensor_type.shape.dim[j]
758
+
759
+ # if it's int in model, it won't be replaced by original symbol
760
+ if dim_proto.HasField("dim_value"):
761
+ return
762
+
763
+ if isinstance(dim, int):
764
+ if dim >= 0:
765
+ if dim_proto.HasField("dim_value") and dim_proto.dim_value != dim:
766
+ raise ValueError(
767
+ f"Unable to set dimension value to {dim} for axis {j} of {name}. Contradicts existing dimension value {dim_proto.dim_value}."
768
+ )
769
+ dim_proto.dim_value = dim
770
+ else:
771
+ generated_dim_param = name + "_" + str(j)
772
+ if generated_dim_param in dim_param_set:
773
+ raise ValueError(
774
+ f"Unable to generate unique dim_param for axis {j} of {name}. Please manually provide a dim_param value."
775
+ )
776
+ dim_proto.dim_param = generated_dim_param
777
+ elif isinstance(dim, str):
778
+ dim_proto.dim_param = dim
779
+ else:
780
+ raise ValueError(f"Only int or str is accepted as dimension value, incorrect type: {type(dim)}")
781
+
782
+ for output in model.graph.output:
783
+ output_name = output.name
784
+ output_dim_arr = output_dims[output_name]
785
+ if output_dim_arr is None:
786
+ continue
787
+
788
+ if len(output.type.tensor_type.shape.dim) == 0:
789
+ for _ in range(len(output_dim_arr)):
790
+ output.type.tensor_type.shape.dim.add()
791
+ for j, dim in enumerate(output_dim_arr):
792
+ update_dim(output, dim, j, output_name)
793
+
794
+ return model