onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,669 @@
1
+ import functools
2
+ from typing import Dict, List, Optional, Sequence, Tuple, Union
3
+ import numpy as np
4
+ import onnx
5
+ import torch
6
+ from ..helpers.torch_helper import to_tensor, to_numpy
7
+ from ..torch_onnx.runtime_info import first_used_last_used, RuntimeValue
8
+ from .report_results_comparison import ReportResultComparison
9
+ from . import torch_ops
10
+
11
+
12
+ @functools.lru_cache
13
+ def get_kernels() -> Dict[Tuple[str, str, int], type[torch_ops.OpRunKernel]]:
14
+ """
15
+ Retrieves all the available kernels class :class:`TorchOnnxEvaluator`
16
+ can use. The full list is the following.
17
+
18
+ .. runpython::
19
+ :showcode:
20
+
21
+ from onnx_diagnostic.reference.torch_evaluator import get_kernels
22
+
23
+ for k, v in sorted(get_kernels().items()):
24
+ domain, name, version = k
25
+ f = f"{name}({version})" if domain == "" else f"{name}[{domain}]({version})"
26
+ add = " " * max(25 - len(f), 0)
27
+ dd = " -- device dependent" if v.device_dependent() else ""
28
+ print(f"{f}{add} -- {v.__name__}{dd}")
29
+ """
30
+ res = {}
31
+ for _k, v in torch_ops.__dict__.items():
32
+ if isinstance(v, type) and issubclass(v, torch_ops.OpRunKernel) and "_" in v.__name__:
33
+ name, version = v.__name__.split("_")
34
+ domain = getattr(v, "domain", "")
35
+ res[domain, name, int(version)] = v
36
+ return res
37
+
38
+
39
+ class TorchOnnxEvaluator:
40
+ """
41
+ Torch evaluator for onnx models.
42
+ The model does not stores the original proto it evaluates to avoid
43
+
44
+ :param proto: a proto
45
+ :param providers: where to run the model
46
+ :param opsets: needed if proto is a graph
47
+ :param functions: known local functions
48
+ :param verbose: verbosity level
49
+ :param custom_kernels: dictionary of kernels the user can defined to overwrite
50
+ a specific implementation: ``("", "LayerNormalization"): CustomKernel``
51
+
52
+ The class holds the following attributes:
53
+
54
+ * `providers`: providers
55
+ * `default_device`: default torch device
56
+ * `constants`: all initializers or constants
57
+ * `kernels`: kernels
58
+ * `runtime_info`: produced by :func:`first_used_last_used
59
+ <onnx_diagnostic.torch_onnx.runtime_info.first_used_last_used>`
60
+ * `last_used`: contains the list of intermediate results,
61
+ to remove after every node execution,
62
+ this avoid the memory to grow too much
63
+ * `functions`: local functions
64
+
65
+ The class is not multithreaded. `runtime_info` gets updated
66
+ by the the class. The list of available kernels is returned by function
67
+ :func:`onnx_diagnostic.reference.torch_evaluator.get_kernels`.
68
+ Example:
69
+
70
+ .. runpython::
71
+ :showcode:
72
+
73
+ import onnx
74
+ import onnx.helper as oh
75
+ import torch
76
+ from onnx_diagnostic.helpers import string_type
77
+ from onnx_diagnostic.reference import TorchOnnxEvaluator
78
+
79
+ TFLOAT = onnx.TensorProto.FLOAT
80
+
81
+ proto = oh.make_model(
82
+ oh.make_graph(
83
+ [
84
+ oh.make_node("Sigmoid", ["Y"], ["sy"]),
85
+ oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
86
+ oh.make_node("Mul", ["X", "ysy"], ["final"]),
87
+ ],
88
+ "-nd-",
89
+ [
90
+ oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
91
+ oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
92
+ ],
93
+ [oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
94
+ ),
95
+ opset_imports=[oh.make_opsetid("", 18)],
96
+ ir_version=9,
97
+ )
98
+
99
+ sess = TorchOnnxEvaluator(proto)
100
+ feeds = dict(X=torch.rand((4, 5)), Y=torch.rand((4, 5)))
101
+ result = sess.run(None, feeds)
102
+ print(string_type(result, with_shape=True, with_min_max=True))
103
+
104
+ With ``verbose=1``, the class prints out every kernel run and
105
+ and every result deleted along the run.
106
+ It shows when a result is not needed anymore. In that case,
107
+ it is deleted to free the memory it takes.
108
+
109
+ .. runpython::
110
+ :showcode:
111
+
112
+ import onnx
113
+ import onnx.helper as oh
114
+ import torch
115
+ from onnx_diagnostic.helpers import string_type
116
+ from onnx_diagnostic.reference import TorchOnnxEvaluator
117
+
118
+ TFLOAT = onnx.TensorProto.FLOAT
119
+
120
+ proto = oh.make_model(
121
+ oh.make_graph(
122
+ [
123
+ oh.make_node("Sigmoid", ["Y"], ["sy"]),
124
+ oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
125
+ oh.make_node("Mul", ["X", "ysy"], ["final"]),
126
+ ],
127
+ "-nd-",
128
+ [
129
+ oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
130
+ oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
131
+ ],
132
+ [oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
133
+ ),
134
+ opset_imports=[oh.make_opsetid("", 18)],
135
+ ir_version=9,
136
+ )
137
+
138
+ sess = TorchOnnxEvaluator(proto, verbose=1)
139
+ feeds = dict(X=torch.rand((4, 5)), Y=torch.rand((4, 5)))
140
+ result = sess.run(None, feeds)
141
+ print(string_type(result, with_shape=True, with_min_max=True))
142
+
143
+ The runtime can also execute the kernel the onnx model on CUDA.
144
+ It follows the same logic as :class:`onnxruntime.InferenceSession`:
145
+ ``providers=["CUDAExecutionProvider"]``.
146
+ It is better in that case to move the input on CUDA. The class
147
+ tries to move every weight on CUDA but tries to keep any tensor
148
+ identified as a shape in CPU. Some bugs may remain as torch
149
+ raises an exception when devices are expected to be the same.
150
+ The runtime was validated with model :epkg:`arnir0/Tiny-LLM`.
151
+ Next example shows how to replace a kernel with a different
152
+ one based on :epkg:`onnxruntime`.
153
+
154
+ .. runpython::
155
+ :showcode:
156
+
157
+ import numpy as np
158
+ import onnx
159
+ import onnx.helper as oh
160
+ import onnxruntime
161
+ import torch
162
+ from onnx_diagnostic.helpers import string_type
163
+ from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
164
+ from onnx_diagnostic.reference import TorchOnnxEvaluator
165
+ from onnx_diagnostic.reference.torch_ops import OpRunKernel, OpRunTensor
166
+
167
+ TFLOAT16 = onnx.TensorProto.FLOAT16
168
+
169
+ class LayerNormalizationOrt(OpRunKernel):
170
+ "LayerNormalization based on onnxruntime"
171
+
172
+ def __init__(self, node: onnx.NodeProto, version=None, verbose=0):
173
+ super().__init__(node, version, verbose=verbose)
174
+ self.axis = self.get_attribute_int(node, "axis", -1)
175
+ self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
176
+ self.stash_type = onnx_dtype_to_torch_dtype(
177
+ self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT)
178
+ )
179
+ self.compute_std = len(node.output) > 1
180
+ assert not self.compute_std, "The keren only computes the first output."
181
+ layer_model = oh.make_model(
182
+ oh.make_graph(
183
+ [
184
+ oh.make_node(
185
+ "LayerNormalization",
186
+ ["X", "W", "B"],
187
+ ["Z"],
188
+ axis=-1,
189
+ epsilon=9.999999974752427e-7,
190
+ )
191
+ ],
192
+ "dummy",
193
+ [
194
+ oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
195
+ oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
196
+ oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
197
+ ],
198
+ [oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
199
+ ),
200
+ ir_version=9,
201
+ opset_imports=[oh.make_opsetid("", 17)],
202
+ )
203
+ self.ort_sess = onnxruntime.InferenceSession(
204
+ layer_model.SerializeToString(), providers=["CUDAExecutionProvider"]
205
+ )
206
+
207
+ def run(self, x, scale, bias=None):
208
+ print(f"-- running {self.__class__.__name__}")
209
+ feeds = dict(X=x, W=scale)
210
+ if bias is not None:
211
+ feeds["B"] = bias
212
+ feeds = {k: v.tensor.detach().cpu().numpy() for k, v in feeds.items()}
213
+ got = self.ort_sess.run(None, feeds)[0]
214
+ return OpRunTensor(torch.from_numpy(got).to(x.dtype).to(x.device))
215
+
216
+ # This kernel is tested on this model.
217
+ model = oh.make_model(
218
+ oh.make_graph(
219
+ [
220
+ oh.make_node(
221
+ "LayerNormalization",
222
+ ["X", "W", "B"],
223
+ ["ln"],
224
+ axis=-1,
225
+ epsilon=9.999999974752427e-7,
226
+ ),
227
+ oh.make_node(
228
+ "Add", ["ln", "W"], ["Z"], axis=-1, epsilon=9.999999974752427e-7
229
+ ),
230
+ ],
231
+ "dummy",
232
+ [
233
+ oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
234
+ oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
235
+ oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
236
+ ],
237
+ [oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
238
+ ),
239
+ ir_version=9,
240
+ opset_imports=[oh.make_opsetid("", 17)],
241
+ )
242
+
243
+ torch_sess = TorchOnnxEvaluator(
244
+ model,
245
+ custom_kernels={("", "LayerNormalization"): LayerNormalizationOrt},
246
+ verbose=1,
247
+ )
248
+ feeds = dict(
249
+ zip(
250
+ torch_sess.input_names,
251
+ [
252
+ torch.rand(3, 4, 5, dtype=torch.float16),
253
+ torch.abs(torch.rand(5, dtype=torch.float16)),
254
+ torch.rand(5, dtype=torch.float16),
255
+ ],
256
+ )
257
+ )
258
+ res = torch_sess.run(None, feeds)
259
+ print(string_type(res, with_shape=True, with_min_max=True))
260
+ """
261
+
262
+ class IO:
263
+ "IO"
264
+
265
+ def __init__(self, name: str, type: int, shape: Tuple[Union[str, int], ...]):
266
+ self.name = name
267
+ self.type = type
268
+ self.shape = shape
269
+
270
+ @classmethod
271
+ def _on_cuda(cls, providers) -> int:
272
+ if not providers:
273
+ return -1
274
+ for p in providers:
275
+ if p == "CUDAExecutionProvider":
276
+ return 0
277
+ if isinstance(p, tuple) and p[0] == "CUDAExecutionProvider":
278
+ return p[1]["device_id"]
279
+ return -1
280
+
281
+ def __init__(
282
+ self,
283
+ proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto],
284
+ providers: Tuple[str, ...] = ("CPUExecutionProvider",),
285
+ opsets: Optional[Dict[str, int]] = None,
286
+ local_functions: Optional[Dict[Tuple[str, str], "TorchOnnxEvaluator"]] = None,
287
+ verbose: int = 0,
288
+ custom_kernels: Optional[Dict[Tuple[str, str], type[torch_ops.OpRunKernel]]] = None,
289
+ ):
290
+ self.providers = providers
291
+ self.constants: Dict[str, torch.Tensor] = {}
292
+ self.kernels: List[Optional[torch_ops.OpRunKernel]] = []
293
+ self.functions = local_functions.copy() if local_functions else {}
294
+ self.CPU = torch.tensor([0]).to("cpu").device
295
+ self.verbose = verbose
296
+ self.custom_kernels = custom_kernels or {}
297
+ dev = self._on_cuda(providers)
298
+ if dev < 0:
299
+ self.default_device = self.CPU
300
+ self.CUDA = None
301
+ else:
302
+ self.CUDA = torch.tensor([0]).to(f"cuda:{dev}").device
303
+ self.default_device = self.CUDA
304
+
305
+ if isinstance(proto, str):
306
+ proto = onnx.load(proto)
307
+ if isinstance(proto, onnx.ModelProto):
308
+ assert opsets is None, "proto is a model, opsets must be None in that case"
309
+ assert not proto.graph.sparse_initializer, "sparse_initializer not support yet"
310
+ self.opsets = {d.domain: d.version for d in proto.opset_import}
311
+ for f in proto.functions:
312
+ self.functions[f.domain, f.name] = self.__class__(
313
+ f,
314
+ providers=providers,
315
+ local_functions=self.functions,
316
+ verbose=self.verbose,
317
+ )
318
+ self._build_initializers(proto.graph.initializer)
319
+ self._build_initializers(proto.graph.node)
320
+ self._build_kernels(proto.graph.node)
321
+ self.input_names = [i.name for i in proto.graph.input]
322
+ self.output_names = [i.name for i in proto.graph.output]
323
+ self._io_input_names = [
324
+ self.IO(
325
+ name=i.name,
326
+ type=i.type.tensor_type.elem_type,
327
+ shape=tuple(
328
+ d.dim_param or d.dim_value for d in i.type.tensor_type.shape.dim
329
+ ),
330
+ )
331
+ for i in proto.graph.input
332
+ ]
333
+ self._io_output_names = [
334
+ self.IO(
335
+ name=i.name,
336
+ type=i.type.tensor_type.elem_type,
337
+ shape=tuple(
338
+ d.dim_param or d.dim_value for d in i.type.tensor_type.shape.dim
339
+ ),
340
+ )
341
+ for i in proto.graph.output
342
+ ]
343
+ elif isinstance(proto, onnx.GraphProto):
344
+ assert opsets, "opsets must be specified if proto is a graph"
345
+ assert not proto.sparse_initializer, "sparse_initializer not support yet"
346
+ self.opsets = opsets
347
+ self._build_initializers(proto.initializer)
348
+ self._build_initializers(proto.node)
349
+ self._build_kernels(proto.node)
350
+ self.input_names = [i.name for i in proto.input]
351
+ self.output_names = [i.name for i in proto.output]
352
+ elif isinstance(proto, onnx.FunctionProto):
353
+ assert opsets is None, "proto is a model, opsets must be None in that case"
354
+ self.opsets = {d.domain: d.version for d in proto.opset_import}
355
+ self._build_initializers(proto.node)
356
+ self._build_kernels(proto.node)
357
+ self.input_names = list(proto.input)
358
+ self.output_names = list(proto.output)
359
+ else:
360
+ raise TypeError(f"Unexpected type {type(proto)} for proto")
361
+
362
+ self.runtime_info = first_used_last_used(proto, constant_as_initializer=True)
363
+ self.last_used: List[List[str]] = [[] for _ in self.kernels]
364
+ for name, info in self.runtime_info.items():
365
+ assert isinstance(info.last_used, int) or info.is_input, (
366
+ f"Missing field last_used in {info!r}, last_used={info.last_used!r}, "
367
+ f"This may mean the node is unused and it should be removed."
368
+ )
369
+ if info.last_used is None:
370
+ # Not used.
371
+ self.last_used[0].append(name)
372
+ elif not info.is_output and not info.is_initializer:
373
+ self.last_used[info.last_used].append(name)
374
+
375
+ def get_inputs(self):
376
+ "Same API than onnxruntime."
377
+ assert hasattr(self, "_io_input_names"), "Missing attribute '_io_input_names'."
378
+ return self._io_input_names
379
+
380
+ def get_outputs(self):
381
+ "Same API than onnxruntime."
382
+ assert hasattr(self, "_io_output_names"), "Missing attribute '_io_output_names'."
383
+ return self._io_output_names
384
+
385
+ @property
386
+ def on_cuda(self) -> bool:
387
+ "Tells if the default device is CUDA."
388
+ return self.default_device == self.CUDA
389
+
390
+ def _build_initializers(self, inits: Sequence[Union[onnx.NodeProto, onnx.TensorProto]]):
391
+ for init in inits:
392
+ if isinstance(init, onnx.TensorProto):
393
+ self.constants[init.name] = to_tensor(init).to(self.default_device)
394
+ elif (
395
+ isinstance(init, onnx.NodeProto)
396
+ and init.op_type == "Constant"
397
+ and init.domain == ""
398
+ ):
399
+ value = None
400
+ for att in init.attribute:
401
+ if att.name == "value":
402
+ value = to_tensor(att.t).to(self.default_device)
403
+ elif att.name == "value_floats":
404
+ value = torch.tensor(list(att.floats), dtype=torch.float32).to(
405
+ self.default_device
406
+ )
407
+ assert value is not None, f"No attribute value in node {init}"
408
+ self.constants[init.output[0]] = value
409
+
410
+ def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
411
+ kernels = get_kernels()
412
+ self.kernels.clear()
413
+ for node in nodes:
414
+ kernel_kwargs = dict(verbose=max(0, self.verbose - 1))
415
+ opset = self.opsets[node.domain]
416
+ key = node.domain, node.op_type, opset
417
+ if key[:2] in self.custom_kernels:
418
+ cls = self.custom_kernels[key[:2]]
419
+ ags = [self.default_device] if cls.device_dependent() else []
420
+ kws = dict(parent=self) if cls.has_subgraphs() else {}
421
+ kws.update(kernel_kwargs) # type: ignore[arg-type]
422
+ kernel2 = cls(node, opset, *ags, **kws) # type: ignore[arg-type]
423
+ self.kernels.append(kernel2)
424
+ continue
425
+
426
+ if (node.domain, node.op_type) in self.functions:
427
+ kernel = torch_ops.OpRunFunction(
428
+ self.functions[node.domain, node.op_type],
429
+ node,
430
+ self.opsets[node.domain],
431
+ **kernel_kwargs,
432
+ )
433
+ self.kernels.append(kernel)
434
+ continue
435
+
436
+ if node.op_type == "Constant" and node.domain == "":
437
+ # Treated as a constant.
438
+ self.kernels.append(None)
439
+ continue
440
+
441
+ while key not in kernels and opset > 0:
442
+ opset -= 1
443
+ key = node.domain, node.op_type, opset
444
+ assert key in kernels, (
445
+ f"Missing kernel for node type {node.op_type!r} from domain {node.domain!r}, "
446
+ f"local functions={sorted(self.functions)}"
447
+ )
448
+ cls = kernels[key]
449
+ ags = [self.default_device] if cls.device_dependent() else []
450
+ kws = dict(parent=self) if cls.has_subgraphs() else {}
451
+ kws.update(kernel_kwargs) # type: ignore[arg-type]
452
+ kernel2 = cls(node, opset, *ags, **kws) # type: ignore[arg-type]
453
+ self.kernels.append(kernel2)
454
+
455
+ def run(
456
+ self,
457
+ outputs: Optional[List[str]],
458
+ feeds: Union[Dict[str, torch.Tensor], Dict[str, np.ndarray]],
459
+ report_cmp: Optional[ReportResultComparison] = None,
460
+ ) -> Union[List[Optional[torch.Tensor]], List[Optional[np.ndarray]]]:
461
+ """
462
+ Runs the ONNX model.
463
+
464
+ :param outputs: outputs required
465
+ :param feeds: inputs
466
+ :param report_cmp: used as a reference,
467
+ every intermediate results is compare to every existing one,
468
+ if not empty, it is an instance of
469
+ :class:`onnx_diagnostic.reference.ReportResultComparison`
470
+ :return: output tensors.
471
+ """
472
+ use_numpy = any(isinstance(t, np.ndarray) for t in feeds.values())
473
+ if use_numpy:
474
+ feeds = {k: torch.from_numpy(v) for k, v in feeds.items()}
475
+ if outputs is None:
476
+ outputs = self.output_names
477
+
478
+ # sets constants
479
+ for k, v in self.constants.items():
480
+ r = self.runtime_info[k]
481
+ if not r.has_value:
482
+ r.set_value(
483
+ torch_ops.OpRunTensor(
484
+ v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
485
+ is_constant=True,
486
+ may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
487
+ )
488
+ )
489
+ if self.verbose:
490
+ print(f"+C {r.name}: {r.string_type()}")
491
+
492
+ # inputs
493
+ for k, v in feeds.items():
494
+ r = self.runtime_info[k]
495
+ r.set_value(
496
+ torch_ops.OpRunTensor(
497
+ v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
498
+ is_constant=False,
499
+ may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
500
+ )
501
+ )
502
+ if self.verbose:
503
+ print(f"+I {r.name}: {r.string_type()}")
504
+
505
+ # node execution
506
+ for it, kernel in enumerate(self.kernels):
507
+ if kernel is not None:
508
+ if self.verbose:
509
+ print(
510
+ f"{kernel.__class__.__name__}"
511
+ f"({', '.join(kernel.input)}) -> "
512
+ f"{', '.join(kernel.output)}"
513
+ )
514
+ # kernel execution
515
+ inputs = [(self.runtime_info[i].value if i else None) for i in kernel.input]
516
+ if kernel.has_subgraphs():
517
+ res = kernel.run(*inputs, context=self.runtime_info) # type: ignore[call-arg]
518
+ else:
519
+ res = kernel.run(*inputs)
520
+ if isinstance(res, tuple):
521
+ # outputs
522
+ assert all(isinstance(o, torch_ops.OpRunValue) for o in res), (
523
+ f"Unexpected output type {[type(o) for o in res]} "
524
+ f"for kernel {type(kernel)}."
525
+ )
526
+ for name, t in zip(kernel.output, res):
527
+ self.runtime_info[name].set_value(t)
528
+ if self.verbose:
529
+ for name in kernel.output:
530
+ print(f"+R {name}: {self.runtime_info[name].string_type()}")
531
+ else:
532
+ assert isinstance(
533
+ res, torch_ops.OpRunValue
534
+ ), f"Unexpected output type {type(res)} for kernel {type(kernel)}."
535
+ self.runtime_info[kernel.output[0]].set_value(res)
536
+ if self.verbose:
537
+ print(
538
+ f"+R {kernel.output[0]}: "
539
+ f"{self.runtime_info[kernel.output[0]].string_type()}"
540
+ )
541
+ if report_cmp:
542
+ reported = report_cmp.report(
543
+ dict(
544
+ zip(
545
+ kernel.output,
546
+ (
547
+ tuple((r.tensor if r else None) for r in res) # type: ignore[attr-defined]
548
+ if isinstance(res, tuple)
549
+ else ((res.tensor if res else None),) # type: ignore[attr-defined]
550
+ ),
551
+ )
552
+ )
553
+ )
554
+ if self.verbose > 1:
555
+ print(f" -- report {len(reported)} comparisons")
556
+
557
+ # free intermediate results
558
+ for name in self.last_used[it]:
559
+ self.runtime_info[name].clean_value()
560
+ if self.verbose:
561
+ print(f"- clean {name}")
562
+
563
+ assert all(
564
+ self.runtime_info[o].value is not None for o in outputs
565
+ ), "Not implemented yet when one output is None."
566
+ fres = [self.runtime_info[o].value.tensor for o in outputs] # type: ignore[union-attr]
567
+ if self.verbose:
568
+ print(f"++ outputs {', '.join(outputs)}")
569
+
570
+ # clean previous execution
571
+ for k in feeds:
572
+ self.runtime_info[k].clean_value()
573
+ if self.verbose:
574
+ print(f"- clean {k}")
575
+ for o in outputs:
576
+ self.runtime_info[o].clean_value()
577
+ if self.verbose:
578
+ print(f"- clean {o}")
579
+
580
+ if use_numpy:
581
+ return [None if a is None else to_numpy(a) for a in fres]
582
+ return fres
583
+
584
+ def run_with_values(
585
+ self,
586
+ *args: Optional[torch_ops.OpRunTensor],
587
+ context: Optional[Dict[str, RuntimeValue]] = None,
588
+ ) -> Union[torch_ops.OpRunValue, Tuple[torch_ops.OpRunValue, ...]]:
589
+ """
590
+ Runs the ONNX model. The signature is different.
591
+ This method is called by every kernel hokding a subgraph.
592
+ The local variables are stored in `context`.
593
+
594
+ :param args: inputs
595
+ :param context: local context for the execution of subgraphs
596
+ :return: output OpRunTensor
597
+ """
598
+ assert all(
599
+ isinstance(a, torch_ops.OpRunValue) for a in args
600
+ ), f"Unexpected type in args: {[type(a) for a in args]}"
601
+ outputs = self.output_names
602
+ context = context or {}
603
+
604
+ # sets constants
605
+ for k, v in self.constants.items():
606
+ r = self.runtime_info[k]
607
+ if not r.has_value:
608
+ r.set_value(
609
+ torch_ops.OpRunTensor(
610
+ v.to(self.CUDA) if r.is_shape is False and self.on_cuda else v,
611
+ is_constant=True,
612
+ may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
613
+ )
614
+ )
615
+
616
+ # inputs
617
+ for k, v in zip(self.input_names, args):
618
+ r = self.runtime_info[k]
619
+ r.set_value(
620
+ torch_ops.OpRunTensor(None) if v is None else v.__class__(v.tensor_or_sequence)
621
+ )
622
+
623
+ # node execution
624
+ for it, kernel in enumerate(self.kernels):
625
+ if kernel is not None:
626
+ # kernel execution
627
+ inputs = [
628
+ (
629
+ (
630
+ self.runtime_info[i].value
631
+ if i in self.runtime_info
632
+ else context[i].value
633
+ )
634
+ if i
635
+ else None
636
+ )
637
+ for i in kernel.input
638
+ ]
639
+ res = kernel.run(*inputs)
640
+ if isinstance(res, tuple):
641
+ # outputs
642
+ assert all(isinstance(o, torch_ops.OpRunTensor) for o in res), (
643
+ f"Unexpected output type {[type(o) for o in res]} "
644
+ f"for kernel {type(kernel)}."
645
+ )
646
+ for name, t in zip(kernel.output, res):
647
+ self.runtime_info[name].set_value(t)
648
+ else:
649
+ assert isinstance(
650
+ res, torch_ops.OpRunValue
651
+ ), f"Unexpected output type {type(res)} for kernel {type(kernel)}."
652
+ self.runtime_info[kernel.output[0]].set_value(res)
653
+
654
+ # free intermediate results
655
+ for name in self.last_used[it]:
656
+ self.runtime_info[name].clean_value()
657
+
658
+ assert all(
659
+ self.runtime_info[o].value is not None for o in outputs
660
+ ), "Not implemented yet when one output is None."
661
+ res2 = [self.runtime_info[o].value.copy() for o in outputs] # type: ignore[assignment, union-attr]
662
+
663
+ # clean previous execution
664
+ for k in self.input_names:
665
+ self.runtime_info[k].clean_value()
666
+ for o in self.output_names:
667
+ self.runtime_info[o].clean_value()
668
+
669
+ return res2[0] if len(res2) == 1 else tuple(res2) # type: ignore[index, return-value, arg-type]