onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,987 @@
1
+ import contextlib
2
+ import ctypes
3
+ import inspect
4
+ import os
5
+ import sys
6
+ import warnings
7
+ from collections.abc import Iterable
8
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
9
+ import numpy as np
10
+ import onnx
11
+ from onnx.external_data_helper import load_external_data_for_tensor, uses_external_data
12
+ import torch
13
+ from .helper import string_type, size_type
14
+ from .cache_helper import (
15
+ make_dynamic_cache,
16
+ make_encoder_decoder_cache,
17
+ make_hybrid_cache,
18
+ make_sliding_window_cache,
19
+ make_mamba_cache,
20
+ make_static_cache,
21
+ CacheKeyValue,
22
+ )
23
+ from .mini_onnx_builder import create_onnx_model_from_input_tensors
24
+ from .onnx_helper import (
25
+ to_array_extended,
26
+ tensor_dtype_to_np_dtype,
27
+ _STORAGE_TYPE,
28
+ onnx_dtype_name,
29
+ )
30
+
31
+
32
+ def proto_from_tensor(
33
+ arr: "torch.Tensor", # noqa: F821
34
+ name: Optional[str] = None,
35
+ verbose: int = 0,
36
+ ) -> onnx.TensorProto:
37
+ """
38
+ Converts a torch Tensor into a TensorProto.
39
+
40
+ :param arr: tensor
41
+ :param verbose: display the type and shape
42
+ :return: a TensorProto
43
+ """
44
+ import torch
45
+
46
+ if not isinstance(arr, torch.Tensor):
47
+ raise TypeError(f"Unexpected type {type(arr)}.")
48
+ if arr.is_sparse:
49
+ raise NotImplementedError(
50
+ f"Sparse tensor is not supported yet but initializer {name!r} is."
51
+ )
52
+
53
+ # arr.contiguous() is slow after a transpose, maybe there is a way to optimize this.
54
+ if arr.is_contiguous():
55
+ arr_cpu = arr.cpu()
56
+ else:
57
+ arr_cpu = arr.contiguous().cpu()
58
+
59
+ numel = torch.numel(arr_cpu)
60
+ element_size = arr_cpu.element_size()
61
+
62
+ if arr_cpu.dtype in {torch.bfloat16}:
63
+ np_arr = arr_cpu
64
+ elif arr_cpu.data_ptr() == arr.data_ptr():
65
+ copy = arr_cpu.clone().detach().requires_grad_(False)
66
+ assert (
67
+ arr_cpu.data_ptr() == 0 or arr_cpu.data_ptr() != copy.data_ptr()
68
+ ), f"Pointers are not null and different {arr_cpu.data_ptr()} != {copy.data_ptr()}"
69
+ np_arr = np.from_dlpack(copy)
70
+ else:
71
+ np_arr = np.from_dlpack(arr_cpu.detach())
72
+
73
+ tensor = onnx.TensorProto()
74
+ tensor.dims.extend(arr_cpu.shape)
75
+ if name:
76
+ tensor.name = name
77
+ itype = torch_dtype_to_onnx_dtype(arr_cpu.dtype)
78
+ assert not hasattr(onnx.TensorProto, "INT4") or itype not in {
79
+ onnx.TensorProto.INT4,
80
+ onnx.TensorProto.UINT4,
81
+ }, f"Type {arr.dtype} is not supported yet for name={name!r}"
82
+ tensor.data_type = itype
83
+
84
+ if verbose > 1 and numel > 100:
85
+ print(f"[proto_from_array] {tensor.data_type}[{arr_cpu.shape}]")
86
+
87
+ if isinstance(np_arr, torch.Tensor):
88
+ byte_data = (ctypes.c_ubyte * numel * element_size).from_address(np_arr.data_ptr())
89
+ tensor.raw_data = bytes(byte_data)
90
+ if sys.byteorder == "big":
91
+ np_dtype = _STORAGE_TYPE[tensor.data_type] # type: ignore
92
+ np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True) # type: ignore
93
+ else:
94
+ tensor.raw_data = np_arr.tobytes()
95
+ if sys.byteorder == "big":
96
+ np_dtype = tensor_dtype_to_np_dtype(tensor.data_type)
97
+ np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True)
98
+ return tensor
99
+
100
+
101
+ def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821
102
+ """
103
+ Converts an onnx type into a torch dtype.
104
+
105
+ :param to: onnx dtype
106
+ :return: torch dtype
107
+ """
108
+ if itype == onnx.TensorProto.FLOAT:
109
+ return torch.float32
110
+ if itype == onnx.TensorProto.FLOAT16:
111
+ return torch.float16
112
+ if itype == onnx.TensorProto.BFLOAT16:
113
+ return torch.bfloat16
114
+ if itype == onnx.TensorProto.DOUBLE:
115
+ return torch.float64
116
+ if itype == onnx.TensorProto.INT32:
117
+ return torch.int32
118
+ if itype == onnx.TensorProto.INT64:
119
+ return torch.int64
120
+ if itype == onnx.TensorProto.UINT32:
121
+ return torch.uint32
122
+ if itype == onnx.TensorProto.UINT64:
123
+ return torch.uint64
124
+ if itype == onnx.TensorProto.BOOL:
125
+ return torch.bool
126
+ if itype == onnx.TensorProto.INT16:
127
+ return torch.int16
128
+ if itype == onnx.TensorProto.UINT16:
129
+ return torch.uint16
130
+ if itype == onnx.TensorProto.INT8:
131
+ return torch.int8
132
+ if itype == onnx.TensorProto.UINT8:
133
+ return torch.uint8
134
+ if itype == onnx.TensorProto.COMPLEX64:
135
+ return torch.complex64
136
+ if itype == onnx.TensorProto.COMPLEX128:
137
+ return torch.complex128
138
+ raise NotImplementedError(
139
+ f"Unable to convert onnx type {onnx_dtype_name(itype)} to torch.type."
140
+ )
141
+
142
+
143
+ def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821
144
+ """
145
+ Converts a torch dtype into a onnx element type.
146
+
147
+ :param to: torch dtype
148
+ :return: onnx type
149
+ """
150
+ import torch
151
+
152
+ if to == torch.float32:
153
+ return onnx.TensorProto.FLOAT
154
+ if to == torch.float16:
155
+ return onnx.TensorProto.FLOAT16
156
+ if to == torch.bfloat16:
157
+ return onnx.TensorProto.BFLOAT16
158
+ if to == torch.float64:
159
+ return onnx.TensorProto.DOUBLE
160
+ if to == torch.int64:
161
+ return onnx.TensorProto.INT64
162
+ if to == torch.int32:
163
+ return onnx.TensorProto.INT32
164
+ if to == torch.uint64:
165
+ return onnx.TensorProto.UINT64
166
+ if to == torch.uint32:
167
+ return onnx.TensorProto.UINT32
168
+ if to == torch.bool:
169
+ return onnx.TensorProto.BOOL
170
+ if to == torch.SymInt:
171
+ return onnx.TensorProto.INT64
172
+ if to == torch.int16:
173
+ return onnx.TensorProto.INT16
174
+ if to == torch.uint16:
175
+ return onnx.TensorProto.UINT16
176
+ if to == torch.int8:
177
+ return onnx.TensorProto.INT8
178
+ if to == torch.uint8:
179
+ return onnx.TensorProto.UINT8
180
+ if to == torch.SymFloat:
181
+ return onnx.TensorProto.FLOAT
182
+ if to == torch.complex64:
183
+ return onnx.TensorProto.COMPLEX64
184
+ if to == torch.complex128:
185
+ return onnx.TensorProto.COMPLEX128
186
+ raise NotImplementedError(f"Unable to convert torch dtype {to!r} to onnx dtype.")
187
+
188
+
189
+ def _forward_(
190
+ *args,
191
+ _f=None,
192
+ _fprint=string_type,
193
+ _prefix="",
194
+ _context=None,
195
+ _storage=None,
196
+ _storage_limit=2**27,
197
+ _verbose=0,
198
+ **kwargs,
199
+ ):
200
+ assert _f is not None, "_f cannot be None"
201
+ assert _context is not None, "_context cannot be None"
202
+ indent = " " * (len(_prefix) - len(_prefix.lstrip()))
203
+ _prefix = _prefix.lstrip()
204
+ print(
205
+ f"{indent}+{_prefix} -- stolen forward for class {_context['class_name']} "
206
+ f"-- iteration {_context['iteration']}"
207
+ )
208
+ kws = dict(
209
+ with_shape=_context.get("with_shape", False),
210
+ with_min_max=_context.get("with_min_max", False),
211
+ )
212
+ if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
213
+ # torch.compiler.is_exporting requires torch>=2.7
214
+ print(f"{indent} <- args={_fprint(args, **kws)} --- kwargs={_fprint(kwargs, **kws)}")
215
+ if _storage is not None:
216
+ it = _context["iteration"]
217
+ key = (_prefix, it)
218
+ _storage[(*key, "I")] = (torch_deepcopy(args), torch_deepcopy(kwargs))
219
+ res = _f(*args, **kwargs)
220
+ if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
221
+ print(f"{indent} -> {_fprint(res, **kws)}")
222
+ print(f"{indent}-{_prefix}.")
223
+ if _storage is not None:
224
+ size = torch_tensor_size(res)
225
+ if size < _storage_limit:
226
+ if _verbose:
227
+ print(
228
+ f"-- stores key={key}, size {size // 2**10}Kb -- "
229
+ f"{string_type(res, with_shape=True)}"
230
+ )
231
+ _storage[(*key, "O")] = torch_deepcopy(res)
232
+ else:
233
+ if _verbose:
234
+ print(
235
+ f"-- skips key={key}, size {size // 2**10}Kb -- "
236
+ f"{string_type(res, with_shape=True)}"
237
+ )
238
+ _context["iteration"] += 1
239
+ return res
240
+
241
+
242
+ _steal_forward_status = [False]
243
+ _additional_stolen_objects = {}
244
+
245
+
246
+ def is_stealing() -> bool:
247
+ """Returns true if :func:`steal_forward` was yielded."""
248
+ return _steal_forward_status[0]
249
+
250
+
251
+ def steal_append(name: str, obj: Any):
252
+ """
253
+ When outside a forward method, it is still possible to add
254
+ a python object which contains tensors and dump after the execution
255
+ of the model.
256
+
257
+ .. code-block:: python
258
+
259
+ steal_append("quantize", [t1, t2])
260
+
261
+ The same code can executed multiple times, then
262
+ the name can extended with a number.
263
+ """
264
+ if is_stealing():
265
+ if name in _additional_stolen_objects:
266
+ i = 1
267
+ n = f"{name}_{i}"
268
+ while n in _additional_stolen_objects:
269
+ i += 1
270
+ n = f"{name}_{i}"
271
+ print(f"-- stolen {name!r} renamed in {n!r}: {string_type(obj, with_shape=True)}")
272
+ _additional_stolen_objects[n] = obj
273
+ else:
274
+ print(f"-- stolen {name!r}: {string_type(obj, with_shape=True)}")
275
+ _additional_stolen_objects[name] = obj
276
+
277
+
278
+ @contextlib.contextmanager
279
+ def steal_forward(
280
+ model: Union[
281
+ Union[torch.nn.Module, Tuple[str, torch.nn.Module]],
282
+ List[Union[torch.nn.Module, Tuple[str, torch.nn.Module]]],
283
+ ],
284
+ fprint: Callable = string_type,
285
+ dump_file: Optional[str] = None,
286
+ dump_drop: Optional[Set[str]] = None,
287
+ submodules: bool = False,
288
+ verbose: int = 0,
289
+ storage_limit: int = 2**27,
290
+ save_as_external_data: bool = True,
291
+ **kwargs,
292
+ ):
293
+ """
294
+ The necessary modification to steem forward method and prints out inputs
295
+ and outputs using :func:`onnx_diagnostic.helpers.string_type`.
296
+ See example :ref:`l-plot-tiny-llm-export` or
297
+ :ref:`l-plot-intermediate-results`.
298
+
299
+ :param model: a model or a list of models to monitor,
300
+ every model can also be a tuple(name, model), name is displayed well.
301
+ :param fprint: function used to print out (or dump), by default, it is
302
+ :func:`onnx_diagnostic.helpers.string_type`
303
+ :param kwargs: additional parameters sent to :func:`onnx_diagnostic.helpers.string_type`
304
+ or any other function defined by ``fprint``
305
+ :param dump_file: dumps stolen inputs and outputs in an onnx model,
306
+ they can be restored with :func:`create_input_tensors_from_onnx_model
307
+ <onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
308
+ :param dump_drop: to drop some inputs too big (only if dump_file is specified)
309
+ :param save_as_external_data: True by default, but maybe better to have everything
310
+ in a single file if possible
311
+ :param submodules: if True and model is a module, the list extended with all the submodules
312
+ the module contains
313
+ :param verbose: verbosity
314
+ :param storage_limit: do not stored object bigger than this
315
+
316
+ The following examples shows how to steal and dump all the inputs / outputs
317
+ for a module and its submodules, then restores them.
318
+
319
+ .. runpython::
320
+ :showcode:
321
+
322
+ import torch
323
+ from onnx_diagnostic.helpers.torch_helper import steal_forward
324
+ from onnx_diagnostic.helpers.mini_onnx_builder import (
325
+ create_input_tensors_from_onnx_model,
326
+ )
327
+
328
+ class SubModel(torch.nn.Module):
329
+ def forward(self, x):
330
+ return x * x
331
+
332
+ class Model(torch.nn.Module):
333
+ def __init__(self):
334
+ super().__init__()
335
+ self.s1 = SubModel()
336
+ self.s2 = SubModel()
337
+
338
+ def forward(self, x, y):
339
+ return self.s1(x) + self.s2(y)
340
+
341
+ inputs = torch.rand(2, 1), torch.rand(2, 1)
342
+ model = Model()
343
+ dump_file = "dump_steal_forward_submodules.onnx"
344
+ with steal_forward(model, submodules=True, dump_file=dump_file):
345
+ model(*inputs)
346
+
347
+ # Let's restore the stolen data.
348
+ restored = create_input_tensors_from_onnx_model(dump_file)
349
+ for k, v in sorted(restored.items()):
350
+ if isinstance(v, tuple):
351
+ args, kwargs = v
352
+ print("input", k, args, kwargs)
353
+ else:
354
+ print("output", k, v)
355
+
356
+ Function :func:`steal_append` can be used to dump more tensors.
357
+ When inside the context, func:`is_stealing` returns True, False otherwise.
358
+ """
359
+ assert not is_stealing(), "steal_forward was already called."
360
+ # We clear the cache.
361
+ _steal_forward_status[0] = True
362
+ _additional_stolen_objects.clear()
363
+ assert not submodules or isinstance(
364
+ model, torch.nn.Module
365
+ ), f"submodules can only be True if model is a module but is is {type(model)}."
366
+ context = dict(iteration=0, **kwargs)
367
+ if "with_shape" not in context and fprint == string_type:
368
+ context["with_shape"] = True
369
+ if not isinstance(model, list):
370
+ assert isinstance(model, torch.nn.Module), f"Unexpected type {type(model)} for model"
371
+ if submodules:
372
+ models = []
373
+ for idx, m in model.named_modules():
374
+ level = str(idx).split(".")
375
+ ll = len(level)
376
+ try:
377
+ _, start_line = inspect.getsourcelines(m.forward)
378
+ except OSError:
379
+ # The code is not available.
380
+ start_line = 0
381
+ name = f"{idx}-{m.__class__.__name__}-{start_line}"
382
+ models.append((f"{' ' * ll}{name}", m))
383
+ model = models
384
+ else:
385
+ model = [model]
386
+ keep_model_forward = {}
387
+ storage: Optional[Dict[Any, Any]] = {} if dump_file else None
388
+ for mt in model:
389
+ name, m = mt if isinstance(mt, tuple) else ("", mt)
390
+ keep_model_forward[id(m)] = (m, m.forward)
391
+ c = context.copy()
392
+ c["class_name"] = m.__class__.__name__
393
+ m.forward = lambda *args, _f=m.forward, _fp=fprint, _c=c, _p=name, _s=storage, _v=verbose, _sl=storage_limit, **kws: _forward_( # noqa: E501
394
+ *args,
395
+ _f=_f,
396
+ _fprint=_fp,
397
+ _context=_c,
398
+ _prefix=_p,
399
+ _storage=_s,
400
+ _verbose=_v,
401
+ _storage_limit=_sl,
402
+ **kws,
403
+ )
404
+ try:
405
+ yield
406
+ finally:
407
+ _steal_forward_status[0] = False
408
+ for f in keep_model_forward.values():
409
+ f[0].forward = f[1]
410
+ if dump_file:
411
+ # Let's add the cached tensor
412
+ assert storage is not None, "storage cannot be None but mypy is confused here."
413
+ storage.update(_additional_stolen_objects)
414
+ # We clear the cache.
415
+ _additional_stolen_objects.clear()
416
+ if verbose:
417
+ size = torch_tensor_size(storage)
418
+ print(f"-- gather stored {len(storage)} objects, size={size // 2 ** 20} Mb")
419
+ if dump_drop:
420
+ for k, v in storage.items():
421
+ if k[-1] == "I":
422
+ _args, kwargs = v
423
+ ii = set(kwargs) & dump_drop
424
+ if ii:
425
+ for i in ii:
426
+ print("---", i)
427
+ del kwargs[i]
428
+ proto = create_onnx_model_from_input_tensors(storage)
429
+ if verbose:
430
+ print("-- dumps stored objects")
431
+ location = f"{os.path.split(dump_file)[-1]}.data"
432
+ if os.path.exists(location):
433
+ os.remove(location)
434
+ onnx.save(
435
+ proto,
436
+ dump_file,
437
+ save_as_external_data=save_as_external_data,
438
+ all_tensors_to_one_file=True,
439
+ location=location,
440
+ )
441
+ if verbose:
442
+ print("-- done dump stored objects")
443
+
444
+
445
+ @contextlib.contextmanager
446
+ def fake_torchdynamo_exporting():
447
+ """
448
+ Sets ``torch.compiler._is_exporting_flag`` to True to trigger
449
+ pieces of code only enabled during export.
450
+ """
451
+ memorize = torch.compiler._is_exporting_flag
452
+ torch.compiler._is_exporting_flag = True
453
+ try:
454
+ yield
455
+ finally:
456
+ torch.compiler._is_exporting_flag = memorize
457
+
458
+
459
+ def is_torchdynamo_exporting() -> bool:
460
+ """
461
+ Tells if :epkg:`torch` is exporting a model.
462
+ Relies on ``torch.compiler.is_exporting()``.
463
+ """
464
+ import torch
465
+
466
+ if not hasattr(torch.compiler, "is_exporting"):
467
+ # torch.compiler.is_exporting requires torch>=2.7
468
+ return False
469
+
470
+ try:
471
+ return torch.compiler.is_exporting()
472
+ except Exception:
473
+ try:
474
+ import torch._dynamo as dynamo
475
+
476
+ return dynamo.is_exporting() # type: ignore
477
+ except Exception:
478
+ return False
479
+
480
+
481
+ def to_numpy(tensor: "torch.Tensor") -> np.ndarray: # noqa: F821
482
+ """Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`."""
483
+ try:
484
+ return tensor.detach().cpu().numpy()
485
+ except TypeError:
486
+ # We try with ml_dtypes
487
+ pass
488
+
489
+ import ml_dtypes
490
+
491
+ conv = {torch.bfloat16: ml_dtypes.bfloat16}
492
+ assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}"
493
+ return tensor.detach().to(torch.float32).cpu().numpy().astype(conv[tensor.dtype])
494
+
495
+
496
+ def replace_string_by_dynamic(dynamic_shapes: Any) -> Any:
497
+ """Replaces strings by ``torch.export.Dim.DYNAMIC``."""
498
+ import torch
499
+
500
+ if isinstance(dynamic_shapes, torch.export.dynamic_shapes._Dim):
501
+ return dynamic_shapes
502
+ if isinstance(dynamic_shapes, str):
503
+ return torch.export.Dim.DYNAMIC
504
+ if not dynamic_shapes:
505
+ return dynamic_shapes
506
+ if isinstance(dynamic_shapes, (tuple, list)):
507
+ return type(dynamic_shapes)(replace_string_by_dynamic(i) for i in dynamic_shapes)
508
+ if isinstance(dynamic_shapes, dict):
509
+ return {k: replace_string_by_dynamic(v) for k, v in dynamic_shapes.items()}
510
+ raise AssertionError(f"Unexpected type {type(dynamic_shapes)} for dynamic_shapes")
511
+
512
+
513
+ def dummy_llm(
514
+ cls_name: Optional[str] = None,
515
+ dynamic_shapes: bool = False,
516
+ ) -> Union[
517
+ Tuple[torch.nn.Module, Tuple[torch.Tensor, ...]],
518
+ Tuple[torch.nn.Module, Tuple[torch.Tensor, ...], Any],
519
+ ]:
520
+ """
521
+ Creates a dummy LLM for test purposes.
522
+
523
+ :param cls_name: None for whole model or a piece of it
524
+ :param dynamic_shapes: returns dynamic shapes as well
525
+
526
+ .. runpython::
527
+ :showcode:
528
+
529
+ from onnx_diagnostic.helpers.torch_helper import dummy_llm
530
+ print(dummy_llm())
531
+ """
532
+
533
+ class Embedding(torch.nn.Module):
534
+ def __init__(self, vocab_size: int = 1024, embedding_dim: int = 16):
535
+ super().__init__()
536
+ self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
537
+ self.pe = torch.nn.Embedding(vocab_size, embedding_dim)
538
+
539
+ def forward(self, x):
540
+ word_emb = self.embedding(x)
541
+ word_pe = self.pe(x)
542
+ return word_emb + word_pe
543
+
544
+ class AttentionBlock(torch.nn.Module):
545
+
546
+ def __init__(self, embedding_dim: int = 16, context_size: int = 256):
547
+ super().__init__()
548
+ self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
549
+ self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
550
+ self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
551
+ # torch.nn.Buffer are not fully handled by symbolic tracing
552
+ # Buffer(...)[:Prowy()] is not working
553
+ self.mask = torch.nn.Parameter(
554
+ torch.tril(
555
+ input=torch.ones(size=[context_size, context_size], dtype=torch.float)
556
+ )
557
+ )
558
+
559
+ def forward(self, x):
560
+ _B, T, C = x.shape
561
+
562
+ query = self.query(x)
563
+ key = self.key(x)
564
+ value = self.value(x)
565
+
566
+ qk = query @ key.transpose(-2, -1) * C**-0.5
567
+ attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
568
+ attention = torch.nn.functional.softmax(input=attention, dim=-1)
569
+
570
+ out = attention @ value
571
+ return out
572
+
573
+ class MultiAttentionBlock(torch.nn.Module):
574
+
575
+ def __init__(
576
+ self, embedding_dim: int = 16, num_heads: int = 2, context_size: int = 256
577
+ ):
578
+ super().__init__()
579
+ self.attention = torch.nn.ModuleList(
580
+ modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
581
+ )
582
+ self.linear = torch.nn.Linear(
583
+ in_features=embedding_dim * num_heads, out_features=embedding_dim
584
+ )
585
+
586
+ def forward(self, x):
587
+ out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
588
+ x = self.linear(out)
589
+ return x
590
+
591
+ class FeedForward(torch.nn.Module):
592
+
593
+ def __init__(self, embedding_dim: int = 16, ff_dim: int = 128):
594
+ super().__init__()
595
+ self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
596
+ self.relu = torch.nn.ReLU()
597
+ self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)
598
+
599
+ def forward(self, x):
600
+ x = self.linear_1(x)
601
+ x = self.relu(x)
602
+ x = self.linear_2(x)
603
+ return x
604
+
605
+ class DecoderLayer(torch.nn.Module):
606
+
607
+ def __init__(
608
+ self,
609
+ embedding_dim: int = 16,
610
+ num_heads: int = 2,
611
+ context_size: int = 256,
612
+ ff_dim: int = 128,
613
+ ):
614
+ super().__init__()
615
+ self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
616
+ self.feed_forward = FeedForward(embedding_dim, ff_dim)
617
+ self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
618
+ self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
619
+
620
+ def forward(self, x):
621
+ x_norm = self.norm_1(x)
622
+ attention = self.attention(x_norm)
623
+ attention = attention + x
624
+
625
+ attention_norm = self.norm_2(attention)
626
+ ff = self.feed_forward(attention_norm)
627
+ ff = ff + attention
628
+
629
+ return ff
630
+
631
+ class LLM(torch.nn.Module):
632
+
633
+ def __init__(
634
+ self,
635
+ vocab_size: int = 1024,
636
+ embedding_dim: int = 16,
637
+ num_heads: int = 2,
638
+ context_size: int = 256,
639
+ ff_dim: int = 128,
640
+ ):
641
+ super().__init__()
642
+ self.embedding = Embedding(vocab_size, embedding_dim)
643
+ self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)
644
+
645
+ def forward(self, input_ids):
646
+ x = self.embedding(input_ids)
647
+ y = self.decoder(x)
648
+ return y
649
+
650
+ if cls_name in (None, "LLM"):
651
+ dec: torch.nn.Module = LLM()
652
+ x = torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64)
653
+ dec(x)
654
+ if dynamic_shapes:
655
+ dyn = {
656
+ "input_ids": {
657
+ 0: torch.export.Dim("batch", min=1, max=1024),
658
+ 1: torch.export.Dim("length", min=1, max=255),
659
+ }
660
+ }
661
+ return dec, (x,), dyn
662
+ return dec, (x,)
663
+
664
+ if cls_name == "DecoderLayer":
665
+ LLM()(torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64))
666
+
667
+ dec = DecoderLayer()
668
+ x = Embedding()(
669
+ torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64)
670
+ )
671
+ dec(x)
672
+ if dynamic_shapes:
673
+ dyn = {
674
+ "x": {
675
+ 0: torch.export.Dim("batch", min=1, max=1024),
676
+ 1: torch.export.Dim("length", min=1, max=255),
677
+ }
678
+ }
679
+ return dec, (x,), dyn
680
+ return dec, (x,)
681
+
682
+ if cls_name == "MultiAttentionBlock":
683
+ dec = MultiAttentionBlock()
684
+ x = torch.rand(2 if dynamic_shapes else 1, 30, 16).to(torch.float32)
685
+ dec(x)
686
+ if dynamic_shapes:
687
+ dyn = {
688
+ "x": {
689
+ 0: torch.export.Dim("batch", min=1, max=1024),
690
+ 1: torch.export.Dim("length", min=1, max=255),
691
+ }
692
+ }
693
+ return dec, (x,), dyn
694
+ return dec, (x,)
695
+
696
+ if cls_name == "AttentionBlock":
697
+ dec = AttentionBlock()
698
+ x = torch.rand(2 if dynamic_shapes else 1, 30, 16).to(torch.float32)
699
+ dec(x)
700
+ if dynamic_shapes:
701
+ dyn = {
702
+ "x": {
703
+ 0: torch.export.Dim("batch", min=1, max=1024),
704
+ 1: torch.export.Dim("length", min=1, max=255),
705
+ }
706
+ }
707
+ return dec, (x,), dyn
708
+ return dec, (x,)
709
+
710
+ raise NotImplementedError(f"cls_name={cls_name}")
711
+
712
+
713
+ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
714
+ """Applies torch.to if applicable. Goes recursively."""
715
+ if isinstance(value, (torch.nn.Module, torch.Tensor)) and value.__class__.__name__ not in {
716
+ "DynamicCache",
717
+ "EncoderDecoderCache",
718
+ }:
719
+ if (
720
+ (
721
+ isinstance(to_value, torch.dtype)
722
+ or to_value in {"float16", "bfloat16", "float32", "float64"}
723
+ )
724
+ and hasattr(value, "dtype")
725
+ and value.dtype in {torch.int32, torch.int64, torch.int8, torch.int16}
726
+ ):
727
+ # int vector should not be changed.
728
+ return value
729
+ return value.to(to_value)
730
+ if isinstance(value, list):
731
+ return [to_any(t, to_value) for t in value]
732
+ if isinstance(value, tuple):
733
+ return tuple(to_any(t, to_value) for t in value)
734
+ if isinstance(value, set):
735
+ return {to_any(t, to_value) for t in value}
736
+ if type(value) is dict:
737
+ return {k: to_any(t, to_value) for k, t in value.items()}
738
+ if value.__class__.__name__ in {"DynamicCache", "HybridCache"}:
739
+ make = dict(DynamicCache=make_dynamic_cache, HybridCache=make_hybrid_cache)
740
+ cc = CacheKeyValue(value)
741
+ return make[value.__class__.__name__]( # type: ignore[operator]
742
+ list(
743
+ zip(
744
+ [t.to(to_value) if t is not None else t for t in cc.key_cache],
745
+ [t.to(to_value) if t is not None else t for t in cc.value_cache],
746
+ )
747
+ )
748
+ )
749
+ if value.__class__.__name__ == "StaticCache":
750
+ cc = CacheKeyValue(value)
751
+ return make_static_cache(
752
+ list(
753
+ zip(
754
+ [t.to(to_value) if t is not None else t for t in cc.key_cache],
755
+ [t.to(to_value) if t is not None else t for t in cc.value_cache],
756
+ )
757
+ ),
758
+ max_cache_len=value.max_cache_len,
759
+ )
760
+ if value.__class__.__name__ == "EncoderDecoderCache":
761
+ return make_encoder_decoder_cache(
762
+ to_any(value.self_attention_cache, to_value),
763
+ to_any(value.cross_attention_cache, to_value),
764
+ )
765
+ if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
766
+ args, spec = torch.utils._pytree.tree_flatten(value)
767
+ new_args = to_any(args, to_value)
768
+ return torch.utils._pytree.tree_unflatten(new_args, spec)
769
+
770
+ if hasattr(value, "to"):
771
+ return value.to(to_value)
772
+
773
+ assert "Cache" not in value.__class__.__name__, (
774
+ f"Class {value.__class__.__name__!r} should be registered "
775
+ f"to be able to change the type in every tensor it contains."
776
+ )
777
+ assert not isinstance(value, Iterable), f"Unsupported type {type(value)}"
778
+ return value
779
+
780
+
781
+ def torch_deepcopy(value: Any) -> Any:
782
+ """
783
+ Makes a deep copy.
784
+
785
+ :param value: any value
786
+ :return: a deep copy
787
+ """
788
+ if value is None:
789
+ return None
790
+ if isinstance(value, (int, float, str)):
791
+ return value
792
+ if isinstance(value, tuple):
793
+ return tuple(torch_deepcopy(v) for v in value)
794
+ if isinstance(value, list):
795
+ return [torch_deepcopy(v) for v in value]
796
+ if isinstance(value, set):
797
+ return {torch_deepcopy(v) for v in value}
798
+ if isinstance(value, dict):
799
+ if type(value) is dict:
800
+ return {k: torch_deepcopy(v) for k, v in value.items()}
801
+ # for BaseModelOutput
802
+ return value.__class__(**{k: torch_deepcopy(v) for k, v in value.items()})
803
+ if isinstance(value, np.ndarray):
804
+ return value.copy()
805
+ if hasattr(value, "clone"):
806
+ return value.clone()
807
+ if value.__class__.__name__ == "DynamicCache":
808
+ from .cache_helper import CacheKeyValue
809
+
810
+ ca = CacheKeyValue(value)
811
+ return make_dynamic_cache(torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))))
812
+ if value.__class__.__name__ == "StaticCache":
813
+ from .cache_helper import CacheKeyValue
814
+
815
+ ca = CacheKeyValue(value)
816
+ if len(ca.key_cache) == 0:
817
+ # Use of deepcopy.
818
+ import copy
819
+
820
+ return copy.deepcopy(value)
821
+ return make_static_cache(
822
+ torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))),
823
+ max_cache_len=max([value.max_cache_len, *[t.shape[2] for t in ca.key_cache]]),
824
+ )
825
+ if value.__class__.__name__ == "HybridCache":
826
+ from .cache_helper import CacheKeyValue
827
+
828
+ ca = CacheKeyValue(value)
829
+ return make_hybrid_cache(torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))))
830
+ if value.__class__.__name__ == "SlidingWindowCache":
831
+ from .cache_helper import CacheKeyValue
832
+
833
+ ca = CacheKeyValue(value)
834
+ return make_sliding_window_cache(
835
+ torch_deepcopy(list(zip(ca.key_cache, ca.value_cache)))
836
+ )
837
+ if value.__class__.__name__ == "EncoderDecoderCache":
838
+ return make_encoder_decoder_cache(
839
+ torch_deepcopy(value.self_attention_cache),
840
+ torch_deepcopy(value.cross_attention_cache),
841
+ )
842
+ if value.__class__.__name__ == "MambaCache":
843
+ return make_mamba_cache(list(zip(value.conv_states, value.ssm_states)))
844
+
845
+ if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
846
+ args, spec = torch.utils._pytree.tree_flatten(value)
847
+ new_args = torch_deepcopy(args)
848
+ return torch.utils._pytree.tree_unflatten(new_args, spec)
849
+
850
+ if value.__class__.__name__ == "Results":
851
+ import copy
852
+ import ultralytics
853
+
854
+ assert isinstance(
855
+ value, ultralytics.engine.results.Results
856
+ ), f"Unexpected type={type(value)}"
857
+ return copy.deepcopy(value)
858
+
859
+ if hasattr(value, "__nocopy__"):
860
+ return value
861
+
862
+ # We should have a code using serialization, deserialization assuming a model
863
+ # cannot be exported without them.
864
+ raise NotImplementedError(
865
+ f"torch_deepcopy not implemented for type {type(value)}, "
866
+ f"add attribute '__nocopy__' to return it as is."
867
+ )
868
+
869
+
870
+ def torch_tensor_size(value: Any) -> Any:
871
+ """Returns the number of bytes stored in tensors."""
872
+ if value is None:
873
+ return 0
874
+ if isinstance(value, (int, float, str)):
875
+ return 0
876
+ if isinstance(value, (tuple, list, set)):
877
+ return sum(torch_tensor_size(v) for v in value)
878
+ if isinstance(value, dict):
879
+ return sum(torch_tensor_size(v) for v in value.values())
880
+ if isinstance(value, np.ndarray):
881
+ return value.copy()
882
+ if hasattr(value, "clone"):
883
+ return value.numel() * size_type(value.dtype)
884
+ if value.__class__.__name__ in {
885
+ "DynamicCache",
886
+ "SlidingWindowCache",
887
+ "HybridCache",
888
+ "StaticCache",
889
+ }:
890
+ cc = CacheKeyValue(value)
891
+ return torch_tensor_size(cc.key_cache) + torch_tensor_size(cc.value_cache)
892
+ if value.__class__.__name__ == "EncoderDecoderCache":
893
+ return torch_tensor_size(value.self_attention_cache) + torch_tensor_size(
894
+ value.cross_attention_cache
895
+ )
896
+ if value.__class__.__name__ == "MambaCache":
897
+ return torch_tensor_size(value.conv_states) + torch_tensor_size(value.ssm_states)
898
+ if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
899
+ args, _spec = torch.utils._pytree.tree_flatten(value)
900
+ return sum(torch_tensor_size(a) for a in args)
901
+
902
+ # We should have a code using serialization, deserialization assuming a model
903
+ # cannot be exported without them.
904
+ raise NotImplementedError(f"torch_tensor_size not implemented for type {type(value)}")
905
+
906
+
907
+ def model_statistics(model: torch.nn.Module):
908
+ """Returns statistics on a model in a dictionary."""
909
+ n_subs = len(list(model.modules()))
910
+ sizes = {}
911
+ param_size = 0
912
+ for param in model.parameters():
913
+ size = param.nelement() * param.element_size()
914
+ param_size += size
915
+ name = str(param.dtype).replace("torch.", "")
916
+ if name not in sizes:
917
+ sizes[name] = 0
918
+ sizes[name] += size
919
+
920
+ buffer_size = 0
921
+ for buffer in model.buffers():
922
+ size = buffer.nelement() * buffer.element_size()
923
+ buffer_size += size
924
+ name = str(buffer.dtype).replace("torch.", "")
925
+ if name not in sizes:
926
+ sizes[name] = 0
927
+ sizes[name] += size
928
+
929
+ res = dict(
930
+ type=model.__class__.__name__,
931
+ n_modules=n_subs,
932
+ param_size=param_size,
933
+ buffer_size=buffer_size,
934
+ size_mb=(param_size + buffer_size) // 2**20,
935
+ )
936
+ res.update(sizes)
937
+ return res
938
+
939
+
940
+ def to_tensor(tensor: onnx.TensorProto, base_dir: str = "") -> torch.Tensor:
941
+ """
942
+ Converts a TensorProto to a numpy array.
943
+
944
+ :param tensor: a TensorProto object.
945
+ :param base_dir: if external tensor exists, base_dir can help to find the path to it
946
+ :return: the converted tensor
947
+ """
948
+ assert not tensor.HasField("segment"), "Currently not supporting loading segments."
949
+ assert (
950
+ tensor.data_type != onnx.TensorProto.UNDEFINED
951
+ ), "The element type in the input tensor is not defined."
952
+ assert tensor.data_type != onnx.TensorProto.STRING, "to_tensor not implemented for strings"
953
+
954
+ tensor_dtype = tensor.data_type
955
+ torch_dtype = onnx_dtype_to_torch_dtype(tensor_dtype)
956
+ dims = tuple(tensor.dims)
957
+ if uses_external_data(tensor):
958
+ # Load raw data from external tensor if it exists
959
+ load_external_data_for_tensor(tensor, base_dir)
960
+
961
+ if tensor.HasField("raw_data"):
962
+ raw_data = tensor.raw_data
963
+ if len(raw_data) == 0:
964
+ return torch.tensor([], dtype=torch_dtype).reshape(dims)
965
+ if sys.byteorder == "big":
966
+ # Convert endian from little to big
967
+ raw_data = torch.frombuffer(raw_data, dtype=torch_dtype).byteswap().tobytes()
968
+ with warnings.catch_warnings():
969
+ warnings.simplefilter("ignore")
970
+ return torch.frombuffer(raw_data, dtype=torch_dtype).reshape(dims)
971
+
972
+ # Other cases, it should be small tensor. We use numpy.
973
+ np_tensor = to_array_extended(tensor)
974
+ return torch.from_numpy(np_tensor)
975
+
976
+
977
+ def get_weight_type(model: torch.nn.Module) -> torch.dtype:
978
+ """Returns the most probable dtype in a model."""
979
+ counts = {}
980
+ for _name, param in model.named_parameters():
981
+ dt = param.dtype
982
+ if dt not in counts:
983
+ counts[dt] = 1
984
+ else:
985
+ counts[dt] += 1
986
+ final = max(list(counts.items()))
987
+ return final[0]