onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,600 @@
1
+ import ctypes
2
+ import sys
3
+ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
4
+ import numpy as np
5
+ from onnx import GraphProto, ModelProto, NodeProto, TensorProto
6
+ import onnx.helper as oh
7
+ import torch
8
+ from .onnx_helper import dtype_to_tensor_dtype, tensor_dtype_to_np_dtype, from_array_extended
9
+ from . import string_type
10
+
11
+ STORAGE_TYPE = {
12
+ TensorProto.FLOAT16: np.int16,
13
+ TensorProto.BFLOAT16: np.int16,
14
+ }
15
+
16
+
17
+ def proto_from_array(
18
+ arr: torch.Tensor,
19
+ name: Optional[str] = None,
20
+ verbose: int = 0,
21
+ ) -> TensorProto:
22
+ """
23
+ Converts a torch Tensor into a TensorProto.
24
+
25
+ :param arr: tensor
26
+ :param verbose: display the type and shape
27
+ :return: a TensorProto
28
+ """
29
+ if not isinstance(arr, torch.Tensor):
30
+ raise TypeError(f"Unexpected type {type(arr)}.")
31
+ if arr.is_sparse:
32
+ raise NotImplementedError(
33
+ f"Sparse tensor is not supported yet but initializer {name!r} is."
34
+ )
35
+
36
+ # arr.contiguous() is slow after a transpose, maybe there is a way to optimize this.
37
+ arr_cpu = arr.cpu() if arr.is_contiguous() else arr.contiguous().cpu()
38
+
39
+ numel = torch.numel(arr_cpu)
40
+ element_size = arr_cpu.element_size()
41
+
42
+ if arr_cpu.dtype in {torch.bfloat16}:
43
+ np_arr = arr_cpu
44
+ elif arr_cpu.data_ptr() == arr.data_ptr():
45
+ copy = arr_cpu.clone().detach().requires_grad_(False)
46
+ assert (
47
+ arr_cpu.data_ptr() == 0 or arr_cpu.data_ptr() != copy.data_ptr()
48
+ ), f"Pointers are not null and different {arr_cpu.data_ptr()} != {copy.data_ptr()}"
49
+ np_arr = np.from_dlpack(copy)
50
+ else:
51
+ np_arr = np.from_dlpack(arr_cpu.detach())
52
+
53
+ tensor = TensorProto()
54
+ tensor.dims.extend(arr_cpu.shape)
55
+ tensor.name = name
56
+ itype = dtype_to_tensor_dtype(arr_cpu.dtype)
57
+ assert not hasattr(TensorProto, "INT4") or itype not in {
58
+ TensorProto.INT4,
59
+ TensorProto.UINT4,
60
+ }, f"Type {arr.dtype} is not supported yet for name={name!r}"
61
+ tensor.data_type = itype
62
+
63
+ if verbose > 1 and numel > 100:
64
+ print(f"[proto_from_array] {tensor.data_type}[{arr_cpu.shape}]")
65
+
66
+ if isinstance(np_arr, torch.Tensor):
67
+ byte_data = (ctypes.c_ubyte * numel * element_size).from_address(np_arr.data_ptr())
68
+ tensor.raw_data = bytes(byte_data)
69
+ if sys.byteorder == "big":
70
+ np_dtype = tensor_dtype_to_np_dtype(STORAGE_TYPE[tensor.data_type])
71
+ np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True)
72
+ else:
73
+ tensor.raw_data = np_arr.tobytes()
74
+ if sys.byteorder == "big":
75
+ np_dtype = tensor_dtype_to_np_dtype(tensor.data_type)
76
+ np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True)
77
+
78
+ return tensor
79
+
80
+
81
+ class MiniOnnxBuilder:
82
+ """
83
+ Simplified builder to build very simple model.
84
+
85
+ :param target_opset: opset to specify
86
+ :param ir_verison: IR version to use
87
+ :param sep: separator to build output names
88
+ """
89
+
90
+ def __init__(self, target_opset: int = 18, ir_version: int = 10, sep: str = "___"):
91
+ self.initializers_dict: Dict[str, Any] = {}
92
+ self.inputs: List[Any] = []
93
+ self.outputs: List[Any] = []
94
+ self.nodes: List[NodeProto] = []
95
+ self.opsets = {"": target_opset}
96
+ self.ir_version = ir_version
97
+ self.sep = sep
98
+
99
+ def append_output_initializer(
100
+ self,
101
+ name: str,
102
+ tensor: Union[np.ndarray, torch.Tensor],
103
+ randomize: bool = False,
104
+ ):
105
+ """
106
+ Adds an initializer as an output.
107
+ The initializer name is prefixed by ``t_``.
108
+ The output name is *name*.
109
+ If `randomize` is True, the tensor is not stored but replaced by a random generator.
110
+ """
111
+ if randomize:
112
+ dtype = dtype_to_tensor_dtype(tensor.dtype)
113
+ if dtype in {
114
+ TensorProto.FLOAT,
115
+ TensorProto.FLOAT16,
116
+ TensorProto.DOUBLE,
117
+ TensorProto.BFLOAT16,
118
+ }:
119
+ mini, maxi = tensor.min(), tensor.max()
120
+ if mini < 0 and maxi > 0:
121
+ op_type = "RandomNormal"
122
+ kwargs = {
123
+ "mean": float(tensor.mean()),
124
+ "scale": float(tensor.std()),
125
+ "seed": 0.0,
126
+ }
127
+ else:
128
+ op_type = "RandomUniform"
129
+ kwargs = {
130
+ "low": float(mini),
131
+ "high": float(maxi),
132
+ "seed": 0.0,
133
+ }
134
+ shape = tuple(map(int, tensor.shape))
135
+ self.nodes.append(
136
+ oh.make_node(op_type, [], [name], dtype=dtype, shape=shape, **kwargs)
137
+ )
138
+ self.outputs.append(oh.make_tensor_value_info(name, dtype, shape))
139
+ return
140
+
141
+ init_name = f"t_{name}"
142
+ assert (
143
+ init_name not in self.initializers_dict
144
+ ), f"name={init_name!r} already in {sorted(self.initializers_dict)}"
145
+ self.initializers_dict[init_name] = tensor
146
+ shape = tuple(map(int, tensor.shape))
147
+ self.outputs.append(
148
+ oh.make_tensor_value_info(name, dtype_to_tensor_dtype(tensor.dtype), shape)
149
+ )
150
+ self.nodes.append(oh.make_node("Identity", [init_name], [name]))
151
+
152
+ def append_output_sequence(
153
+ self, name: str, tensors: List[Union[np.ndarray, torch.Tensor]]
154
+ ):
155
+ """
156
+ Adds a sequence of initializers as an output.
157
+ The initializers names are prefixed by ``seq_``.
158
+ The output name is ``name``.
159
+ """
160
+ if not tensors:
161
+ # empty list
162
+ self.nodes.append(oh.make_node("SequenceEmpty", [], [name]))
163
+ tensor_type_proto = oh.make_tensor_type_proto(
164
+ elem_type=TensorProto.FLOAT, shape=None
165
+ )
166
+ else:
167
+ assert all(
168
+ isinstance(t, (np.ndarray, torch.Tensor)) for t in tensors
169
+ ), f"Nested sequences are not supported, types are {[type(t) for t in tensors]}"
170
+ names = []
171
+ for i, t in enumerate(tensors):
172
+ init_name = f"seq_{name}_{i}"
173
+ self.initializers_dict[init_name] = t
174
+ names.append(init_name)
175
+
176
+ self.nodes.append(oh.make_node("SequenceConstruct", names, [name]))
177
+ tensor_type_proto = oh.make_tensor_type_proto(
178
+ elem_type=dtype_to_tensor_dtype(tensors[0].dtype), shape=None
179
+ )
180
+
181
+ sequence_type_proto = oh.make_sequence_type_proto(tensor_type_proto)
182
+ output = oh.make_value_info(name, type_proto=sequence_type_proto)
183
+ self.outputs.append(output)
184
+
185
+ def append_output_dict(
186
+ self, name: str, tensors: Dict[str, Union[np.ndarray, torch.Tensor]]
187
+ ):
188
+ """
189
+ Adds two outputs, a string tensors for the keys and a sequence of tensors
190
+ for the values.
191
+
192
+ The output name is ``name___keys`` and ``name___values``.
193
+ """
194
+ keys = []
195
+ values = []
196
+ for k, v in tensors.items():
197
+ keys.append(k)
198
+ values.append(v)
199
+ self.append_output_initializer(f"{name}{self.sep}keys", np.array(keys, dtype=np.str_))
200
+ self.append_output_sequence(f"{name}{self.sep}values", values)
201
+
202
+ def _build_initializers(self, switch_low_high: bool) -> List[TensorProto]:
203
+ """
204
+ Builds initializers.
205
+
206
+ :param switch_low_high: invert low, high precision
207
+ :return: a list of tensors to stored in the model
208
+ """
209
+ init_dict = self.initializers_dict
210
+ if switch_low_high:
211
+ # Let's try to minimize the time.
212
+ initializer: List[TensorProto] = []
213
+ for k, v in init_dict.items():
214
+ if isinstance(v, TensorProto):
215
+ initializer.append(v)
216
+ continue
217
+
218
+ if isinstance(v, np.ndarray):
219
+ itype = dtype_to_tensor_dtype(v.dtype)
220
+ if itype in {
221
+ TensorProto.BOOL,
222
+ TensorProto.STRING,
223
+ TensorProto.UNDEFINED,
224
+ TensorProto.COMPLEX64,
225
+ TensorProto.COMPLEX128,
226
+ getattr(TensorProto, "UINT4", 0),
227
+ getattr(TensorProto, "INT4", 0),
228
+ }:
229
+ t = from_array_extended(v, name=k)
230
+ initializer.append(t)
231
+ continue
232
+
233
+ from_np = True
234
+ elif isinstance(v, np.float32):
235
+ t = from_array_extended(np.array([v], dtype=np.float32), name=k)
236
+ initializer.append(t)
237
+ continue
238
+ elif isinstance(v, np.float64):
239
+ t = from_array_extended(np.array([v], dtype=np.float64), name=k)
240
+ initializer.append(t)
241
+ continue
242
+ elif isinstance(v, np.float16):
243
+ t = from_array_extended(np.array([v], dtype=np.float16), name=k)
244
+ initializer.append(t)
245
+ continue
246
+ else:
247
+ assert isinstance(
248
+ v, torch.Tensor
249
+ ), f"tensor {k!r} has un unexpected type {type(v)}"
250
+ assert "FakeTensor" not in str(
251
+ type(v)
252
+ ), f"tensor {k!r} cannot be a FakeTensor: {type(v)}"
253
+ from_np = False
254
+ itype = dtype_to_tensor_dtype(v.dtype)
255
+
256
+ # How to avoid a copy?
257
+ if from_np:
258
+ tensor = TensorProto()
259
+ tensor.name = k
260
+ tensor.dims.extend(v.shape)
261
+ tensor.data_type = itype
262
+ tensor.raw_data = v.tobytes()
263
+ else:
264
+ tensor = proto_from_array(v, name=k)
265
+
266
+ initializer.append(tensor)
267
+
268
+ return initializer
269
+
270
+ res: List[TensorProto] = []
271
+ for k, v in init_dict.items():
272
+ if isinstance(v, TensorProto):
273
+ res.append(v)
274
+ continue
275
+ if isinstance(v, torch.Tensor):
276
+ # no string tensor
277
+ t = proto_from_array(v, name=k)
278
+ res.append(t)
279
+ continue
280
+ if isinstance(v, np.ndarray):
281
+ t = from_array_extended(v, name=k)
282
+ res.append(t)
283
+ continue
284
+ raise TypeError(
285
+ f"Unable to convert initializer {k!r} with type "
286
+ f"{type(v)} into a TensorProto."
287
+ )
288
+ return res
289
+
290
+ def to_onnx(self) -> ModelProto:
291
+ """
292
+ Conversion to onnx.
293
+ :return: the proto
294
+ """
295
+ opsets = [oh.make_opsetid(*o) for o in self.opsets.items()]
296
+ ir_version = self.ir_version
297
+ model = ModelProto()
298
+ model.graph.CopyFrom(GraphProto())
299
+ model.graph.name = "mini_model"
300
+ model.graph.input.extend(self.inputs)
301
+ model.graph.node.extend(self.nodes)
302
+ model.graph.output.extend(self.outputs)
303
+ initializers = self._build_initializers(switch_low_high=sys.byteorder != "big")
304
+ model.graph.initializer.extend(initializers)
305
+ model.opset_import.extend(opsets)
306
+ model.ir_version = ir_version
307
+ return model
308
+
309
+
310
+ def _flatten_iterator(obj: Any, sep: str) -> Iterator:
311
+ """Iterates on all object."""
312
+ if obj is not None:
313
+ if isinstance(obj, np.ndarray):
314
+ yield "array", obj
315
+ elif isinstance(obj, torch.Tensor):
316
+ yield "tensor", obj
317
+ elif isinstance(obj, bool):
318
+ yield "bool", np.array([obj], dtype=np.bool_)
319
+ elif isinstance(obj, int):
320
+ yield "int", np.array([obj], dtype=np.int64)
321
+ elif isinstance(obj, float):
322
+ yield "float", np.array([obj], dtype=np.float64)
323
+ elif isinstance(obj, tuple):
324
+ if not obj:
325
+ yield f"tuple.{sep}empty", None
326
+ else:
327
+ for i, o in enumerate(obj):
328
+ if i == len(obj) - 1:
329
+ for p, oo in _flatten_iterator(o, sep):
330
+ yield f"tuple_{i}.{sep}{p}", oo
331
+ else:
332
+ for p, oo in _flatten_iterator(o, sep):
333
+ yield f"tuple_{i}{sep}{p}", oo
334
+ elif isinstance(obj, list):
335
+ if not obj:
336
+ yield f"list.{sep}empty", None
337
+ else:
338
+ for i, o in enumerate(obj):
339
+ if i == len(obj) - 1:
340
+ for p, oo in _flatten_iterator(o, sep):
341
+ yield f"list_{i}.{sep}{p}", oo
342
+ else:
343
+ for p, oo in _flatten_iterator(o, sep):
344
+ yield f"list_{i}{sep}{p}", oo
345
+ elif isinstance(obj, dict):
346
+ if not obj:
347
+ yield f"dict.{sep}empty", None
348
+ else:
349
+ for i, (k, v) in enumerate(obj.items()):
350
+ assert sep not in k, (
351
+ f"Key {k!r} cannot contain '{sep}'. "
352
+ f"It would interfere with the serialization."
353
+ )
354
+
355
+ def _mk(k):
356
+ if isinstance(k, tuple):
357
+ # this assumes the tuple contains simple types
358
+ return f"(({','.join(map(str,k))}))"
359
+ return str(k)
360
+
361
+ if i == len(obj) - 1:
362
+ for p, o in _flatten_iterator(v, sep):
363
+ yield f"dict._{_mk(k)}{sep}{p}", o
364
+ else:
365
+ for p, o in _flatten_iterator(v, sep):
366
+ yield f"dict_{_mk(k)}{sep}{p}", o
367
+ elif obj.__class__.__name__ == "DynamicCache":
368
+ # transformers
369
+ import transformers
370
+ from .cache_helper import CacheKeyValue
371
+
372
+ assert isinstance(
373
+ obj, transformers.cache_utils.DynamicCache
374
+ ), f"Unexpected type {type(obj)}"
375
+ obj = CacheKeyValue(obj)
376
+ atts = ["key_cache", "value_cache"]
377
+ for i, att in enumerate(atts):
378
+ if i == len(atts) - 1:
379
+ for p, o in _flatten_iterator(getattr(obj, att), sep):
380
+ yield f"DynamicCache._{att}{sep}{p}", o
381
+ else:
382
+ for p, o in _flatten_iterator(getattr(obj, att), sep):
383
+ yield f"DynamicCache_{att}{sep}{p}", o
384
+ elif obj.__class__.__name__ == "StaticCache":
385
+ # transformers
386
+ import transformers
387
+ from .cache_helper import CacheKeyValue
388
+
389
+ assert isinstance(
390
+ obj, transformers.cache_utils.StaticCache
391
+ ), f"Unexpected type {type(obj)}"
392
+ obj = CacheKeyValue(obj)
393
+ atts = ["key_cache", "value_cache"]
394
+ for i, att in enumerate(atts):
395
+ if i == len(atts) - 1:
396
+ for p, o in _flatten_iterator(getattr(obj, att), sep):
397
+ yield f"StaticCache._{att}{sep}{p}", o
398
+ else:
399
+ for p, o in _flatten_iterator(getattr(obj, att), sep):
400
+ yield f"StaticCache_{att}{sep}{p}", o
401
+ else:
402
+ raise NotImplementedError(f"Unexpected type {type(obj)}")
403
+
404
+
405
+ def create_onnx_model_from_input_tensors(
406
+ inputs: Any,
407
+ switch_low_high: Optional[bool] = None,
408
+ randomize: bool = False,
409
+ sep: str = "___",
410
+ ) -> ModelProto:
411
+ """
412
+ Creates a model proto including all the value as initializers.
413
+ They can be restored by executing the model.
414
+ We assume these inputs are not bigger than 2Gb,
415
+ the limit of protobuf. Nothing is implemented yet to get around
416
+ that limit.
417
+
418
+ :param inputs: anything
419
+ :param switch_low_high: if None, it is equal to ``switch_low_high=sys.byteorder != "big"``
420
+ :param randomize: if True, float tensors are not stored but randomized to save space
421
+ :param sep: separator
422
+ :return: ModelProto
423
+
424
+ The function raises an error if not supported.
425
+ """
426
+ if switch_low_high is None:
427
+ switch_low_high = sys.byteorder != "big"
428
+
429
+ builder = MiniOnnxBuilder(sep=sep)
430
+ for prefix, o in _flatten_iterator(inputs, sep):
431
+ if o is None:
432
+ builder.append_output_initializer(prefix, np.array([]))
433
+ else:
434
+ builder.append_output_initializer(prefix, o, randomize=randomize)
435
+ model = builder.to_onnx()
436
+ model.doc_string = string_type(inputs, True, True)
437
+ return model
438
+
439
+
440
+ def _unflatten(
441
+ sep: str,
442
+ names: List[str],
443
+ outputs: List[Any],
444
+ pos: int = 0,
445
+ level: int = 0,
446
+ device: str = "cpu",
447
+ ) -> Tuple[int, Any]:
448
+ """Unflattens a list of outputs flattened with :func:`flatten_iterator`."""
449
+ name = names[pos]
450
+ spl = name.split(sep)
451
+ if len(spl) == level + 1:
452
+ # A tensor.
453
+ if spl[-1] == "empty":
454
+ return pos + 1, None
455
+ if spl[-1] == "bool":
456
+ return pos + 1, bool(outputs[pos][0])
457
+ if spl[-1] == "int":
458
+ return pos + 1, int(outputs[pos][0])
459
+ if spl[-1] == "float":
460
+ return pos + 1, float(outputs[pos][0])
461
+ if spl[-1] == "array":
462
+ return pos + 1, outputs[pos]
463
+ if spl[-1] == "tensor":
464
+ return pos + 1, torch.from_numpy(outputs[pos]).to(device)
465
+ raise AssertionError(f"Unexpected name {name!r} in {names}")
466
+
467
+ res: List[Any] = []
468
+ while True:
469
+ assert pos < len(names), f"Something went wrong with names={names!r}\nres={res!r}"
470
+ name = names[pos]
471
+ spl = name.split(sep)
472
+ prefix = spl[level]
473
+ next_pos, value = _unflatten(
474
+ sep, names, outputs, pos=pos, level=level + 1, device=device
475
+ )
476
+
477
+ if prefix.startswith("DynamicCache"):
478
+ key = prefix.split("_", maxsplit=1)[-1]
479
+ res.append((key, value))
480
+ lp = len("DynamicCache")
481
+ end = len(prefix) > lp and prefix[lp] == "."
482
+ elif prefix.startswith("dict"):
483
+ key = prefix.split("_", maxsplit=1)[-1]
484
+ res.append((key, value))
485
+ end = len(prefix) > 4 and prefix[4] == "."
486
+ else:
487
+ res.append(value)
488
+ end = prefix[-1] == "."
489
+
490
+ if end:
491
+ if prefix.startswith("dict"):
492
+ ty: type = dict
493
+ elif prefix.startswith("list"):
494
+ ty = list
495
+ elif prefix.startswith("tuple"):
496
+ ty = tuple
497
+ elif prefix.startswith("DynamicCache"):
498
+ from transformers.cache_utils import DynamicCache
499
+
500
+ ty = DynamicCache
501
+ else:
502
+ raise AssertionError(f"Unexpected prefix={prefix!r}")
503
+ break
504
+ pos = next_pos
505
+
506
+ def _tryint(s):
507
+ try:
508
+ return int(s)
509
+ except (ValueError, TypeError):
510
+ if s in {"True", "False"}:
511
+ return s == "True"
512
+ return s
513
+
514
+ def _make(ty: type, res: Any) -> Any:
515
+ if ty.__name__ == "DynamicCache":
516
+ from .cache_helper import CacheKeyValue
517
+
518
+ cc = CacheKeyValue()
519
+ for k, v in res:
520
+ setattr(cc, k, v)
521
+ r = cc.make_dynamic_cache()
522
+ return r
523
+ if ty is dict:
524
+ d = {}
525
+ for k, v in res:
526
+ if k.startswith("((") and k.endswith("))"):
527
+ spl = k[2:-2].split(",")
528
+ key = tuple(_tryint(s) for s in spl)
529
+ else:
530
+ key = _tryint(k)
531
+ d[key] = v
532
+ return d
533
+ return ty(res)
534
+
535
+ return next_pos, (
536
+ ty() if len(res) == 1 and res[0] in (("dict.", None), None) else _make(ty, res)
537
+ )
538
+
539
+
540
+ def create_input_tensors_from_onnx_model(
541
+ proto: Union[str, ModelProto],
542
+ device: str = "cpu",
543
+ engine: str = "ExtendedReferenceEvaluator",
544
+ sep: str = "___",
545
+ ) -> Any:
546
+ """
547
+ Deserializes tensors stored with function
548
+ :func:`create_onnx_model_from_input_tensors`.
549
+ It relies on :class:`ExtendedReferenceEvaluator
550
+ <onnx_diagnostic.reference.ExtendedReferenceEvaluator>`
551
+ to restore the tensors.
552
+
553
+ :param proto: ModelProto or the file itself
554
+ :param device: moves the tensor to this device
555
+ :param engine: runtime to use, onnx, the default value, onnxruntime
556
+ :param sep: separator
557
+ :return: restored data
558
+
559
+ See example :ref:`l-plot-intermediate-results` for an example.
560
+ """
561
+ if engine == "ExtendedReferenceEvaluator":
562
+ from ..reference import ExtendedReferenceEvaluator
563
+
564
+ sess = ExtendedReferenceEvaluator(proto)
565
+ names = sess.output_names
566
+ elif engine == "onnx":
567
+ from onnx.reference import ReferenceEvaluator
568
+
569
+ sess = ReferenceEvaluator(proto)
570
+ names = sess.output_names
571
+ elif engine == "onnxruntime":
572
+ from onnxruntime import InferenceSession
573
+
574
+ sess = InferenceSession(
575
+ proto if isinstance(proto, str) else proto.SerializeToString(),
576
+ providers=["CPUExecutionProvider"],
577
+ )
578
+ names = [i.name for i in sess.get_outputs()]
579
+ else:
580
+ raise AssertionError(f"Unexpected value for engine={engine!r}")
581
+
582
+ got = sess.run(None, {})
583
+ if len(names) == 1:
584
+ name = names[0]
585
+ output = got[0]
586
+ if name == "empty":
587
+ return None
588
+ if name == "array":
589
+ return output
590
+ if name == "bool":
591
+ return bool(output[0])
592
+ if name == "int":
593
+ return int(output[0])
594
+ if name == "float":
595
+ return float(output[0])
596
+ if name == "tensor":
597
+ return torch.from_numpy(output).to(device)
598
+ raise AssertionError(f"Unexpected name {name!r} in {names}")
599
+
600
+ return _unflatten(sep, names, got, device=device)[1]