onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,311 @@
1
+ import pprint
2
+ from typing import Any, Callable, Dict, Optional, Set
3
+ import packaging.version as pv
4
+ import optree
5
+ import torch
6
+ import transformers
7
+ from transformers.cache_utils import (
8
+ DynamicCache,
9
+ EncoderDecoderCache,
10
+ HybridCache,
11
+ SlidingWindowCache,
12
+ StaticCache,
13
+ )
14
+
15
+ from ..helpers import string_type
16
+ from .serialization import _lower_name_with_
17
+
18
+ PATCH_OF_PATCHES: Set[Any] = set()
19
+
20
+
21
+ def get_mamba_cache_cls() -> type:
22
+ try:
23
+ from transformers.models.mamba.modeling_mamba import MambaCache
24
+
25
+ return MambaCache
26
+ except ImportError:
27
+ try:
28
+ from transformers.cache_utils import MambaCache
29
+
30
+ return MambaCache
31
+ except ImportError:
32
+ return None
33
+
34
+
35
+ def register_class_serialization(
36
+ cls,
37
+ f_flatten: Callable,
38
+ f_unflatten: Callable,
39
+ f_flatten_with_keys: Callable,
40
+ f_check: Optional[Callable] = None,
41
+ verbose: int = 0,
42
+ ) -> bool:
43
+ """
44
+ Registers a class.
45
+ It can be undone with
46
+ :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_class_serialization`.
47
+
48
+ :param cls: class to register
49
+ :param f_flatten: see ``torch.utils._pytree.register_pytree_node``
50
+ :param f_unflatten: see ``torch.utils._pytree.register_pytree_node``
51
+ :param f_flatten_with_keys: see ``torch.utils._pytree.register_pytree_node``
52
+ :param f_check: called to check the registration was successful
53
+ :param verbose: verbosity
54
+ :return: registered or not
55
+ """
56
+ if cls is not None and cls in torch.utils._pytree.SUPPORTED_NODES:
57
+ if verbose and cls is not None:
58
+ print(f"[register_class_serialization] already registered {cls.__name__}")
59
+ return False
60
+
61
+ if verbose:
62
+ print(f"[register_class_serialization] ---------- register {cls.__name__}")
63
+ torch.utils._pytree.register_pytree_node(
64
+ cls,
65
+ f_flatten,
66
+ f_unflatten,
67
+ serialized_type_name=f"{cls.__module__}.{cls.__name__}",
68
+ flatten_with_keys_fn=f_flatten_with_keys,
69
+ )
70
+ if pv.Version(torch.__version__) < pv.Version("2.7"):
71
+ if verbose:
72
+ print(
73
+ f"[register_class_serialization] "
74
+ f"---------- register {cls.__name__} for torch=={torch.__version__}"
75
+ )
76
+ torch.fx._pytree.register_pytree_flatten_spec(cls, lambda x, _: f_flatten(x)[0])
77
+
78
+ # check
79
+ if f_check:
80
+ inst = f_check()
81
+ values, spec = torch.utils._pytree.tree_flatten(inst)
82
+ restored = torch.utils._pytree.tree_unflatten(values, spec)
83
+ assert string_type(inst, with_shape=True) == string_type(restored, with_shape=True), (
84
+ f"Issue with registration of class {cls} "
85
+ f"inst={string_type(inst, with_shape=True)}, "
86
+ f"restored={string_type(restored, with_shape=True)}"
87
+ )
88
+ return True
89
+
90
+
91
+ def register_cache_serialization(
92
+ patch_transformers: bool = False, patch_diffusers: bool = True, verbose: int = 0
93
+ ) -> Dict[str, bool]:
94
+ """
95
+ Registers many classes with
96
+ :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_class_serialization`.
97
+ Returns information needed to undo the registration.
98
+
99
+ :param patch_transformers: add serialization function for
100
+ :epkg:`transformers` package
101
+ :param patch_diffusers: add serialization function for
102
+ :epkg:`diffusers` package
103
+ :param verbosity: verbosity level
104
+ :return: information to unpatch
105
+ """
106
+ wrong: Dict[type, Optional[str]] = {}
107
+ if patch_transformers:
108
+ from .serialization.transformers_impl import WRONG_REGISTRATIONS
109
+
110
+ wrong |= WRONG_REGISTRATIONS
111
+ if patch_diffusers:
112
+ from .serialization.diffusers_impl import WRONG_REGISTRATIONS
113
+
114
+ wrong |= WRONG_REGISTRATIONS
115
+
116
+ registration_functions = serialization_functions(
117
+ patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose
118
+ )
119
+
120
+ # DynamicCache serialization is different in transformers and does not
121
+ # play way with torch.export.export.
122
+ # see test test_export_dynamic_cache_cat with NOBYPASS=1
123
+ # :: NOBYBASS=1 python _unittests/ut_torch_export_patches/test_dynamic_class.py -k e_c
124
+ # This is caused by this line:
125
+ # torch.fx._pytree.register_pytree_flatten_spec(
126
+ # DynamicCache, _flatten_dynamic_cache_for_fx)
127
+ # so we remove it anyway
128
+ # BaseModelOutput serialization is incomplete.
129
+ # It does not include dynamic shapes mapping.
130
+ for cls, version in wrong.items():
131
+ if (
132
+ cls in torch.utils._pytree.SUPPORTED_NODES
133
+ and cls not in PATCH_OF_PATCHES
134
+ # and pv.Version(torch.__version__) < pv.Version("2.7")
135
+ and (
136
+ version is None or pv.Version(transformers.__version__) >= pv.Version(version)
137
+ )
138
+ ):
139
+ assert cls in registration_functions, (
140
+ f"{cls} has no registration functions mapped to it, "
141
+ f"available options are {list(registration_functions)}"
142
+ )
143
+ if verbose:
144
+ print(
145
+ f"[_fix_registration] {cls.__name__} is unregistered and "
146
+ f"registered first"
147
+ )
148
+ unregister_class_serialization(cls, verbose=verbose)
149
+ registration_functions[cls](verbose=verbose) # type: ignore[arg-type, call-arg]
150
+ if verbose:
151
+ print(f"[_fix_registration] {cls.__name__} done.")
152
+ # To avoid doing it multiple times.
153
+ PATCH_OF_PATCHES.add(cls)
154
+
155
+ # classes with no registration at all.
156
+ done = {}
157
+ for k, v in registration_functions.items():
158
+ done[k] = v(verbose=verbose) # type: ignore[arg-type, call-arg]
159
+ return done
160
+
161
+
162
+ def serialization_functions(
163
+ patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0
164
+ ) -> Dict[type, Callable[[int], bool]]:
165
+ """Returns the list of serialization functions."""
166
+
167
+ supported_classes: Set[type] = set()
168
+ classes: Dict[type, Callable[[int], bool]] = {}
169
+ all_functions: Dict[type, Optional[str]] = {}
170
+
171
+ if patch_transformers:
172
+ from .serialization.transformers_impl import (
173
+ __dict__ as dtr,
174
+ SUPPORTED_DATACLASSES,
175
+ flatten_dynamic_cache,
176
+ unflatten_dynamic_cache,
177
+ flatten_with_keys_dynamic_cache,
178
+ flatten_hybrid_cache,
179
+ unflatten_hybrid_cache,
180
+ flatten_with_keys_hybrid_cache,
181
+ flatten_mamba_cache,
182
+ unflatten_mamba_cache,
183
+ flatten_with_keys_mamba_cache,
184
+ flatten_encoder_decoder_cache,
185
+ unflatten_encoder_decoder_cache,
186
+ flatten_with_keys_encoder_decoder_cache,
187
+ flatten_sliding_window_cache,
188
+ unflatten_sliding_window_cache,
189
+ flatten_with_keys_sliding_window_cache,
190
+ flatten_static_cache,
191
+ unflatten_static_cache,
192
+ flatten_with_keys_static_cache,
193
+ )
194
+
195
+ all_functions.update(dtr)
196
+ supported_classes |= SUPPORTED_DATACLASSES
197
+
198
+ transformers_classes = {
199
+ DynamicCache: lambda verbose=verbose: register_class_serialization(
200
+ DynamicCache,
201
+ flatten_dynamic_cache,
202
+ unflatten_dynamic_cache,
203
+ flatten_with_keys_dynamic_cache,
204
+ # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
205
+ verbose=verbose,
206
+ ),
207
+ HybridCache: lambda verbose=verbose: register_class_serialization(
208
+ HybridCache,
209
+ flatten_hybrid_cache,
210
+ unflatten_hybrid_cache,
211
+ flatten_with_keys_hybrid_cache,
212
+ # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
213
+ verbose=verbose,
214
+ ),
215
+ EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
216
+ EncoderDecoderCache,
217
+ flatten_encoder_decoder_cache,
218
+ unflatten_encoder_decoder_cache,
219
+ flatten_with_keys_encoder_decoder_cache,
220
+ verbose=verbose,
221
+ ),
222
+ SlidingWindowCache: lambda verbose=verbose: register_class_serialization(
223
+ SlidingWindowCache,
224
+ flatten_sliding_window_cache,
225
+ unflatten_sliding_window_cache,
226
+ flatten_with_keys_sliding_window_cache,
227
+ verbose=verbose,
228
+ ),
229
+ StaticCache: lambda verbose=verbose: register_class_serialization(
230
+ StaticCache,
231
+ flatten_static_cache,
232
+ unflatten_static_cache,
233
+ flatten_with_keys_static_cache,
234
+ verbose=verbose,
235
+ ),
236
+ }
237
+ MambaCache = get_mamba_cache_cls()
238
+ if MambaCache:
239
+ transformers_classes[MambaCache] = (
240
+ lambda verbose=verbose: register_class_serialization(
241
+ MambaCache,
242
+ flatten_mamba_cache,
243
+ unflatten_mamba_cache,
244
+ flatten_with_keys_mamba_cache,
245
+ verbose=verbose,
246
+ )
247
+ )
248
+ classes.update(transformers_classes)
249
+
250
+ if patch_diffusers:
251
+ from .serialization.diffusers_impl import SUPPORTED_DATACLASSES, __dict__ as dfu
252
+
253
+ all_functions.update(dfu)
254
+ supported_classes |= SUPPORTED_DATACLASSES
255
+
256
+ for cls in supported_classes:
257
+ lname = _lower_name_with_(cls.__name__)
258
+ assert (
259
+ f"flatten_{lname}" in all_functions
260
+ ), f"Unable to find function 'flatten_{lname}' in {list(all_functions)}"
261
+ classes[cls] = (
262
+ lambda verbose=verbose, _ln=lname, cls=cls, _al=all_functions: register_class_serialization( # noqa: E501
263
+ cls,
264
+ _al[f"flatten_{_ln}"],
265
+ _al[f"unflatten_{_ln}"],
266
+ _al[f"flatten_with_keys_{_ln}"],
267
+ verbose=verbose,
268
+ )
269
+ )
270
+ return classes
271
+
272
+
273
+ def unregister_class_serialization(cls: type, verbose: int = 0):
274
+ """Undo the registration."""
275
+ # torch.utils._pytree._deregister_pytree_flatten_spec(cls)
276
+ if cls in torch.fx._pytree.SUPPORTED_NODES:
277
+ del torch.fx._pytree.SUPPORTED_NODES[cls]
278
+ if cls in torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH:
279
+ del torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH[cls]
280
+ if hasattr(torch.utils._pytree, "_deregister_pytree_node"):
281
+ # torch >= 2.7
282
+ torch.utils._pytree._deregister_pytree_node(cls)
283
+ else:
284
+ if cls in torch.utils._pytree.SUPPORTED_NODES:
285
+ del torch.utils._pytree.SUPPORTED_NODES[cls]
286
+ optree.unregister_pytree_node(cls, namespace="torch")
287
+ if cls in torch.utils._pytree.SUPPORTED_NODES:
288
+ import packaging.version as pv
289
+
290
+ if pv.Version(torch.__version__) < pv.Version("2.7.0"):
291
+ del torch.utils._pytree.SUPPORTED_NODES[cls]
292
+ assert cls not in torch.utils._pytree.SUPPORTED_NODES, (
293
+ f"{cls} was not successful unregistered "
294
+ f"from torch.utils._pytree.SUPPORTED_NODES="
295
+ f"{pprint.pformat(list(torch.utils._pytree.SUPPORTED_NODES))}"
296
+ )
297
+ if verbose:
298
+ print(f"[unregister_cache_serialization] unregistered {cls.__name__}")
299
+
300
+
301
+ def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
302
+ """Undo all registrations."""
303
+ MambaCache = get_mamba_cache_cls()
304
+ cls_ensemble = (
305
+ {DynamicCache, EncoderDecoderCache}
306
+ | set(undo)
307
+ | ({MambaCache} if MambaCache else set())
308
+ )
309
+ for cls in cls_ensemble:
310
+ if undo.get(cls.__name__, False):
311
+ unregister_class_serialization(cls, verbose)
@@ -0,0 +1,340 @@
1
+ import difflib
2
+ import inspect
3
+ import pprint
4
+ import re
5
+ import textwrap
6
+ from typing import Any, Dict, Callable, List, Optional, Tuple, Union
7
+
8
+
9
+ def clean_code_with_black(code: str) -> str:
10
+ """Changes the code style with :epkg:`black` if available."""
11
+ code = textwrap.dedent(code)
12
+ try:
13
+ import black
14
+ except ImportError:
15
+ return code
16
+ try:
17
+ return black.format_str(code, mode=black.FileMode(line_length=98))
18
+ except black.parsing.InvalidInput as e:
19
+ raise RuntimeError(f"Unable to parse code\n\n---\n{code}\n---\n") from e
20
+
21
+
22
+ def make_diff_code(code1: str, code2: str, output: Optional[str] = None) -> str:
23
+ """
24
+ Creates a diff between two codes.
25
+
26
+ :param code1: first code
27
+ :param code2: second code
28
+ :param output: if not empty, stores the output in this file
29
+ :return: diff
30
+ """
31
+ text = "\n".join(
32
+ difflib.unified_diff(
33
+ code1.strip().splitlines(),
34
+ code2.strip().splitlines(),
35
+ fromfile="original",
36
+ tofile="rewritten",
37
+ lineterm="",
38
+ )
39
+ )
40
+ if output:
41
+ with open(output, "w") as f:
42
+ f.write(text)
43
+ return text
44
+
45
+
46
+ class PatchInfo:
47
+ """
48
+ Stores information about patches.
49
+
50
+ :param function_to_patch: function to patch
51
+ :param patch: function patched
52
+ :param family: a category, anything to classify the patch
53
+ """
54
+
55
+ __slots__ = ("depends_on", "family", "function_to_patch", "patch")
56
+
57
+ def __init__(
58
+ self, function_to_patch: Union[str, Callable], patch: Callable, family: str = ""
59
+ ):
60
+ assert callable(function_to_patch) or isinstance(function_to_patch, str), (
61
+ f"function_to_patch is not a function but {type(function_to_patch)} "
62
+ f"- {function_to_patch!r}"
63
+ )
64
+ assert callable(patch), (
65
+ f"function_to_patch is not a function but {type(patch)} - {patch!r}, "
66
+ f"function_to_patch={function_to_patch!r}"
67
+ )
68
+ assert not callable(function_to_patch) or not function_to_patch.__name__.startswith(
69
+ "patched_"
70
+ ), (
71
+ f"A patch was probably not removed because function_to_patch="
72
+ f"{function_to_patch!r} and patch={patch!r}"
73
+ )
74
+ self.family = family
75
+ self.function_to_patch = function_to_patch
76
+ self.patch = patch
77
+ self.depends_on: List[PatchInfo] = []
78
+
79
+ def add_dependency(self, patch_info: "PatchInfo"):
80
+ self.depends_on.append(patch_info)
81
+
82
+ def __repr__(self) -> str:
83
+ "usual"
84
+ return (
85
+ (
86
+ f"{self.__class__.__name__}({self.function_to_patch!r}, {self.patch!r}, "
87
+ f"{self.family!r})"
88
+ )
89
+ if self.family
90
+ else f"{self.__class__.__name__}({self.function_to_patch!r}, {self.patch!r})"
91
+ )
92
+
93
+ def to_tuple(self) -> Tuple[str, Callable, Callable]:
94
+ "usual"
95
+ return (self.family, self.function_to_patch, self.patch)
96
+
97
+ def to_dict(self) -> Dict[str, Any]:
98
+ "usual"
99
+ return {k: getattr(self, k) for k in self.__slots__}
100
+
101
+ def make_diff(self) -> str:
102
+ """Returns a diff as a string."""
103
+ if isinstance(self.function_to_patch, str):
104
+ return clean_code_with_black(inspect.getsource(self.patch))
105
+ src1 = clean_code_with_black(inspect.getsource(self.function_to_patch))
106
+ src2 = clean_code_with_black(inspect.getsource(self.patch))
107
+ diff = make_diff_code(src1, src2)
108
+ if not self.depends_on:
109
+ return diff
110
+ res = [diff]
111
+ for d in self.depends_on:
112
+ res.append("")
113
+ res.append(d.make_diff())
114
+ return "\n".join(res)
115
+
116
+ @classmethod
117
+ def function_name(cls, f: Callable) -> str:
118
+ return f.__qualname__
119
+
120
+ def format_diff(self, format: str = "raw") -> str:
121
+ """
122
+ Format a diff between two function as a string.
123
+
124
+ :param format: ``'raw'`` or ``'rst'``
125
+ :return: diff
126
+
127
+ .. runpython::
128
+ :showcode:
129
+ :rst:
130
+
131
+ import transformers
132
+ import onnx_diagnostic.torch_export_patches.patches.patch_transformers as ptr
133
+ from onnx_diagnostic.torch_export_patches.patch_details import PatchInfo
134
+ from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
135
+ patched_eager_mask,
136
+ )
137
+
138
+ eager_mask = transformers.masking_utils.eager_mask
139
+ diff = PatchInfo(eager_mask, patched_eager_mask).format_diff(format="rst")
140
+ print(diff)
141
+ """
142
+ diff = self.make_diff()
143
+ kind = self.family or ""
144
+ if kind:
145
+ kind = f"{kind}: "
146
+ function_to_pach_name = (
147
+ f"{self.function_to_patch!r}"
148
+ if isinstance(self.function_to_patch, str)
149
+ else self.function_name(self.function_to_patch)
150
+ )
151
+ patch_name = self.function_name(self.patch)
152
+ title = f"{kind}{function_to_pach_name} -> {patch_name}"
153
+ if format == "raw":
154
+ return f"{title}\n{diff}"
155
+
156
+ rows = [
157
+ title,
158
+ "=" * len(title),
159
+ "",
160
+ ".. code-block:: diff",
161
+ " :linenos:",
162
+ "",
163
+ textwrap.indent(diff, prefix=" "),
164
+ ]
165
+ return "\n".join(rows)
166
+
167
+
168
+ class PatchDetails:
169
+ """
170
+ This class is used to store patching information.
171
+ This helps understanding which rewriting was applied to which
172
+ method of functions. Page :ref:`l-patch-diff` contains all the
173
+ diff for all the implemented patches.
174
+
175
+ .. runpython::
176
+ :showcode:
177
+ :rst:
178
+
179
+ import torch
180
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
181
+ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
182
+ from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
183
+ from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
184
+
185
+ data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0)
186
+ model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
187
+ details = PatchDetails()
188
+ with torch_export_patches(
189
+ patch_transformers=True, patch_details=details, patch_torch=False
190
+ ):
191
+ ep = torch.export.export(
192
+ model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
193
+ )
194
+ patches = details.patches_involded_in_graph(ep.graph)
195
+ report = details.make_report(patches, format="rst")
196
+ print(report)
197
+ """
198
+
199
+ def __init__(self):
200
+ self.patched = []
201
+ self.find_cache = {}
202
+
203
+ def find(self, name: str) -> Optional[PatchInfo]:
204
+ "Finds a patch by name."
205
+ if name in self.find_cache:
206
+ return self.find_cache[name]
207
+ for p in self.patched:
208
+ if p.patch.__name__ == name:
209
+ self.find_cache[name] = p
210
+ return p
211
+ return None
212
+
213
+ def append(
214
+ self, family: str, function_to_patch: Union[str, Callable], patch: Callable
215
+ ) -> PatchInfo:
216
+ """
217
+ Stores a patch.
218
+
219
+ :param family: a category, anything to classify the patch
220
+ :param function_to_patch: function to patch
221
+ :param patch: function patched
222
+ :return: instance of PatchInfo
223
+ """
224
+ p = PatchInfo(function_to_patch, patch, family=family)
225
+ self.patched.append(p)
226
+ return p
227
+
228
+ @property
229
+ def n_patches(self) -> int:
230
+ "Returns the number of stored patches."
231
+ # Overwritten __len__ may have an impact on bool(patch_details: PatchDetails)
232
+ return len(self.patched)
233
+
234
+ def data(self) -> List[Dict[str, Any]]:
235
+ """Returns the data for a dataframe."""
236
+ return [p.to_dict() for p in self.patched]
237
+
238
+ def patches_involded_in_graph(
239
+ self, graph: "torch.fx.Graph" # noqa: F821
240
+ ) -> List[Tuple[PatchInfo, List["torch.fx.Node"]]]: # noqa: F821
241
+ """
242
+ Enumerates all patches impacting a graph.
243
+ The function goes through the graph node (only the main graph) and
244
+ looks into the metadata to determine if a listed patch was involved.
245
+
246
+ :param graph: fx graph
247
+ :return: list of nodes impacted by a patch
248
+ """
249
+ patches = []
250
+ for patch in self.patched:
251
+ f = patch.patch
252
+ source = inspect.getsourcefile(f)
253
+ lines, lineno = inspect.getsourcelines(f)
254
+ interval = [lineno, lineno + len(lines)]
255
+ patches.append((patch, f, source, interval))
256
+
257
+ cst = "onnx_diagnostic"
258
+ node_stack = []
259
+ for node in graph.nodes:
260
+ meta = node.meta
261
+ if "stack_trace" not in meta:
262
+ continue
263
+ stack = meta["stack_trace"]
264
+ if cst not in stack:
265
+ # to reduce the cost of the next iteration
266
+ continue
267
+ node_stack.append((node, stack))
268
+
269
+ patch_node = []
270
+ patched_nodes = set()
271
+ for patch, _f, source, interval in patches:
272
+ exp = 'File "([^"]*?%s[^"]+?)", line (\\d+)' % cst
273
+ reg = re.compile(exp)
274
+ for node, stack in node_stack:
275
+ occ = reg.findall(stack)
276
+ if not occ:
277
+ continue
278
+ for filename, line_number in occ:
279
+ if source.replace("\\", "/").strip("/") != filename.replace(
280
+ "\\", "/"
281
+ ).strip("/"):
282
+ continue
283
+ line = int(line_number)
284
+ if (
285
+ line >= interval[0]
286
+ and line <= interval[1]
287
+ and self.matching_pair(patch, node)
288
+ ):
289
+ patch_node.append((patch, node))
290
+ patched_nodes.add(id(node))
291
+
292
+ # checks all patches were discovered
293
+ for node, _ in node_stack:
294
+ assert id(node) in patched_nodes, (
295
+ f"One node was patched but no patch was found:\n"
296
+ f"node: {node.target}({','.join(map(str, node.args))}) -> {node.name}"
297
+ f"\n--\n{pprint.pformat(node.meta)}"
298
+ )
299
+
300
+ res = {} # type: ignore[var-annotated]
301
+ for patch, node in patch_node:
302
+ if patch not in res:
303
+ res[patch] = []
304
+ res[patch].append(node)
305
+ return list(res.items())
306
+
307
+ def matching_pair(cls, patch: PatchInfo, node: "torch.fx.Node") -> bool: # noqa: F821
308
+ """
309
+ Last validation for a pair. RotaryEmbedding has many rewriting
310
+ and they all end up in the same code line.
311
+ """
312
+ cls_name = patch.function_to_patch.__qualname__.split(".")[0]
313
+ if not cls_name.endswith("RotaryEmbedding"):
314
+ return True
315
+ return cls_name in str(node.meta)
316
+
317
+ def make_report(
318
+ cls,
319
+ patches: List[Tuple[PatchInfo, List["torch.fx.Node"]]], # noqa: F821
320
+ format: str = "raw",
321
+ ) -> str:
322
+ """
323
+ Creates a report based on the involved patches.
324
+
325
+ :param patches: from method :meth:`patches_involded_in_graph`
326
+ :param format: format of the report
327
+ :return: report
328
+ """
329
+ rows = []
330
+ for patch, nodes in patches:
331
+ rows.append(patch.format_diff(format=format))
332
+ rows.append("")
333
+ if format == "rst":
334
+ rows.extend(["", "", "**impacted nodes**", "", "", ".. code-block::", ""])
335
+ for node in nodes:
336
+ rows.append(
337
+ f" {node.target}({', '.join(map(str,node.args))}) -> {node.name}"
338
+ )
339
+ rows.append("")
340
+ return "\n".join(rows)