onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,108 @@
1
+ from typing import Callable, Set
2
+ import torch
3
+ from ..helpers.torch_helper import is_torchdynamo_exporting
4
+
5
+
6
+ def make_undefined_dimension(i: int) -> torch.SymInt:
7
+ """
8
+ Uses for a custom op when a new dimension must be introduced to bypass
9
+ some verification. The following function creates a dummy output
10
+ with a dimension based on the content.
11
+
12
+ .. code-block:: python
13
+
14
+ def symbolic_shape(x, y):
15
+ return torch.empty(
16
+ x.shape[0],
17
+ make_undefined_dimension(min(x.shape[1], y[0])),
18
+ )
19
+ """
20
+ try:
21
+ ti = int(i)
22
+ except: # noqa: E722
23
+ ti = 10
24
+ t = torch.ones((ti * 2,))
25
+ t[:ti] = 0
26
+ res = torch.nonzero(t).shape[0]
27
+ return res
28
+
29
+
30
+ def _patched_float_arange(
31
+ start: torch.Tensor, end: torch.Tensor, step: torch.Tensor
32
+ ) -> torch.Tensor:
33
+ """Float arange."""
34
+ return torch.arange(
35
+ float(start.item()),
36
+ float(end.item()),
37
+ float(step.item()),
38
+ dtype=start.dtype,
39
+ device=start.device,
40
+ )
41
+
42
+
43
+ def _patched_float_arange_shape(start, end, step):
44
+ # Fails because:
45
+ # Did you accidentally call new_dynamic_size() or item()
46
+ # more times than you needed to in your fake implementation?
47
+ # try:
48
+ # n = math.ceil(((end - start) / step).item())
49
+ # except: # noqa: E722
50
+ # n = 10
51
+ n = 10
52
+ return torch.empty((make_undefined_dimension(n),), dtype=start.dtype, device=start.device)
53
+
54
+
55
+ def _iterate_patched_expressions():
56
+ glo = globals().copy()
57
+ for k, _v in glo.items():
58
+ if k.startswith("_patched_") and not k.endswith("_shape"):
59
+ name = k
60
+ yield k[len("_patched_") :], glo[name], glo[f"{name}_shape"]
61
+
62
+
63
+ _registered: Set[str] = set()
64
+
65
+
66
+ def _register_patched_expression(
67
+ fct: Callable, fct_shape: Callable, namespace: str, fname: str
68
+ ):
69
+ schema_str = torch.library.infer_schema(fct, mutates_args=())
70
+ custom_def = torch.library.CustomOpDef(namespace, fname, schema_str, fct)
71
+ custom_def.register_kernel("cpu")(fct)
72
+ custom_def._abstract_fn = fct_shape
73
+
74
+
75
+ def register_patched_expressions(namespace: str = "patched"):
76
+ """
77
+ Registers as custom ops known expressions failing due to dynamic shapes.
78
+
79
+ .. runpython::
80
+ :showcode:
81
+
82
+ import pprint
83
+ from onnx_diagnostic.torch_export_patches.patch_expressions import (
84
+ _iterate_patched_expressions,
85
+ )
86
+
87
+ pprint.pprint([name for name, _f, _fsh in _iterate_patched_expressions()])
88
+ """
89
+ for name, f, fsh in _iterate_patched_expressions():
90
+ if name not in _registered:
91
+ _register_patched_expression(f, fsh, namespace, name)
92
+ _registered.add(name)
93
+
94
+
95
+ def patched_selector(fct: Callable, patched_fct: Callable) -> Callable:
96
+ """
97
+ Returns **fct** if the model is being executed or
98
+ **patched_fct** if it is being exported.
99
+ """
100
+ return patched_fct if is_torchdynamo_exporting() else fct
101
+
102
+
103
+ def patched_float_arange(start, end, step):
104
+ """Patched arange when start, end, step are floats."""
105
+ if is_torchdynamo_exporting():
106
+ return torch.ops.patched.float_arange(start, end, step)
107
+ else:
108
+ return torch.arange(start, end, step)
@@ -0,0 +1,211 @@
1
+ import inspect
2
+ from typing import Any, Dict, Optional, Tuple
3
+ import torch
4
+ import transformers
5
+ from ..helpers import string_type
6
+
7
+
8
+ def _process_cache(k: str, v):
9
+ assert k != "position_ids" or isinstance(
10
+ k, torch.Tensor
11
+ ), f"Unexpected type for parameter {k!r} {string_type(v, with_shape=True)}"
12
+ if (
13
+ isinstance(v, list)
14
+ and all(isinstance(i, tuple) for i in v)
15
+ and set(len(t) for t in v) == {2}
16
+ ):
17
+ # A dynamicCache
18
+ from ..helpers.cache_helper import make_dynamic_cache
19
+
20
+ cache = make_dynamic_cache(v)
21
+ return cache
22
+ if isinstance(v, torch.Tensor):
23
+ return v
24
+ raise NotImplementedError(
25
+ f"Unable to process parameter {k!r} with v={string_type(v,with_shape=True)}"
26
+ )
27
+
28
+
29
+ def _make_shape(subset: Dict, cls: type, value: Any) -> Any:
30
+ if cls is transformers.cache_utils.DynamicCache:
31
+ assert subset, "DynamicCache cannot be empty"
32
+ values = set(map(str, subset.values()))
33
+ assert len(values) == 1, (
34
+ f"Inconsistencies in subset={subset}, found={values}, "
35
+ f"it cannot be a {cls}, value={string_type(value)}"
36
+ )
37
+ cache_length = len(value.layers if hasattr(value, "layers") else value.key_cache)
38
+ for v in subset.values():
39
+ axes = v
40
+ break
41
+ new_shape = [axes for i in range(cache_length * 2)]
42
+ return new_shape
43
+ if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
44
+ raise NotImplementedError(
45
+ f"_make_shape not implemented for registered class={cls}, "
46
+ f"subset={subset}, value={string_type(value)}"
47
+ )
48
+ raise NotImplementedError(
49
+ f"_make_shape not implemented for cls={cls}, "
50
+ f"subset={subset}, value={string_type(value)}"
51
+ )
52
+
53
+
54
+ def convert_dynamic_axes_into_dynamic_shapes(
55
+ model: torch.nn.Module,
56
+ args: Optional[Tuple[Any, ...]] = None,
57
+ kwargs: Optional[Dict[str, Any]] = None,
58
+ dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None,
59
+ prefix_mapping: Optional[Dict[str, str]] = None,
60
+ verbose: int = 0,
61
+ ) -> Tuple[Tuple[Any, ...], Dict[str, Any], Dict[str, Any]]:
62
+ """
63
+ Converts the input from an export to something :func:`torch.export.export` can handle.
64
+
65
+ :param model: model to convert (used to extract the signature)
66
+ :param args: positional arguments
67
+ :param kwargs: named arguments
68
+ :param dynamic_axes: dynamic axes
69
+ :param prefix_mapping: prefix mapping
70
+ :param verbose: verbosity
71
+ :return: (args, kwargs, dynamic shapes)
72
+ """
73
+ from ..helpers.cache_helper import CacheKeyValue
74
+
75
+ new_kwargs = {}
76
+ if args:
77
+ assert hasattr(model, "forward"), f"Missing method 'forward' for {model!r}"
78
+ plus = 0 if isinstance(model, torch.nn.Module) else 1
79
+ print(
80
+ f"[convert_dynamic_axes_into_dynamic_shapes] "
81
+ f"mapping args to kwargs for model="
82
+ f"{model if plus else model.__class__.__name__}"
83
+ )
84
+ pars = inspect.signature(model.forward).parameters
85
+ assert len(pars) >= len(
86
+ args
87
+ ), f"Length mismatch, len(args)={len(args)}, pars={list(pars)}"
88
+
89
+ for i, p in enumerate(pars):
90
+ if i < plus:
91
+ continue
92
+ if i - plus >= len(args):
93
+ break
94
+ if verbose:
95
+ print(
96
+ f"[convert_dynamic_axes_into_dynamic_shapes] mapping args[{i-plus}] "
97
+ f"to {p!r} ({string_type(args[i-plus])})"
98
+ )
99
+ new_kwargs[p] = args[i - plus]
100
+
101
+ if kwargs:
102
+ for k, v in kwargs.items():
103
+ assert k not in new_kwargs, f"Argument {k!r} from kwargs already present in args."
104
+ new_kwargs[k] = v
105
+
106
+ # process
107
+ updated_kwargs = {}
108
+ changes = {}
109
+ for k, v in new_kwargs.items():
110
+ if isinstance(v, torch.Tensor):
111
+ updated_kwargs[k] = v
112
+ continue
113
+ if isinstance(v, list):
114
+ # cache?
115
+ updated_kwargs[k] = _process_cache(k, v)
116
+ if type(updated_kwargs[k]) is not type(v):
117
+ # A cache was introduced.
118
+ if verbose:
119
+ print(
120
+ f"[convert_dynamic_axes_into_dynamic_shapes] parameter "
121
+ f"{k!r} was changed into {type(updated_kwargs[k])}"
122
+ )
123
+ changes[k] = type(updated_kwargs[k])
124
+ continue
125
+ if isinstance(v, transformers.cache_utils.DynamicCache):
126
+ ca = CacheKeyValue(v)
127
+ updated_kwargs[k] = [ca.key_cache, ca.value_cache]
128
+ changes[k] = type(v)
129
+ continue
130
+ raise NotImplementedError(
131
+ f"Unexpected type {type(v)} for parameter {k!r} "
132
+ f"({string_type(v, with_shape=True)})"
133
+ )
134
+
135
+ # process dynamic axes
136
+ if changes:
137
+ dynamic_shapes = {}
138
+ done = set()
139
+ for k, v in dynamic_axes.items():
140
+ if k not in changes and k in updated_kwargs and isinstance(v, dict):
141
+ dynamic_shapes[k] = v
142
+ continue
143
+ if (
144
+ k in updated_kwargs
145
+ and k in changes
146
+ and changes[k] == transformers.cache_utils.DynamicCache
147
+ ):
148
+ dynamic_shapes[k] = v
149
+ continue
150
+ if "." in k:
151
+ # something like present.0.key
152
+ prefix = k.split(".")[0]
153
+ if prefix in done:
154
+ continue
155
+ args_prefix = (
156
+ prefix_mapping[prefix]
157
+ if prefix_mapping and prefix in prefix_mapping
158
+ else prefix
159
+ )
160
+ if args_prefix in updated_kwargs and args_prefix in changes:
161
+ # A cache.
162
+ cls = changes[args_prefix]
163
+ dynamic_shapes[args_prefix] = _make_shape(
164
+ {
165
+ _: __
166
+ for _, __ in dynamic_axes.items()
167
+ if _.startswith(f"{prefix}.")
168
+ },
169
+ cls,
170
+ updated_kwargs[args_prefix],
171
+ )
172
+ done.add(prefix)
173
+ continue
174
+ if k not in updated_kwargs:
175
+ # dynamic axes not in the given inputs, should be raise an exception?
176
+ if verbose:
177
+ print(
178
+ f"[convert_dynamic_axes_into_dynamic_shapes] dropping axes "
179
+ f"{k!r}-{v!r}, not found in {set(updated_kwargs)}"
180
+ )
181
+ continue
182
+ raise NotImplementedError(
183
+ f"Unable to process dynamic axes {k!r}, axes={v}, "
184
+ f"value={string_type(updated_kwargs[k], with_shape=True)}, "
185
+ f"dynamic axes={dynamic_axes}, "
186
+ f"updated_kwargs={string_type(updated_kwargs, with_shape=True)}"
187
+ )
188
+
189
+ return (), updated_kwargs, dynamic_shapes
190
+
191
+
192
+ def use_dyn_not_str(dynamic_shapes: Any, default_value=None) -> Any:
193
+ """
194
+ Some functions returns dynamic shapes as string.
195
+ This functions replaces them with ``torch.export.Dim.DYNAMIC``.
196
+ ``default_value=torch.export.Dim.AUTO`` changes the default value.
197
+ """
198
+ if isinstance(dynamic_shapes, list):
199
+ return [use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes]
200
+ if isinstance(dynamic_shapes, tuple):
201
+ return tuple(use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes)
202
+ if isinstance(dynamic_shapes, dict):
203
+ return {
204
+ k: use_dyn_not_str(v, default_value=default_value)
205
+ for k, v in dynamic_shapes.items()
206
+ }
207
+ if isinstance(dynamic_shapes, set):
208
+ return {use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes}
209
+ if isinstance(dynamic_shapes, str):
210
+ return torch.export.Dim.DYNAMIC if default_value is None else default_value
211
+ return dynamic_shapes