onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +387 -12
  3. onnx_diagnostic/export/api.py +118 -5
  4. onnx_diagnostic/export/control_flow.py +214 -0
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +135 -0
  7. onnx_diagnostic/export/onnx_plug.py +396 -0
  8. onnx_diagnostic/ext_test_case.py +118 -25
  9. onnx_diagnostic/helpers/cache_helper.py +218 -204
  10. onnx_diagnostic/helpers/dot_helper.py +210 -0
  11. onnx_diagnostic/helpers/helper.py +92 -26
  12. onnx_diagnostic/helpers/log_helper.py +26 -4
  13. onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
  14. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  15. onnx_diagnostic/helpers/onnx_helper.py +115 -16
  16. onnx_diagnostic/helpers/ort_session.py +37 -11
  17. onnx_diagnostic/helpers/rt_helper.py +547 -0
  18. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  19. onnx_diagnostic/helpers/torch_helper.py +108 -6
  20. onnx_diagnostic/reference/ort_evaluator.py +233 -28
  21. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  22. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  23. onnx_diagnostic/tasks/summarization.py +72 -137
  24. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  26. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  34. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  35. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  36. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
  37. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  38. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  39. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  40. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  41. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
  42. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  43. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  44. onnx_diagnostic/torch_models/validate.py +50 -1
  45. onnx_diagnostic/torch_onnx/sbs.py +963 -312
  46. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
  47. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
  48. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
  49. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
  50. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  51. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,214 @@
