onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
File without changes
@@ -0,0 +1,289 @@
1
+ import enum
2
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
3
+ import onnx
4
+ import torch
5
+ from ..api import TensorLike
6
+ from ..helpers import string_type
7
+
8
+
9
+ class RuntimeValueKind(enum.IntEnum):
10
+ "Kind of result."
11
+
12
+ RESULT = 1
13
+ INITIALIZER = 3
14
+ INPUT = 5
15
+ OUTPUT = 9
16
+
17
+ def to_str(self) -> str:
18
+ for k, v in self.__class__.__dict__.items():
19
+ if v == int(self):
20
+ return k
21
+ raise RuntimeError(f"Unable to display {self!r}")
22
+
23
+
24
+ class RuntimeDevice(enum.IntEnum):
25
+ "Device definition"
26
+
27
+ UNKNOWN = 0
28
+ NEW = 1
29
+ CPU = 2
30
+ CUDA = 4
31
+
32
+ def to_str(self) -> str:
33
+ for k, v in self.__class__.__dict__.items():
34
+ if v == int(self):
35
+ return k
36
+ raise RuntimeError(f"Unable to display {self!r}")
37
+
38
+
39
+ class RuntimeValue:
40
+ """Describes a value used during the execution of a model."""
41
+
42
+ def __init__(
43
+ self,
44
+ name: str,
45
+ dtype: Optional[Any] = None,
46
+ shape: Optional[Tuple[Union[str, int], ...]] = None,
47
+ value: Optional[Any] = None,
48
+ first_used: Optional[int] = None,
49
+ last_used: Optional[int] = None,
50
+ created: Optional[int] = None,
51
+ is_shape: Optional[bool] = None,
52
+ kind: Optional[RuntimeValueKind] = None,
53
+ device: Optional[RuntimeDevice] = None,
54
+ ):
55
+ self.name = name
56
+ self.dtype = dtype
57
+ self.shape = shape
58
+ self.value = value
59
+ self.first_used = first_used
60
+ self.last_used = last_used
61
+ self.created = created
62
+ self.is_shape = is_shape
63
+ self.kind = kind
64
+ self.device = device
65
+
66
+ def __repr__(self) -> str:
67
+ "usual"
68
+ ad = {}
69
+ for att in [
70
+ "name",
71
+ "dtype",
72
+ "shape",
73
+ "first_used",
74
+ "last_used",
75
+ "is_shape",
76
+ "kind",
77
+ "created",
78
+ "device",
79
+ ]:
80
+ v = getattr(self, att)
81
+ if v is not None:
82
+ ad[att] = v
83
+ if self.value is not None:
84
+ ad["value"] = (
85
+ self.value.string_type()
86
+ if hasattr(self.value, "string_type")
87
+ else string_type(self.value, with_shape=True)
88
+ )
89
+ msg = ", ".join(
90
+ f"{name}={t.to_str()}" if hasattr(t, "to_str") else f"{name}={t}"
91
+ for name, t in ad.items()
92
+ )
93
+ return f"{self.__class__.__name__}({msg})"
94
+
95
+ @property
96
+ def has_value(self) -> bool:
97
+ "Tells if value is specified."
98
+ return self.value is not None
99
+
100
+ def string_type(self) -> str:
101
+ "Returns a string describing the value."
102
+ rows = []
103
+ if self.shape is not None:
104
+ rows.append(f"shape={self.shape}")
105
+ if self.is_shape is not None:
106
+ rows.append(f"is_shape={self.is_shape}")
107
+ if self.device is not None:
108
+ rows.append(f"device={self.device}")
109
+ text = f", {', '.join(rows)}" if rows else ""
110
+ if self.value is None:
111
+ return (
112
+ f"RuntimeValue(name={self.name!r}{text}"
113
+ f", dtype={self.dtype}, kind={self.kind})"
114
+ )
115
+ return (
116
+ f"RuntimeValue(name={self.name!r}, "
117
+ f"kind={self.kind}{text}, value={self.value.string_type()})"
118
+ )
119
+
120
+ def set_value(self, value: Union[torch.Tensor, TensorLike]):
121
+ """Sets the value."""
122
+ assert value is not None, "Use clean_value to set a value to None"
123
+ self.value = value
124
+ is_sequence = hasattr(value, "is_sequence") and value.is_sequence()
125
+ if self.dtype:
126
+ assert value is None or self.dtype == value.dtype, (
127
+ f"Unexpected dtype={value.dtype}, previous dtype was {self.dtype}, "
128
+ f"is_sequence={is_sequence}"
129
+ )
130
+ else:
131
+ self.dtype = value.dtype
132
+ self.shape = None if is_sequence else tuple(map(int, value.shape))
133
+
134
+ def clean_value(self):
135
+ """Sets value to None."""
136
+ self.value = None
137
+
138
+ @property
139
+ def is_output(self) -> bool:
140
+ "Tells if it is an output."
141
+ return self.kind == RuntimeValueKind.OUTPUT
142
+
143
+ @property
144
+ def is_input(self) -> bool:
145
+ "Tells if it is an input."
146
+ return self.kind == RuntimeValueKind.INPUT
147
+
148
+ @property
149
+ def is_initializer(self) -> bool:
150
+ "Tells if it is an initializer."
151
+ return self.kind == RuntimeValueKind.INITIALIZER
152
+
153
+
154
+ def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
155
+ """
156
+ Returns the hidden inputs (inputs coming from an upper context)
157
+ used by a subgraph.
158
+ """
159
+ hidden = set()
160
+ memo = (
161
+ set(i.name for i in graph.initializer)
162
+ | set(i.name for i in graph.sparse_initializer)
163
+ | set(i.name for i in graph.input)
164
+ )
165
+ for node in graph.node:
166
+ for i in node.input:
167
+ if i not in memo:
168
+ hidden.add(i)
169
+ for att in node.attribute:
170
+ if att.type == onnx.AttributeProto.GRAPH and att.g:
171
+ hid = get_hidden_inputs(att.g)
172
+ less = set(h for h in hid if h not in memo)
173
+ hidden |= less
174
+ memo |= set(node.output)
175
+ return hidden
176
+
177
+
178
+ def set_is_shape(
179
+ node: onnx.NodeProto, values: Dict[str, RuntimeValue], drop: Optional[Set[str]] = None
180
+ ) -> List[str]:
181
+ """
182
+ Sets attribute ``is_shape`` for outputs of a node.
183
+
184
+ :param node: node to process
185
+ :param values: stored results, values in this dictionary are updated
186
+ :param drop: variables not to consider because the come from the graph
187
+ holding this subgraph
188
+ :return: list of modified results
189
+ """
190
+ if not node.input:
191
+ # Constant
192
+ return []
193
+ drop = drop or set()
194
+ if node.op_type in ("Shape", "Size") and node.domain == "":
195
+ values[node.output[0]].is_shape = True
196
+ return [node.output[0]]
197
+ is_shapes = [values[i].is_shape for i in node.input if i not in drop]
198
+ if any(is_shapes):
199
+ if is_shapes[0] and len(node.output) == 1:
200
+ values[node.output[0]].is_shape = True
201
+ return [node.output[0]]
202
+ else:
203
+ for o in node.output:
204
+ values[o].is_shape = False
205
+ return list(node.output)
206
+ return []
207
+
208
+
209
+ def first_used_last_used(
210
+ proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto],
211
+ constant_as_initializer: bool = False,
212
+ ) -> Dict[str, RuntimeValue]:
213
+ """
214
+ Builds first used, last used information for every result
215
+ in the model.
216
+
217
+ :param proto: model, graph or function
218
+ :param constant_as_initializer: outputs of node Constant is tagged as INITIALIZER
219
+ :return: dictionary of RuntimeValue
220
+ """
221
+ values = {}
222
+ if isinstance(proto, onnx.ModelProto):
223
+ initializer = proto.graph.initializer
224
+ sparse_initializer = proto.graph.sparse_initializer
225
+ _input = proto.graph.input
226
+ output = proto.graph.output
227
+ _node = proto.graph.node
228
+ allow_unknown = False
229
+ elif isinstance(proto, onnx.GraphProto):
230
+ initializer = proto.initializer
231
+ sparse_initializer = proto.sparse_initializer
232
+ _input = proto.input
233
+ output = proto.output
234
+ _node = proto.node
235
+ allow_unknown = True
236
+ else:
237
+ initializer = []
238
+ sparse_initializer = []
239
+ _input = proto.input
240
+ output = proto.output
241
+ _node = proto.node
242
+ allow_unknown = False
243
+
244
+ for init in initializer:
245
+ values[init.name] = RuntimeValue(
246
+ init.name, kind=RuntimeValueKind.INITIALIZER, created=-1
247
+ )
248
+ for init in sparse_initializer:
249
+ values[init.name] = RuntimeValue(
250
+ init.name, created=-1, kind=RuntimeValueKind.INITIALIZER
251
+ )
252
+ for inp in _input:
253
+ n = inp if isinstance(inp, str) else inp.name
254
+ values[n] = RuntimeValue(n, created=-1, kind=RuntimeValueKind.INPUT)
255
+ drop = set()
256
+ for it, node in enumerate(_node):
257
+ for i in node.input:
258
+ if i not in values:
259
+ assert allow_unknown, f"Input {i!r} is unknown."
260
+ # This input comes from a context and the model is a GraphProto
261
+ drop.add(i)
262
+ continue
263
+ if values[i].first_used is None:
264
+ values[i].first_used = it
265
+ values[i].last_used = it
266
+ for att in node.attribute:
267
+ if att.type == onnx.AttributeProto.GRAPH:
268
+ for n in get_hidden_inputs(att.g):
269
+ if values[n].first_used is None:
270
+ values[n].first_used = it
271
+ values[n].last_used = it
272
+ is_constant = node.op_type == "Constant" and node.domain == ""
273
+ for o in node.output:
274
+ values[o] = RuntimeValue(
275
+ o,
276
+ created=it,
277
+ kind=(
278
+ RuntimeValueKind.INITIALIZER
279
+ if is_constant and constant_as_initializer
280
+ else RuntimeValueKind.RESULT
281
+ ),
282
+ )
283
+ set_is_shape(node, values, drop=drop)
284
+
285
+ for out in output:
286
+ n = out if isinstance(out, str) else out.name
287
+ values[n].kind = RuntimeValueKind.OUTPUT
288
+ values[n].last_used = len(_node)
289
+ return values