onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
import pprint
|
|
2
|
+
from typing import Any, Callable, Dict, Optional, Set
|
|
3
|
+
import packaging.version as pv
|
|
4
|
+
import optree
|
|
5
|
+
import torch
|
|
6
|
+
import transformers
|
|
7
|
+
from transformers.cache_utils import (
|
|
8
|
+
DynamicCache,
|
|
9
|
+
EncoderDecoderCache,
|
|
10
|
+
HybridCache,
|
|
11
|
+
SlidingWindowCache,
|
|
12
|
+
StaticCache,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from ..helpers import string_type
|
|
16
|
+
from .serialization import _lower_name_with_
|
|
17
|
+
|
|
18
|
+
PATCH_OF_PATCHES: Set[Any] = set()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_mamba_cache_cls() -> type:
|
|
22
|
+
try:
|
|
23
|
+
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
24
|
+
|
|
25
|
+
return MambaCache
|
|
26
|
+
except ImportError:
|
|
27
|
+
try:
|
|
28
|
+
from transformers.cache_utils import MambaCache
|
|
29
|
+
|
|
30
|
+
return MambaCache
|
|
31
|
+
except ImportError:
|
|
32
|
+
return None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def register_class_serialization(
|
|
36
|
+
cls,
|
|
37
|
+
f_flatten: Callable,
|
|
38
|
+
f_unflatten: Callable,
|
|
39
|
+
f_flatten_with_keys: Callable,
|
|
40
|
+
f_check: Optional[Callable] = None,
|
|
41
|
+
verbose: int = 0,
|
|
42
|
+
) -> bool:
|
|
43
|
+
"""
|
|
44
|
+
Registers a class.
|
|
45
|
+
It can be undone with
|
|
46
|
+
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_class_serialization`.
|
|
47
|
+
|
|
48
|
+
:param cls: class to register
|
|
49
|
+
:param f_flatten: see ``torch.utils._pytree.register_pytree_node``
|
|
50
|
+
:param f_unflatten: see ``torch.utils._pytree.register_pytree_node``
|
|
51
|
+
:param f_flatten_with_keys: see ``torch.utils._pytree.register_pytree_node``
|
|
52
|
+
:param f_check: called to check the registration was successful
|
|
53
|
+
:param verbose: verbosity
|
|
54
|
+
:return: registered or not
|
|
55
|
+
"""
|
|
56
|
+
if cls is not None and cls in torch.utils._pytree.SUPPORTED_NODES:
|
|
57
|
+
if verbose and cls is not None:
|
|
58
|
+
print(f"[register_class_serialization] already registered {cls.__name__}")
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
if verbose:
|
|
62
|
+
print(f"[register_class_serialization] ---------- register {cls.__name__}")
|
|
63
|
+
torch.utils._pytree.register_pytree_node(
|
|
64
|
+
cls,
|
|
65
|
+
f_flatten,
|
|
66
|
+
f_unflatten,
|
|
67
|
+
serialized_type_name=f"{cls.__module__}.{cls.__name__}",
|
|
68
|
+
flatten_with_keys_fn=f_flatten_with_keys,
|
|
69
|
+
)
|
|
70
|
+
if pv.Version(torch.__version__) < pv.Version("2.7"):
|
|
71
|
+
if verbose:
|
|
72
|
+
print(
|
|
73
|
+
f"[register_class_serialization] "
|
|
74
|
+
f"---------- register {cls.__name__} for torch=={torch.__version__}"
|
|
75
|
+
)
|
|
76
|
+
torch.fx._pytree.register_pytree_flatten_spec(cls, lambda x, _: f_flatten(x)[0])
|
|
77
|
+
|
|
78
|
+
# check
|
|
79
|
+
if f_check:
|
|
80
|
+
inst = f_check()
|
|
81
|
+
values, spec = torch.utils._pytree.tree_flatten(inst)
|
|
82
|
+
restored = torch.utils._pytree.tree_unflatten(values, spec)
|
|
83
|
+
assert string_type(inst, with_shape=True) == string_type(restored, with_shape=True), (
|
|
84
|
+
f"Issue with registration of class {cls} "
|
|
85
|
+
f"inst={string_type(inst, with_shape=True)}, "
|
|
86
|
+
f"restored={string_type(restored, with_shape=True)}"
|
|
87
|
+
)
|
|
88
|
+
return True
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def register_cache_serialization(
|
|
92
|
+
patch_transformers: bool = False, patch_diffusers: bool = True, verbose: int = 0
|
|
93
|
+
) -> Dict[str, bool]:
|
|
94
|
+
"""
|
|
95
|
+
Registers many classes with
|
|
96
|
+
:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_class_serialization`.
|
|
97
|
+
Returns information needed to undo the registration.
|
|
98
|
+
|
|
99
|
+
:param patch_transformers: add serialization function for
|
|
100
|
+
:epkg:`transformers` package
|
|
101
|
+
:param patch_diffusers: add serialization function for
|
|
102
|
+
:epkg:`diffusers` package
|
|
103
|
+
:param verbosity: verbosity level
|
|
104
|
+
:return: information to unpatch
|
|
105
|
+
"""
|
|
106
|
+
wrong: Dict[type, Optional[str]] = {}
|
|
107
|
+
if patch_transformers:
|
|
108
|
+
from .serialization.transformers_impl import WRONG_REGISTRATIONS
|
|
109
|
+
|
|
110
|
+
wrong |= WRONG_REGISTRATIONS
|
|
111
|
+
if patch_diffusers:
|
|
112
|
+
from .serialization.diffusers_impl import WRONG_REGISTRATIONS
|
|
113
|
+
|
|
114
|
+
wrong |= WRONG_REGISTRATIONS
|
|
115
|
+
|
|
116
|
+
registration_functions = serialization_functions(
|
|
117
|
+
patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# DynamicCache serialization is different in transformers and does not
|
|
121
|
+
# play way with torch.export.export.
|
|
122
|
+
# see test test_export_dynamic_cache_cat with NOBYPASS=1
|
|
123
|
+
# :: NOBYBASS=1 python _unittests/ut_torch_export_patches/test_dynamic_class.py -k e_c
|
|
124
|
+
# This is caused by this line:
|
|
125
|
+
# torch.fx._pytree.register_pytree_flatten_spec(
|
|
126
|
+
# DynamicCache, _flatten_dynamic_cache_for_fx)
|
|
127
|
+
# so we remove it anyway
|
|
128
|
+
# BaseModelOutput serialization is incomplete.
|
|
129
|
+
# It does not include dynamic shapes mapping.
|
|
130
|
+
for cls, version in wrong.items():
|
|
131
|
+
if (
|
|
132
|
+
cls in torch.utils._pytree.SUPPORTED_NODES
|
|
133
|
+
and cls not in PATCH_OF_PATCHES
|
|
134
|
+
# and pv.Version(torch.__version__) < pv.Version("2.7")
|
|
135
|
+
and (
|
|
136
|
+
version is None or pv.Version(transformers.__version__) >= pv.Version(version)
|
|
137
|
+
)
|
|
138
|
+
):
|
|
139
|
+
assert cls in registration_functions, (
|
|
140
|
+
f"{cls} has no registration functions mapped to it, "
|
|
141
|
+
f"available options are {list(registration_functions)}"
|
|
142
|
+
)
|
|
143
|
+
if verbose:
|
|
144
|
+
print(
|
|
145
|
+
f"[_fix_registration] {cls.__name__} is unregistered and "
|
|
146
|
+
f"registered first"
|
|
147
|
+
)
|
|
148
|
+
unregister_class_serialization(cls, verbose=verbose)
|
|
149
|
+
registration_functions[cls](verbose=verbose) # type: ignore[arg-type, call-arg]
|
|
150
|
+
if verbose:
|
|
151
|
+
print(f"[_fix_registration] {cls.__name__} done.")
|
|
152
|
+
# To avoid doing it multiple times.
|
|
153
|
+
PATCH_OF_PATCHES.add(cls)
|
|
154
|
+
|
|
155
|
+
# classes with no registration at all.
|
|
156
|
+
done = {}
|
|
157
|
+
for k, v in registration_functions.items():
|
|
158
|
+
done[k] = v(verbose=verbose) # type: ignore[arg-type, call-arg]
|
|
159
|
+
return done
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def serialization_functions(
|
|
163
|
+
patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0
|
|
164
|
+
) -> Dict[type, Callable[[int], bool]]:
|
|
165
|
+
"""Returns the list of serialization functions."""
|
|
166
|
+
|
|
167
|
+
supported_classes: Set[type] = set()
|
|
168
|
+
classes: Dict[type, Callable[[int], bool]] = {}
|
|
169
|
+
all_functions: Dict[type, Optional[str]] = {}
|
|
170
|
+
|
|
171
|
+
if patch_transformers:
|
|
172
|
+
from .serialization.transformers_impl import (
|
|
173
|
+
__dict__ as dtr,
|
|
174
|
+
SUPPORTED_DATACLASSES,
|
|
175
|
+
flatten_dynamic_cache,
|
|
176
|
+
unflatten_dynamic_cache,
|
|
177
|
+
flatten_with_keys_dynamic_cache,
|
|
178
|
+
flatten_hybrid_cache,
|
|
179
|
+
unflatten_hybrid_cache,
|
|
180
|
+
flatten_with_keys_hybrid_cache,
|
|
181
|
+
flatten_mamba_cache,
|
|
182
|
+
unflatten_mamba_cache,
|
|
183
|
+
flatten_with_keys_mamba_cache,
|
|
184
|
+
flatten_encoder_decoder_cache,
|
|
185
|
+
unflatten_encoder_decoder_cache,
|
|
186
|
+
flatten_with_keys_encoder_decoder_cache,
|
|
187
|
+
flatten_sliding_window_cache,
|
|
188
|
+
unflatten_sliding_window_cache,
|
|
189
|
+
flatten_with_keys_sliding_window_cache,
|
|
190
|
+
flatten_static_cache,
|
|
191
|
+
unflatten_static_cache,
|
|
192
|
+
flatten_with_keys_static_cache,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
all_functions.update(dtr)
|
|
196
|
+
supported_classes |= SUPPORTED_DATACLASSES
|
|
197
|
+
|
|
198
|
+
transformers_classes = {
|
|
199
|
+
DynamicCache: lambda verbose=verbose: register_class_serialization(
|
|
200
|
+
DynamicCache,
|
|
201
|
+
flatten_dynamic_cache,
|
|
202
|
+
unflatten_dynamic_cache,
|
|
203
|
+
flatten_with_keys_dynamic_cache,
|
|
204
|
+
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
205
|
+
verbose=verbose,
|
|
206
|
+
),
|
|
207
|
+
HybridCache: lambda verbose=verbose: register_class_serialization(
|
|
208
|
+
HybridCache,
|
|
209
|
+
flatten_hybrid_cache,
|
|
210
|
+
unflatten_hybrid_cache,
|
|
211
|
+
flatten_with_keys_hybrid_cache,
|
|
212
|
+
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
213
|
+
verbose=verbose,
|
|
214
|
+
),
|
|
215
|
+
EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
|
|
216
|
+
EncoderDecoderCache,
|
|
217
|
+
flatten_encoder_decoder_cache,
|
|
218
|
+
unflatten_encoder_decoder_cache,
|
|
219
|
+
flatten_with_keys_encoder_decoder_cache,
|
|
220
|
+
verbose=verbose,
|
|
221
|
+
),
|
|
222
|
+
SlidingWindowCache: lambda verbose=verbose: register_class_serialization(
|
|
223
|
+
SlidingWindowCache,
|
|
224
|
+
flatten_sliding_window_cache,
|
|
225
|
+
unflatten_sliding_window_cache,
|
|
226
|
+
flatten_with_keys_sliding_window_cache,
|
|
227
|
+
verbose=verbose,
|
|
228
|
+
),
|
|
229
|
+
StaticCache: lambda verbose=verbose: register_class_serialization(
|
|
230
|
+
StaticCache,
|
|
231
|
+
flatten_static_cache,
|
|
232
|
+
unflatten_static_cache,
|
|
233
|
+
flatten_with_keys_static_cache,
|
|
234
|
+
verbose=verbose,
|
|
235
|
+
),
|
|
236
|
+
}
|
|
237
|
+
MambaCache = get_mamba_cache_cls()
|
|
238
|
+
if MambaCache:
|
|
239
|
+
transformers_classes[MambaCache] = (
|
|
240
|
+
lambda verbose=verbose: register_class_serialization(
|
|
241
|
+
MambaCache,
|
|
242
|
+
flatten_mamba_cache,
|
|
243
|
+
unflatten_mamba_cache,
|
|
244
|
+
flatten_with_keys_mamba_cache,
|
|
245
|
+
verbose=verbose,
|
|
246
|
+
)
|
|
247
|
+
)
|
|
248
|
+
classes.update(transformers_classes)
|
|
249
|
+
|
|
250
|
+
if patch_diffusers:
|
|
251
|
+
from .serialization.diffusers_impl import SUPPORTED_DATACLASSES, __dict__ as dfu
|
|
252
|
+
|
|
253
|
+
all_functions.update(dfu)
|
|
254
|
+
supported_classes |= SUPPORTED_DATACLASSES
|
|
255
|
+
|
|
256
|
+
for cls in supported_classes:
|
|
257
|
+
lname = _lower_name_with_(cls.__name__)
|
|
258
|
+
assert (
|
|
259
|
+
f"flatten_{lname}" in all_functions
|
|
260
|
+
), f"Unable to find function 'flatten_{lname}' in {list(all_functions)}"
|
|
261
|
+
classes[cls] = (
|
|
262
|
+
lambda verbose=verbose, _ln=lname, cls=cls, _al=all_functions: register_class_serialization( # noqa: E501
|
|
263
|
+
cls,
|
|
264
|
+
_al[f"flatten_{_ln}"],
|
|
265
|
+
_al[f"unflatten_{_ln}"],
|
|
266
|
+
_al[f"flatten_with_keys_{_ln}"],
|
|
267
|
+
verbose=verbose,
|
|
268
|
+
)
|
|
269
|
+
)
|
|
270
|
+
return classes
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def unregister_class_serialization(cls: type, verbose: int = 0):
|
|
274
|
+
"""Undo the registration."""
|
|
275
|
+
# torch.utils._pytree._deregister_pytree_flatten_spec(cls)
|
|
276
|
+
if cls in torch.fx._pytree.SUPPORTED_NODES:
|
|
277
|
+
del torch.fx._pytree.SUPPORTED_NODES[cls]
|
|
278
|
+
if cls in torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH:
|
|
279
|
+
del torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH[cls]
|
|
280
|
+
if hasattr(torch.utils._pytree, "_deregister_pytree_node"):
|
|
281
|
+
# torch >= 2.7
|
|
282
|
+
torch.utils._pytree._deregister_pytree_node(cls)
|
|
283
|
+
else:
|
|
284
|
+
if cls in torch.utils._pytree.SUPPORTED_NODES:
|
|
285
|
+
del torch.utils._pytree.SUPPORTED_NODES[cls]
|
|
286
|
+
optree.unregister_pytree_node(cls, namespace="torch")
|
|
287
|
+
if cls in torch.utils._pytree.SUPPORTED_NODES:
|
|
288
|
+
import packaging.version as pv
|
|
289
|
+
|
|
290
|
+
if pv.Version(torch.__version__) < pv.Version("2.7.0"):
|
|
291
|
+
del torch.utils._pytree.SUPPORTED_NODES[cls]
|
|
292
|
+
assert cls not in torch.utils._pytree.SUPPORTED_NODES, (
|
|
293
|
+
f"{cls} was not successful unregistered "
|
|
294
|
+
f"from torch.utils._pytree.SUPPORTED_NODES="
|
|
295
|
+
f"{pprint.pformat(list(torch.utils._pytree.SUPPORTED_NODES))}"
|
|
296
|
+
)
|
|
297
|
+
if verbose:
|
|
298
|
+
print(f"[unregister_cache_serialization] unregistered {cls.__name__}")
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
|
|
302
|
+
"""Undo all registrations."""
|
|
303
|
+
MambaCache = get_mamba_cache_cls()
|
|
304
|
+
cls_ensemble = (
|
|
305
|
+
{DynamicCache, EncoderDecoderCache}
|
|
306
|
+
| set(undo)
|
|
307
|
+
| ({MambaCache} if MambaCache else set())
|
|
308
|
+
)
|
|
309
|
+
for cls in cls_ensemble:
|
|
310
|
+
if undo.get(cls.__name__, False):
|
|
311
|
+
unregister_class_serialization(cls, verbose)
|
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
import difflib
|
|
2
|
+
import inspect
|
|
3
|
+
import pprint
|
|
4
|
+
import re
|
|
5
|
+
import textwrap
|
|
6
|
+
from typing import Any, Dict, Callable, List, Optional, Tuple, Union
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def clean_code_with_black(code: str) -> str:
|
|
10
|
+
"""Changes the code style with :epkg:`black` if available."""
|
|
11
|
+
code = textwrap.dedent(code)
|
|
12
|
+
try:
|
|
13
|
+
import black
|
|
14
|
+
except ImportError:
|
|
15
|
+
return code
|
|
16
|
+
try:
|
|
17
|
+
return black.format_str(code, mode=black.FileMode(line_length=98))
|
|
18
|
+
except black.parsing.InvalidInput as e:
|
|
19
|
+
raise RuntimeError(f"Unable to parse code\n\n---\n{code}\n---\n") from e
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def make_diff_code(code1: str, code2: str, output: Optional[str] = None) -> str:
|
|
23
|
+
"""
|
|
24
|
+
Creates a diff between two codes.
|
|
25
|
+
|
|
26
|
+
:param code1: first code
|
|
27
|
+
:param code2: second code
|
|
28
|
+
:param output: if not empty, stores the output in this file
|
|
29
|
+
:return: diff
|
|
30
|
+
"""
|
|
31
|
+
text = "\n".join(
|
|
32
|
+
difflib.unified_diff(
|
|
33
|
+
code1.strip().splitlines(),
|
|
34
|
+
code2.strip().splitlines(),
|
|
35
|
+
fromfile="original",
|
|
36
|
+
tofile="rewritten",
|
|
37
|
+
lineterm="",
|
|
38
|
+
)
|
|
39
|
+
)
|
|
40
|
+
if output:
|
|
41
|
+
with open(output, "w") as f:
|
|
42
|
+
f.write(text)
|
|
43
|
+
return text
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class PatchInfo:
|
|
47
|
+
"""
|
|
48
|
+
Stores information about patches.
|
|
49
|
+
|
|
50
|
+
:param function_to_patch: function to patch
|
|
51
|
+
:param patch: function patched
|
|
52
|
+
:param family: a category, anything to classify the patch
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
__slots__ = ("depends_on", "family", "function_to_patch", "patch")
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self, function_to_patch: Union[str, Callable], patch: Callable, family: str = ""
|
|
59
|
+
):
|
|
60
|
+
assert callable(function_to_patch) or isinstance(function_to_patch, str), (
|
|
61
|
+
f"function_to_patch is not a function but {type(function_to_patch)} "
|
|
62
|
+
f"- {function_to_patch!r}"
|
|
63
|
+
)
|
|
64
|
+
assert callable(patch), (
|
|
65
|
+
f"function_to_patch is not a function but {type(patch)} - {patch!r}, "
|
|
66
|
+
f"function_to_patch={function_to_patch!r}"
|
|
67
|
+
)
|
|
68
|
+
assert not callable(function_to_patch) or not function_to_patch.__name__.startswith(
|
|
69
|
+
"patched_"
|
|
70
|
+
), (
|
|
71
|
+
f"A patch was probably not removed because function_to_patch="
|
|
72
|
+
f"{function_to_patch!r} and patch={patch!r}"
|
|
73
|
+
)
|
|
74
|
+
self.family = family
|
|
75
|
+
self.function_to_patch = function_to_patch
|
|
76
|
+
self.patch = patch
|
|
77
|
+
self.depends_on: List[PatchInfo] = []
|
|
78
|
+
|
|
79
|
+
def add_dependency(self, patch_info: "PatchInfo"):
|
|
80
|
+
self.depends_on.append(patch_info)
|
|
81
|
+
|
|
82
|
+
def __repr__(self) -> str:
|
|
83
|
+
"usual"
|
|
84
|
+
return (
|
|
85
|
+
(
|
|
86
|
+
f"{self.__class__.__name__}({self.function_to_patch!r}, {self.patch!r}, "
|
|
87
|
+
f"{self.family!r})"
|
|
88
|
+
)
|
|
89
|
+
if self.family
|
|
90
|
+
else f"{self.__class__.__name__}({self.function_to_patch!r}, {self.patch!r})"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def to_tuple(self) -> Tuple[str, Callable, Callable]:
|
|
94
|
+
"usual"
|
|
95
|
+
return (self.family, self.function_to_patch, self.patch)
|
|
96
|
+
|
|
97
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
98
|
+
"usual"
|
|
99
|
+
return {k: getattr(self, k) for k in self.__slots__}
|
|
100
|
+
|
|
101
|
+
def make_diff(self) -> str:
|
|
102
|
+
"""Returns a diff as a string."""
|
|
103
|
+
if isinstance(self.function_to_patch, str):
|
|
104
|
+
return clean_code_with_black(inspect.getsource(self.patch))
|
|
105
|
+
src1 = clean_code_with_black(inspect.getsource(self.function_to_patch))
|
|
106
|
+
src2 = clean_code_with_black(inspect.getsource(self.patch))
|
|
107
|
+
diff = make_diff_code(src1, src2)
|
|
108
|
+
if not self.depends_on:
|
|
109
|
+
return diff
|
|
110
|
+
res = [diff]
|
|
111
|
+
for d in self.depends_on:
|
|
112
|
+
res.append("")
|
|
113
|
+
res.append(d.make_diff())
|
|
114
|
+
return "\n".join(res)
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
def function_name(cls, f: Callable) -> str:
|
|
118
|
+
return f.__qualname__
|
|
119
|
+
|
|
120
|
+
def format_diff(self, format: str = "raw") -> str:
|
|
121
|
+
"""
|
|
122
|
+
Format a diff between two function as a string.
|
|
123
|
+
|
|
124
|
+
:param format: ``'raw'`` or ``'rst'``
|
|
125
|
+
:return: diff
|
|
126
|
+
|
|
127
|
+
.. runpython::
|
|
128
|
+
:showcode:
|
|
129
|
+
:rst:
|
|
130
|
+
|
|
131
|
+
import transformers
|
|
132
|
+
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as ptr
|
|
133
|
+
from onnx_diagnostic.torch_export_patches.patch_details import PatchInfo
|
|
134
|
+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
|
|
135
|
+
patched_eager_mask,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
eager_mask = transformers.masking_utils.eager_mask
|
|
139
|
+
diff = PatchInfo(eager_mask, patched_eager_mask).format_diff(format="rst")
|
|
140
|
+
print(diff)
|
|
141
|
+
"""
|
|
142
|
+
diff = self.make_diff()
|
|
143
|
+
kind = self.family or ""
|
|
144
|
+
if kind:
|
|
145
|
+
kind = f"{kind}: "
|
|
146
|
+
function_to_pach_name = (
|
|
147
|
+
f"{self.function_to_patch!r}"
|
|
148
|
+
if isinstance(self.function_to_patch, str)
|
|
149
|
+
else self.function_name(self.function_to_patch)
|
|
150
|
+
)
|
|
151
|
+
patch_name = self.function_name(self.patch)
|
|
152
|
+
title = f"{kind}{function_to_pach_name} -> {patch_name}"
|
|
153
|
+
if format == "raw":
|
|
154
|
+
return f"{title}\n{diff}"
|
|
155
|
+
|
|
156
|
+
rows = [
|
|
157
|
+
title,
|
|
158
|
+
"=" * len(title),
|
|
159
|
+
"",
|
|
160
|
+
".. code-block:: diff",
|
|
161
|
+
" :linenos:",
|
|
162
|
+
"",
|
|
163
|
+
textwrap.indent(diff, prefix=" "),
|
|
164
|
+
]
|
|
165
|
+
return "\n".join(rows)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class PatchDetails:
|
|
169
|
+
"""
|
|
170
|
+
This class is used to store patching information.
|
|
171
|
+
This helps understanding which rewriting was applied to which
|
|
172
|
+
method of functions. Page :ref:`l-patch-diff` contains all the
|
|
173
|
+
diff for all the implemented patches.
|
|
174
|
+
|
|
175
|
+
.. runpython::
|
|
176
|
+
:showcode:
|
|
177
|
+
:rst:
|
|
178
|
+
|
|
179
|
+
import torch
|
|
180
|
+
from onnx_diagnostic.torch_export_patches import torch_export_patches
|
|
181
|
+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
|
|
182
|
+
from onnx_diagnostic.torch_export_patches.patch_details import PatchDetails
|
|
183
|
+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
|
|
184
|
+
|
|
185
|
+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0)
|
|
186
|
+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
|
|
187
|
+
details = PatchDetails()
|
|
188
|
+
with torch_export_patches(
|
|
189
|
+
patch_transformers=True, patch_details=details, patch_torch=False
|
|
190
|
+
):
|
|
191
|
+
ep = torch.export.export(
|
|
192
|
+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
|
|
193
|
+
)
|
|
194
|
+
patches = details.patches_involded_in_graph(ep.graph)
|
|
195
|
+
report = details.make_report(patches, format="rst")
|
|
196
|
+
print(report)
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
def __init__(self):
|
|
200
|
+
self.patched = []
|
|
201
|
+
self.find_cache = {}
|
|
202
|
+
|
|
203
|
+
def find(self, name: str) -> Optional[PatchInfo]:
|
|
204
|
+
"Finds a patch by name."
|
|
205
|
+
if name in self.find_cache:
|
|
206
|
+
return self.find_cache[name]
|
|
207
|
+
for p in self.patched:
|
|
208
|
+
if p.patch.__name__ == name:
|
|
209
|
+
self.find_cache[name] = p
|
|
210
|
+
return p
|
|
211
|
+
return None
|
|
212
|
+
|
|
213
|
+
def append(
|
|
214
|
+
self, family: str, function_to_patch: Union[str, Callable], patch: Callable
|
|
215
|
+
) -> PatchInfo:
|
|
216
|
+
"""
|
|
217
|
+
Stores a patch.
|
|
218
|
+
|
|
219
|
+
:param family: a category, anything to classify the patch
|
|
220
|
+
:param function_to_patch: function to patch
|
|
221
|
+
:param patch: function patched
|
|
222
|
+
:return: instance of PatchInfo
|
|
223
|
+
"""
|
|
224
|
+
p = PatchInfo(function_to_patch, patch, family=family)
|
|
225
|
+
self.patched.append(p)
|
|
226
|
+
return p
|
|
227
|
+
|
|
228
|
+
@property
|
|
229
|
+
def n_patches(self) -> int:
|
|
230
|
+
"Returns the number of stored patches."
|
|
231
|
+
# Overwritten __len__ may have an impact on bool(patch_details: PatchDetails)
|
|
232
|
+
return len(self.patched)
|
|
233
|
+
|
|
234
|
+
def data(self) -> List[Dict[str, Any]]:
|
|
235
|
+
"""Returns the data for a dataframe."""
|
|
236
|
+
return [p.to_dict() for p in self.patched]
|
|
237
|
+
|
|
238
|
+
def patches_involded_in_graph(
|
|
239
|
+
self, graph: "torch.fx.Graph" # noqa: F821
|
|
240
|
+
) -> List[Tuple[PatchInfo, List["torch.fx.Node"]]]: # noqa: F821
|
|
241
|
+
"""
|
|
242
|
+
Enumerates all patches impacting a graph.
|
|
243
|
+
The function goes through the graph node (only the main graph) and
|
|
244
|
+
looks into the metadata to determine if a listed patch was involved.
|
|
245
|
+
|
|
246
|
+
:param graph: fx graph
|
|
247
|
+
:return: list of nodes impacted by a patch
|
|
248
|
+
"""
|
|
249
|
+
patches = []
|
|
250
|
+
for patch in self.patched:
|
|
251
|
+
f = patch.patch
|
|
252
|
+
source = inspect.getsourcefile(f)
|
|
253
|
+
lines, lineno = inspect.getsourcelines(f)
|
|
254
|
+
interval = [lineno, lineno + len(lines)]
|
|
255
|
+
patches.append((patch, f, source, interval))
|
|
256
|
+
|
|
257
|
+
cst = "onnx_diagnostic"
|
|
258
|
+
node_stack = []
|
|
259
|
+
for node in graph.nodes:
|
|
260
|
+
meta = node.meta
|
|
261
|
+
if "stack_trace" not in meta:
|
|
262
|
+
continue
|
|
263
|
+
stack = meta["stack_trace"]
|
|
264
|
+
if cst not in stack:
|
|
265
|
+
# to reduce the cost of the next iteration
|
|
266
|
+
continue
|
|
267
|
+
node_stack.append((node, stack))
|
|
268
|
+
|
|
269
|
+
patch_node = []
|
|
270
|
+
patched_nodes = set()
|
|
271
|
+
for patch, _f, source, interval in patches:
|
|
272
|
+
exp = 'File "([^"]*?%s[^"]+?)", line (\\d+)' % cst
|
|
273
|
+
reg = re.compile(exp)
|
|
274
|
+
for node, stack in node_stack:
|
|
275
|
+
occ = reg.findall(stack)
|
|
276
|
+
if not occ:
|
|
277
|
+
continue
|
|
278
|
+
for filename, line_number in occ:
|
|
279
|
+
if source.replace("\\", "/").strip("/") != filename.replace(
|
|
280
|
+
"\\", "/"
|
|
281
|
+
).strip("/"):
|
|
282
|
+
continue
|
|
283
|
+
line = int(line_number)
|
|
284
|
+
if (
|
|
285
|
+
line >= interval[0]
|
|
286
|
+
and line <= interval[1]
|
|
287
|
+
and self.matching_pair(patch, node)
|
|
288
|
+
):
|
|
289
|
+
patch_node.append((patch, node))
|
|
290
|
+
patched_nodes.add(id(node))
|
|
291
|
+
|
|
292
|
+
# checks all patches were discovered
|
|
293
|
+
for node, _ in node_stack:
|
|
294
|
+
assert id(node) in patched_nodes, (
|
|
295
|
+
f"One node was patched but no patch was found:\n"
|
|
296
|
+
f"node: {node.target}({','.join(map(str, node.args))}) -> {node.name}"
|
|
297
|
+
f"\n--\n{pprint.pformat(node.meta)}"
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
res = {} # type: ignore[var-annotated]
|
|
301
|
+
for patch, node in patch_node:
|
|
302
|
+
if patch not in res:
|
|
303
|
+
res[patch] = []
|
|
304
|
+
res[patch].append(node)
|
|
305
|
+
return list(res.items())
|
|
306
|
+
|
|
307
|
+
def matching_pair(cls, patch: PatchInfo, node: "torch.fx.Node") -> bool: # noqa: F821
|
|
308
|
+
"""
|
|
309
|
+
Last validation for a pair. RotaryEmbedding has many rewriting
|
|
310
|
+
and they all end up in the same code line.
|
|
311
|
+
"""
|
|
312
|
+
cls_name = patch.function_to_patch.__qualname__.split(".")[0]
|
|
313
|
+
if not cls_name.endswith("RotaryEmbedding"):
|
|
314
|
+
return True
|
|
315
|
+
return cls_name in str(node.meta)
|
|
316
|
+
|
|
317
|
+
def make_report(
|
|
318
|
+
cls,
|
|
319
|
+
patches: List[Tuple[PatchInfo, List["torch.fx.Node"]]], # noqa: F821
|
|
320
|
+
format: str = "raw",
|
|
321
|
+
) -> str:
|
|
322
|
+
"""
|
|
323
|
+
Creates a report based on the involved patches.
|
|
324
|
+
|
|
325
|
+
:param patches: from method :meth:`patches_involded_in_graph`
|
|
326
|
+
:param format: format of the report
|
|
327
|
+
:return: report
|
|
328
|
+
"""
|
|
329
|
+
rows = []
|
|
330
|
+
for patch, nodes in patches:
|
|
331
|
+
rows.append(patch.format_diff(format=format))
|
|
332
|
+
rows.append("")
|
|
333
|
+
if format == "rst":
|
|
334
|
+
rows.extend(["", "", "**impacted nodes**", "", "", ".. code-block::", ""])
|
|
335
|
+
for node in nodes:
|
|
336
|
+
rows.append(
|
|
337
|
+
f" {node.target}({', '.join(map(str,node.args))}) -> {node.name}"
|
|
338
|
+
)
|
|
339
|
+
rows.append("")
|
|
340
|
+
return "\n".join(rows)
|