onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,440 @@
1
+ from typing import Any, Dict, Iterator, Optional, Tuple, Union
2
+ import onnx
3
+ import torch
4
+ from ..helpers import string_type, string_diff, max_diff
5
+ from ..helpers.onnx_helper import to_array_extended
6
+ from ..helpers.torch_helper import to_numpy
7
+
8
+
9
+ def validate_fx_tensor(
10
+ node: torch.fx.Node, tensor: torch.Tensor, expected_shape: Tuple[Any, ...]
11
+ ) -> None:
12
+ """
13
+ Validates the shape of tensor is expected.
14
+
15
+ :param node: node
16
+ :param tensor: tensor
17
+ :param expected_shape: expected shape
18
+ """
19
+ assert len(tensor.shape) == len(expected_shape), (
20
+ f"Shape mismatch, got {tensor.shape} expected {expected_shape}, "
21
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
22
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
23
+ f"node.meta={node.meta}"
24
+ )
25
+ for a, b in zip(tensor.shape, expected_shape):
26
+ assert not isinstance(b, int) or a == b or {a, b} == {0, 1}, (
27
+ f"Dimension mismatch, got {tensor.shape} expected {expected_shape}, "
28
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
29
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
30
+ f"node.meta={node.meta}"
31
+ )
32
+
33
+
34
+ def validate_fx_outputs(node: torch.fx.Node, outputs: Tuple[Any, ...]) -> None:
35
+ """
36
+ Validates the outputs of a node using metadata stored in the node.
37
+
38
+ :param node: node
39
+ :param outputs: outputs
40
+ """
41
+ if "val" not in node.meta:
42
+ return
43
+ if isinstance(outputs, torch.Tensor):
44
+ validate_fx_tensor(node, outputs, node.meta["val"].shape)
45
+ return
46
+ if isinstance(outputs, (tuple, list)):
47
+ assert isinstance(node.meta["val"], (list, tuple)), (
48
+ f"Unexpected type {string_type(node.meta['val'])} for node.meta['val'], "
49
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
50
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
51
+ f"node.meta={node.meta}"
52
+ )
53
+ assert len(outputs) == len(node.meta["val"]), (
54
+ f"Length mismatch, got {len(outputs)} expected {len(node.meta['val'])}, "
55
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
56
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
57
+ f"node.meta={node.meta}"
58
+ )
59
+ for a, b in zip(outputs, node.meta["val"]):
60
+ validate_fx_tensor(node, a, b.shape)
61
+ return
62
+ if isinstance(outputs, int):
63
+ assert (
64
+ isinstance(node.meta["val"], (torch.SymInt, torch.SymBool, torch.SymFloat))
65
+ or outputs == node.meta["val"]
66
+ ), (
67
+ f"Int mismatch, got {outputs} expected {node.meta['val']}, "
68
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
69
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
70
+ f"node.meta={node.meta}"
71
+ )
72
+ return
73
+ if outputs is None:
74
+ assert node.meta["val"] is None, (
75
+ f"None mismatch, got {outputs} expected {node.meta['val']}, "
76
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
77
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
78
+ f"node.meta={node.meta}"
79
+ )
80
+ return
81
+ raise NotImplementedError(
82
+ f"Validation for output type {type(outputs)} is not implemented, "
83
+ f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
84
+ f"node.args={node.args}, node.kwargs={node.kwargs}, "
85
+ f"node.meta={node.meta}"
86
+ )
87
+
88
+
89
+ def run_fx_node(
90
+ node: torch.fx.Node, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
91
+ ) -> Tuple[Any, ...]:
92
+ """
93
+ Executes a node
94
+
95
+ :param node: runs a node
96
+ :param args: unnamed inputs to the node
97
+ :param kwargs: named inputs to the node
98
+ :return: results
99
+ """
100
+ if node.op == "output":
101
+ assert len(args) == 1 and not kwargs, (
102
+ f"Unexpected inputs: args={string_type(args, limit=20)} "
103
+ f"kwargs={string_type(kwargs, limit=20)}"
104
+ )
105
+ return args
106
+ if node.op == "call_function":
107
+ assert callable(node.target), f"{node.target!r} not callable in node {node!r}"
108
+ outputs = node.target(*args, **(kwargs or {}))
109
+ validate_fx_outputs(node, outputs)
110
+ return outputs
111
+ raise NotImplementedError(
112
+ f"node.op={node.op!r} is not implemented, node.name={node.name!r}"
113
+ )
114
+
115
+
116
+ def _pick_result(torch_results: Dict[str, Any], ref: Any) -> Any:
117
+ "See :func:`prepare_args_kwargs`."
118
+ if isinstance(ref, torch.fx.Node):
119
+ return torch_results[ref.name]
120
+ if isinstance(ref, list):
121
+ return [_pick_result(torch_results, n) for n in ref]
122
+ if isinstance(ref, tuple):
123
+ return tuple(_pick_result(torch_results, n) for n in ref)
124
+ if isinstance(ref, dict):
125
+ return {k: _pick_result(torch_results, v) for k, v in ref.items()}
126
+ if isinstance(ref, (bool, int, float, str, torch.device, torch.dtype)):
127
+ return ref
128
+ if ref is None:
129
+ return None
130
+ raise NotImplementedError(f"Unable to process args type {type(ref)}")
131
+
132
+
133
+ def prepare_args_kwargs(
134
+ torch_results: Dict[str, Any], node: torch.fx.Node
135
+ ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
136
+ """
137
+ Prepares args and kwargs before executing a fx node.
138
+
139
+ :param torch_results: existing results
140
+ :param node: node to execute
141
+ :return: new args and kwargs
142
+ """
143
+ new_args = _pick_result(torch_results, node.args)
144
+ new_kwargs = _pick_result(torch_results, node.kwargs)
145
+ return new_args, new_kwargs
146
+
147
+
148
+ def run_aligned(
149
+ ep: torch.export.ExportedProgram,
150
+ onx: Union[onnx.ModelProto, onnx.FunctionProto],
151
+ args: Tuple[torch.Tensor, ...],
152
+ check_conversion_cls: Union[Dict[str, Any], type],
153
+ kwargs: Optional[Dict[str, Any]] = None,
154
+ verbose: int = 0,
155
+ ) -> Iterator[Tuple[Any, ...]]:
156
+ """
157
+ Runs in parallel both the exported program
158
+ and the onnx proto and looks for discrepancies.
159
+ The function does match on result names so it assumes
160
+ the exported program and the onnx model have the same names
161
+ for equivalent results.
162
+
163
+ :param ep: exported program
164
+ :param onx: model or function proto
165
+ :param args: input args
166
+ :param check_conversion_cls: defines the runtime to use for this task
167
+ :param kwargs: input kwargs
168
+ :param verbose: verbosity level
169
+ :return: a list of tuples containing the results, they come in tuple,
170
+
171
+ Example:
172
+
173
+ .. runpython::
174
+ :showcode:
175
+ :warningout: UserWarning
176
+
177
+ import pprint
178
+ import pandas
179
+ import torch
180
+ from onnx_diagnostic.reference import (
181
+ # This can be replace by any runtime taking NodeProto as an input.
182
+ ExtendedReferenceEvaluator as ReferenceEvaluator,
183
+ )
184
+ from onnx_diagnostic.torch_onnx.sbs import run_aligned
185
+
186
+
187
+ class Model(torch.nn.Module):
188
+ def forward(self, x):
189
+ ry = x.abs()
190
+ rz = ry.exp()
191
+ rw = rz + 1
192
+ ru = rw.log() + rw
193
+ return ru
194
+
195
+
196
+ def post_process(obs):
197
+ dobs = dict(zip(["ep_id_node", "onnx_id_node", "ep_name", "onnx_name"], obs))
198
+ dobs["err_abs"] = obs[-1]["abs"]
199
+ dobs["err_rel"] = obs[-1]["rel"]
200
+ return dobs
201
+
202
+
203
+ x = torch.randn((5, 4))
204
+ Model()(x) # to make sure the model is running
205
+ ep = torch.export.export(
206
+ Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},)
207
+ )
208
+ onx = torch.onnx.export(
209
+ Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},)
210
+ ).model_proto
211
+ results = list(
212
+ map(
213
+ post_process,
214
+ run_aligned(
215
+ ep,
216
+ onx,
217
+ (x,),
218
+ check_conversion_cls=dict(cls=ReferenceEvaluator, atol=1e-5, rtol=1e-5),
219
+ verbose=1,
220
+ ),
221
+ ),
222
+ )
223
+ print("------------")
224
+ print("final results")
225
+ df = pandas.DataFrame(results)
226
+ print(df)
227
+ """
228
+ assert not kwargs, f"Not implemented when kwargs={string_type(kwargs,with_shape=True)}"
229
+ cls, atol, rtol = (
230
+ (
231
+ check_conversion_cls["cls"],
232
+ check_conversion_cls["atol"],
233
+ check_conversion_cls["rtol"],
234
+ )
235
+ if isinstance(check_conversion_cls, dict)
236
+ else (check_conversion_cls, None, None)
237
+ )
238
+
239
+ # retrieve the positions
240
+ positions: Dict[str, Any] = {}
241
+ for i, node in enumerate(ep.graph.nodes):
242
+ if isinstance(node.name, str):
243
+ positions[node.name] = dict(fx=i)
244
+ else:
245
+ for n in node.name:
246
+ positions[n] = dict(fx=i)
247
+
248
+ for i, node in enumerate(onx.graph.node):
249
+ for n in node.output:
250
+ if n in positions:
251
+ positions[n]["onnx"] = i
252
+ else:
253
+ positions[n] = dict(onnx=i)
254
+
255
+ onnx_results: Dict[str, Any] = {}
256
+ for init in onx.graph.initializer: # type: ignore
257
+ positions[init.name] = -1
258
+ onnx_results[init.name] = to_array_extended(init)
259
+ param_name = f"p_{init.name.replace('.', '_')}"
260
+ if param_name == init.name:
261
+ continue
262
+ assert param_name not in onnx_results, (
263
+ f"Some confusion may happen because {init.name!r} -> {param_name!r} "
264
+ f"and onnx_results has {sorted(onnx_results)}"
265
+ )
266
+ onnx_results[param_name] = onnx_results[init.name]
267
+
268
+ torch_results: Dict[str, Any] = {
269
+ k: torch.from_numpy(v.copy())
270
+ for k, v in onnx_results.items()
271
+ if not k.startswith("init")
272
+ }
273
+ last_position = 0
274
+ torch_output_names = None
275
+ for node in ep.graph.nodes:
276
+ if node.op == "output":
277
+ torch_output_names = [n.name for n in node.args[0]]
278
+ onnx_outputs_names = [o.name for o in onx.graph.output]
279
+ assert torch_output_names is not None and len(torch_output_names) == len(
280
+ onnx_outputs_names
281
+ ), (
282
+ f"Unexpected number of outputs, torch_output_names={torch_output_names}, "
283
+ f"onnx_outputs_names={onnx_outputs_names}"
284
+ )
285
+ mapping_onnx_to_torch = dict(zip(onnx_outputs_names, torch_output_names))
286
+
287
+ if verbose:
288
+ for k, v in torch_results.items():
289
+ print(
290
+ f"[run_aligned] +torch-cst: {k}: "
291
+ f"{string_type(v, with_shape=True, with_min_max=True)}"
292
+ )
293
+ for k, v in onnx_results.items():
294
+ print(
295
+ f"[run_aligned] +onnx-init: {k}: "
296
+ f"{string_type(v, with_shape=True, with_min_max=True)}"
297
+ )
298
+
299
+ for inp, v in zip(onx.graph.input, args):
300
+ onnx_results[inp.name] = to_numpy(v)
301
+ if verbose:
302
+ print(
303
+ f"[run_aligned] +onnx-input: {inp.name}: "
304
+ f"{string_type(v, with_shape=True, with_min_max=True)}"
305
+ )
306
+
307
+ for i, node in enumerate(ep.graph.nodes):
308
+ if verbose:
309
+ if node.op == "call_function":
310
+ print(
311
+ f"[run_aligned] run ep.graph.nodes[{i}]: "
312
+ f"{node.op}[{node.target}] -> {node.name!r}"
313
+ )
314
+ else:
315
+ print(f"[run_aligned] run ep.graph.nodes[{i}]: {node.op} -> {node.name!r}")
316
+
317
+ if node.op == "placeholder":
318
+ if node.name in onnx_results:
319
+ torch_results[node.name] = torch.from_numpy(onnx_results[node.name].copy())
320
+ if verbose:
321
+ t = torch_results[node.name]
322
+ print(
323
+ f"[run_aligned] +torch {node.name}="
324
+ f"{string_type(t, with_shape=True, with_min_max=True)}"
325
+ )
326
+ continue
327
+ raise AssertionError(
328
+ f"unable to process node {node.op} -> {node.name!r} "
329
+ f"not in {sorted(onnx_results)}, len(args)={len(args)}, "
330
+ f"onx.graph.input={[i.name for i in onx.graph.input]}"
331
+ )
332
+
333
+ outputs = [node.name] if isinstance(node.name, str) else list(node.name)
334
+ args, kwargs = prepare_args_kwargs(torch_results, node)
335
+ new_outputs = run_fx_node(node, args, kwargs)
336
+ if isinstance(new_outputs, (torch.Tensor, int, float, list)):
337
+ new_outputs = (new_outputs,)
338
+
339
+ if new_outputs is None:
340
+ # Probably an assert.
341
+ continue
342
+
343
+ for k, v in zip(outputs, new_outputs):
344
+ torch_results[k] = v
345
+ if verbose:
346
+ for k, v in zip(outputs, new_outputs):
347
+ print(
348
+ f"[run_aligned] +torch {k}="
349
+ f"{string_type(v, with_shape=True, with_min_max=True)}"
350
+ )
351
+
352
+ max_pos = -2
353
+ for n in outputs:
354
+ if n in positions and "onnx" in positions[n]:
355
+ max_pos = max(max_pos, positions[n]["onnx"])
356
+ if max_pos == -2:
357
+ # we skip.
358
+ continue
359
+
360
+ for i_onnx in range(last_position, max_pos + 1):
361
+ node = onx.graph.node[i_onnx]
362
+ if verbose:
363
+ print(
364
+ f"[run_aligned] run onx.graph.node[{i_onnx}]: "
365
+ f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
366
+ )
367
+ ref = cls(node)
368
+ feeds = {k: onnx_results[k] for k in node.input}
369
+ res = ref.run(None, feeds)
370
+ for o, r in zip(node.output, res):
371
+ onnx_results[o] = r
372
+ if verbose:
373
+ print(
374
+ f"[run_aligned] +onnx {o}="
375
+ f"{string_type(r, with_shape=True, with_min_max=True)}"
376
+ )
377
+
378
+ to = mapping_onnx_to_torch.get(o, o)
379
+ if to in torch_results:
380
+ d = max_diff(torch_results[to], r)
381
+ if verbose:
382
+ if o == to:
383
+ print(f"[run_aligned] =common results {to}: {string_diff(d)}")
384
+ else:
385
+ print(f"[run_aligned] =common results {to}/{o}: {string_diff(d)}")
386
+ if not (
387
+ atol is None
388
+ or rtol is None
389
+ or (d["abs"] <= atol and d["rel"] <= rtol)
390
+ ):
391
+ skw = dict(with_shape=True, with_min_max=True)
392
+ raise ValueError(
393
+ f"discrepancies detected for results [{to}/{o}]: "
394
+ f"{string_diff(d)}"
395
+ f"\n-- torch_results: {string_type(torch_results[to], **skw)}"
396
+ f"\n-- onnx_results: {string_type(r, **skw)}"
397
+ f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}"
398
+ )
399
+ yield (i, i_onnx, o, to, d)
400
+
401
+ last_position = max_pos + 1
402
+
403
+ # complete the execution of the onnx graph
404
+ for i_onnx in range(last_position, len(onx.graph.node)):
405
+ node = onx.graph.node[i_onnx]
406
+ if verbose:
407
+ print(
408
+ f"[run_aligned] run onx.graph.node[{i_onnx}]: "
409
+ f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
410
+ )
411
+ ref = cls(node)
412
+ feeds = {k: onnx_results[k] for k in node.input}
413
+ res = ref.run(None, feeds)
414
+ for o, r in zip(node.output, res):
415
+ onnx_results[o] = r
416
+ if verbose:
417
+ print(
418
+ f"[run_aligned] +onnx {o}="
419
+ f"{string_type(r, with_shape=True, with_min_max=True)}"
420
+ )
421
+
422
+ to = mapping_onnx_to_torch.get(o, o)
423
+ if to in torch_results:
424
+ d = max_diff(torch_results[to], r)
425
+ if verbose:
426
+ if o == to:
427
+ print(f"[run_aligned] =common results* {to}: {string_diff(d)}")
428
+ else:
429
+ print(f"[run_aligned] =common results* {to}/{o}: {string_diff(d)}")
430
+ if not (
431
+ atol is None or rtol is None or (d["abs"] <= atol and d["rel"] <= rtol)
432
+ ):
433
+ skw = dict(with_shape=True, with_min_max=True)
434
+ raise ValueError(
435
+ f"discrepancies detected for results* [{to}/{o}]: {string_diff(d)}"
436
+ f"\n-- torch_results: {string_type(torch_results[to], **skw)}"
437
+ f"\n-- onnx_results: {string_type(r, **skw)}"
438
+ f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}"
439
+ )
440
+ yield (i, i_onnx, o, to, d)
@@ -0,0 +1,213 @@
1
+ Metadata-Version: 2.4
2
+ Name: onnx-diagnostic
3
+ Version: 0.8.0
4
+ Summary: Tools to help converting pytorch models into ONNX.
5
+ Home-page: https://github.com/sdpython/onnx-diagnostic
6
+ Author: Xavier Dupré
7
+ Author-email: Xavier Dupré <xavier.dupre@gmail.com>
8
+ License: MIT
9
+ Project-URL: Homepage, https://sdpython.github.io/doc/onnx-diagnostic/dev/
10
+ Project-URL: Repository, https://github.com/sdpython/onnx-diagnostic/
11
+ Requires-Python: >=3.9
12
+ Description-Content-Type: text/x-rst
13
+ License-File: LICENSE.txt
14
+ Dynamic: author
15
+ Dynamic: home-page
16
+ Dynamic: license-file
17
+
18
+
19
+ .. image:: https://github.com/sdpython/onnx-diagnostic/raw/main/_doc/_static/logo.png
20
+ :width: 120
21
+
22
+ onnx-diagnostic: investigate onnx models
23
+ ========================================
24
+
25
+ .. image:: https://github.com/sdpython/onnx-diagnostic/actions/workflows/documentation.yml/badge.svg
26
+ :target: https://github.com/sdpython/onnx-diagnostic/actions/workflows/documentation.yml
27
+
28
+ .. image:: https://badge.fury.io/py/onnx-diagnostic.svg
29
+ :target: http://badge.fury.io/py/onnx-diagnostic
30
+
31
+ .. image:: https://img.shields.io/badge/license-MIT-blue.svg
32
+ :alt: MIT License
33
+ :target: https://opensource.org/license/MIT/
34
+
35
+ .. image:: https://img.shields.io/github/repo-size/sdpython/onnx-diagnostic
36
+ :target: https://github.com/sdpython/onnx-diagnostic/
37
+ :alt: size
38
+
39
+ .. image:: https://img.shields.io/badge/code%20style-black-000000.svg
40
+ :target: https://github.com/psf/black
41
+
42
+ .. image:: https://codecov.io/gh/sdpython/onnx-diagnostic/graph/badge.svg?token=91T5ZVIP96
43
+ :target: https://codecov.io/gh/sdpython/onnx-diagnostic
44
+
45
+ The main feature is about `patches <https://github.com/sdpython/onnx-diagnostic/tree/main/onnx_diagnostic/torch_export_patches>`_:
46
+ it helps exporting **pytorch models into ONNX**, mostly designed for LLMs using dynamic caches.
47
+ Patches can be enabled as follows:
48
+
49
+ .. code-block:: python
50
+
51
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
52
+
53
+ with torch_export_patches(patch_transformers=True) as f:
54
+ ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
55
+ # ...
56
+
57
+ Dynamic shapes are difficult to guess for caches, one function
58
+ returns a structure defining all dimensions as dynamic.
59
+ You need then to remove those which are not dynamic in your model.
60
+
61
+ .. code-block:: python
62
+
63
+ from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
64
+
65
+ dynamic_shapes = all_dynamic_shapes_from_inputs(cache)
66
+
67
+ It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...).
68
+ See `documentation of onnx-diagnostic <https://sdpython.github.io/doc/onnx-diagnostic/dev/>`_ and
69
+ `torch_export_patches <https://sdpython.github.io/doc/onnx-diagnostic/dev/api/torch_export_patches/index.html#onnx_diagnostic.torch_export_patches.torch_export_patches>`_.
70
+
71
+ Getting started
72
+ +++++++++++++++
73
+
74
+ ::
75
+
76
+ git clone https://github.com/sdpython/onnx-diagnostic.git
77
+ cd onnx-diagnostic
78
+ pip install -e . -v
79
+
80
+ or
81
+
82
+ ::
83
+
84
+ pip install onnx-diagnostic
85
+
86
+ Enlightening Examples
87
+ +++++++++++++++++++++
88
+
89
+ **Where to start to export a model**
90
+
91
+ * `Export microsoft/phi-2
92
+ <https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_phi2.html>`_
93
+
94
+ **Torch Export**
95
+
96
+ * `Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints
97
+ <https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_shapes_auto.html>`_
98
+ * `Find and fix an export issue due to dynamic shapes
99
+ <https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_locate_issue.html>`_
100
+ * `Export with DynamicCache and guessed dynamic shapes
101
+ <https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_cache.html>`_
102
+ * `Steel method forward to guess the dynamic shapes (with Tiny-LLM)
103
+ <https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_llm.html>`_
104
+ * `Export Tiny-LLM with patches
105
+ <https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_llm_patched.html>`_
106
+
107
+ **Investigate ONNX models**
108
+
109
+ * `Find where a model is failing by running submodels
110
+ <https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_failing_model_extract.html>`_
111
+ * `Intermediate results with (ONNX) ReferenceEvaluator
112
+ <https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_failing_reference_evaluator.html>`_
113
+ * `Intermediate results with onnxruntime
114
+ <https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_failing_onnxruntime_evaluator.html>`_
115
+
116
+ Snapshot of usefuls tools
117
+ +++++++++++++++++++++++++
118
+
119
+ **torch_export_patches**
120
+
121
+ .. code-block:: python
122
+
123
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
124
+
125
+ with torch_export_patches(patch_transformers=True) as f:
126
+ ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
127
+ # ...
128
+
129
+ **all_dynamic_shapes_from_inputs**
130
+
131
+ .. code-block:: python
132
+
133
+ from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
134
+
135
+ dynamic_shapes = all_dynamic_shapes_from_inputs(cache)
136
+
137
+ **torch_export_rewrite**
138
+
139
+ .. code-block:: python
140
+
141
+ from onnx_diagnostic.torch_export_patches import torch_export_rewrite
142
+
143
+ with torch_export_rewrite(rewrite=[Model.forward]) as f:
144
+ ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
145
+ # ...
146
+
147
+ **string_type**
148
+
149
+ .. code-block:: python
150
+
151
+ import torch
152
+ from onnx_diagnostic.helpers import string_type
153
+
154
+ inputs = (
155
+ torch.rand((3, 4), dtype=torch.float16),
156
+ [torch.rand((5, 6), dtype=torch.float16), torch.rand((5, 6, 7), dtype=torch.float16)],
157
+ )
158
+
159
+ # with shapes
160
+ print(string_type(inputs, with_shape=True))
161
+
162
+ ::
163
+
164
+ >>> (T10s3x4,#2[T10s5x6,T10s5x6x7])
165
+
166
+ **onnx_dtype_name**
167
+
168
+ .. code-block:: python
169
+
170
+ import onnx
171
+ from onnx_diagnostic.helpers.onnx_helper import onnx_dtype_name
172
+
173
+ itype = onnx.TensorProto.BFLOAT16
174
+ print(onnx_dtype_name(itype))
175
+ print(onnx_dtype_name(7))
176
+
177
+ ::
178
+
179
+ >>> BFLOAT16
180
+ >>> INT64
181
+
182
+ **max_diff**
183
+
184
+ .. code-block:: python
185
+
186
+ import torch
187
+ from onnx_diagnostic.helpers import max_diff
188
+
189
+ print(
190
+ max_diff(
191
+ (torch.Tensor([1, 2]), (torch.Tensor([1, 2]),)),
192
+ (torch.Tensor([1, 2]), (torch.Tensor([1, 2]),)),
193
+ )
194
+ )
195
+
196
+ ::
197
+
198
+ >>> {"abs": 0.0, "rel": 0.0, "sum": 0.0, "n": 4.0, "dnan": 0.0}s
199
+
200
+ **guess_dynamic_shapes**
201
+
202
+ .. code-block:: python
203
+
204
+ inputs = [
205
+ (torch.randn((5, 6)), torch.randn((1, 6))),
206
+ (torch.randn((7, 8)), torch.randn((1, 8))),
207
+ ]
208
+ ds = ModelInputs(model, inputs).guess_dynamic_shapes(auto="dim")
209
+ print(ds)
210
+
211
+ ::
212
+
213
+ >>> (({0: 'dim_0I0', 1: 'dim_0I1'}, {1: 'dim_1I1'}), {})