1
+ import contextlib
2
+ from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
3
+ import torch
4
+ from torch._higher_order_ops.utils import (
5
+ materialize_as_graph,
6
+ check_input_alias_and_mutation_return_outputs,
7
+ # _maybe_reenter_make_fx,
8
+ )
9
+
10
+ _TEST_EXPORT = False
11
+
12
+
13
+ @contextlib.contextmanager
14
+ def enable_code_export_control_flow():
15
+ """Enables the code meant to be exported."""
16
+ global _TEST_EXPORT
17
+ old = _TEST_EXPORT
18
+ _TEST_EXPORT = True
19
+ try:
20
+ yield
21
+ finally:
22
+ _TEST_EXPORT = old
23
+
24
+
25
+ def is_exporting() -> bool:
26
+ """
27
+ Returns :func:`torch.compiler.is_exporting` or
28
+ :func:`torch.compiler.is_compiling`.
29
+ Changes ``_TEST_EXPORT`` to make it trigger.
30
+ """
31
+ return _TEST_EXPORT or torch.compiler.is_exporting() or torch.compiler.is_compiling()
32
+
33
+
34
+ def _loop_for_fn(n_iter, body_fn, reduction_dim, args):
35
+ """
36
+ Python implementation of the loop.
37
+
38
+ :param n_iter: number of iteration
39
+ :param body_fn: function implementing the body
40
+ :param reduction_dim: dimension used to reduce the list produced by the loop
41
+ :param args: arguments to the loop body
42
+ :return: results
43
+ """
44
+ res = []
45
+ for i in torch.arange(n_iter, dtype=n_iter.dtype):
46
+ r = body_fn(i, *args)
47
+ if isinstance(r, tuple):
48
+ assert not res or len(r) == len(res[-1]), (
49
+ f"Unexpected number of results {len(r)} for function {body_fn}, "
50
+ f"expected {len(res[-1])}"
51
+ )
52
+ res.append(r)
53
+ else:
54
+ assert isinstance(r, torch.Tensor), (
55
+ f"Unexpected type {r} for function {body_fn}, "
56
+ f"it must be a tuple or a Tensor."
57
+ )
58
+ assert not res or len(res[-1]) == 1, (
59
+ f"Unexpected number of results {len(r)} for function {body_fn}, "
60
+ f"expected {len(res[-1])}"
61
+ )
62
+ res.append((r,))
63
+
64
+ if not res:
65
+ return torch.empty(tuple(), dtype=torch.float32, device=args[0].device)
66
+ if len(res) == 1:
67
+ final = res[0]
68
+ else:
69
+ n_res = len(res[0])
70
+ final = [
71
+ torch.cat(
72
+ [r[i] for r in res],
73
+ dim=(
74
+ 0 if reduction_dim is None or i >= len(reduction_dim) else reduction_dim[i]
75
+ ),
76
+ )
77
+ for i in range(n_res)
78
+ ]
79
+ return tuple(final) if len(final) > 1 else final[0]
80
+
81
+
82
+ def make_custom_loop_for(
83
+ n_iter: torch.Tensor,
84
+ body_fn: Callable,
85
+ reduction_dim: Optional[Sequence[int]],
86
+ args: Sequence[torch.Tensor],
87
+ body_gm: Optional[torch.fx.GraphModule] = None,
88
+ body_mutated_inputs: Optional[List[Any]] = None,
89
+ body_outputs: Optional[List[Any]] = None,
90
+ ) -> Tuple[str, torch.library.CustomOpDef]:
91
+ """
92
+ Defines a custom operator for a loop in order to avoid
93
+ :func:`torch.export.export` digging into it.
94
+ It registers the custom op and a custom conversion
95
+ to ONNX.
96
+
97
+ :param n_iter: number of iterations defined by a tensor of no dimension
98
+ :param body_fn: the loop body defined as a function
99
+ :param reduction_dim: dimension used to concatenated the results
100
+ :param args: list of tensors, input to the body
101
+ :param body_gm: torch.fx.GraphModule equivalent to *body_gm*
102
+ :param body_mutated_inputs: inputs to *body_gm*
103
+ :param body_outputs: outputs to *body_gm*
104
+ :return: a name and the custom op definition, the name
105
+ is used to cache the custom op
106
+ """
107
+ assert body_gm is not None, "body_gm cannot be None"
108
+ assert body_mutated_inputs is not None, "body_mutated_inputs cannot be None"
109
+ assert body_outputs is not None, "body_outputs cannot be None"
110
+
111
+ srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
112
+ sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
113
+ full_name = (
114
+ body_fn.__qualname__.replace("<locals>", "L")
115
+ .replace("<lambda>", "l")
116
+ .replace(".", "_")
117
+ )
118
+ name = f"loop_for_onnx_{full_name}_{srank}_{sred}"
119
+
120
+ schema = "(str body_fn, Tensor n_iter, Tensor[] body_inputs) -> Tensor"
121
+ if len(body_outputs) > 1:
122
+ schema += "[]"
123
+ custom_def = torch.library.CustomOpDef("onnx_higher_ops", "loop_for", schema, body_fn)
124
+ custom_def.register_kernel("cpu")(body_fn)
125
+
126
+ custom_def._abstract_fn = lambda _fn_id, *_args, _o=body_outputs: (
127
+ tuple([torch.empty_like(s) for s in _o]) if len(_o) > 1 else torch.empty_like(_o[0])
128
+ )
129
+ return name, custom_def
130
+
131
+
132
+ def loop_for(
133
+ n_iter: Union[torch.SymInt, torch.Tensor],
134
+ body_fn: Callable[..., Tuple[torch.Tensor]],
135
+ args: Sequence[torch.Tensor],
136
+ reduction_dim: Optional[Sequence[int]] = None,
137
+ ) -> Tuple[torch.Tensor, ...]:
138
+ """
139
+ High operators used to easily export a loop in ONNX.
140
+ Does not fully work with :func:`torch.export.export`,
141
+ it does replaces a custom op with a loop operator afterwards.
142
+ Every iteration produces tensors, all of them are gathered
143
+ into lists, all these lists are concatenated into tensors.
144
+
145
+ :param n_iter: number of iterations, it can be fixed on
146
+ variable, in that case it should a tensor with no dimension
147
+ :param body_fn: function body, takes only tensors and returns
148
+ only tensors, the first argument is the iteration number
149
+ in a tensor with no dimension, all the others
150
+ are not changed during the loop
151
+ :param args: the available tensors at every loop
152
+ :param reduction_dim: the loop aggregated the results into list,
153
+ one of each output, each of them is concatenated into one
154
+ tensor along one dimension, by default, it is the first
155
+ dimension, but it can be defined otherwise
156
+ """
157
+ assert args, "The function should have at least one arg."
158
+ assert (
159
+ isinstance(n_iter, torch.Tensor)
160
+ and n_iter.dtype == torch.int64
161
+ and len(n_iter.shape) == 0
162
+ ), f"Only a tensor for one int64 is allowed for n_iter but it equal to {n_iter}."
163
+ if is_exporting():
164
+ from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
165
+
166
+ # tracer = _CURRENT_MAKE_FX_TRACER.fx_tracer
167
+ root = _CURRENT_MAKE_FX_TRACER.fx_tracer.root
168
+ # graph = _CURRENT_MAKE_FX_TRACER.fx_tracer.graph
169
+
170
+ body_gm: torch.fx.GraphModule = materialize_as_graph(
171
+ body_fn, (torch.tensor(0, dtype=torch.int64), *args)
172
+ )
173
+ (
174
+ _1,
175
+ _2,
176
+ _3,
177
+ body_mutated_inputs,
178
+ body_outputs,
179
+ ) = check_input_alias_and_mutation_return_outputs(body_gm)
180
+ name, _custom_ops = make_custom_loop_for(
181
+ n_iter,
182
+ body_fn,
183
+ reduction_dim,
184
+ args,
185
+ body_gm=body_gm,
186
+ body_mutated_inputs=body_mutated_inputs,
187
+ body_outputs=body_outputs,
188
+ )
189
+ root.register_module(name, body_gm)
190
+ # body_graph = _maybe_reenter_make_fx(body_fn)(n_iter, *args)
191
+ return torch.ops.onnx_higher_ops.loop_for(name, n_iter, args)
192
+
193
+ return _loop_for_fn(n_iter, body_fn, reduction_dim, args)
194
+
195
+
196
+ """
197
+ proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph)
198
+ proxy_mode.tracer.root.register_module(body_graph_name, body_graph)
199
+
200
+ args = (cond_graph, body_graph, carried_inputs, additional_inputs)
201
+
202
+ proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
203
+
204
+ out_proxy = proxy_mode.tracer.create_proxy(
205
+ "call_function", op, proxy_args, {}, name=op._name
206
+ )
207
+
208
+ out = op(
209
+ cond_graph, body_graph, unspecialized_carried_inputs, additional_inputs
210
+ )
211
+ return track_tensor_tree(
212
+ out, out_proxy, constant=None, tracer=proxy_mode.tracer
213
+ )
214
+ """