onnx-diagnostic 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +411 -0
- onnx_diagnostic/doc.py +4 -4
- onnx_diagnostic/export/__init__.py +1 -1
- onnx_diagnostic/export/dynamic_shapes.py +433 -22
- onnx_diagnostic/ext_test_case.py +86 -29
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/{cache_helpers.py → helpers/cache_helper.py} +41 -5
- onnx_diagnostic/{helpers.py → helpers/helper.py} +136 -659
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/onnx_helper.py +921 -0
- onnx_diagnostic/{ort_session.py → helpers/ort_session.py} +42 -3
- onnx_diagnostic/{torch_test_helper.py → helpers/torch_test_helper.py} +138 -55
- onnx_diagnostic/reference/ops/op_cast_like.py +1 -1
- onnx_diagnostic/reference/ort_evaluator.py +7 -2
- onnx_diagnostic/torch_export_patches/__init__.py +107 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +137 -33
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +13 -2
- onnx_diagnostic/torch_export_patches/patch_inputs.py +174 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -2
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +4 -4
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +195 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +3259 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +727 -0
- onnx_diagnostic/torch_models/test_helper.py +827 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +3 -4
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +3 -4
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/sbs.py +439 -0
- {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.3.0.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.3.0.dist-info}/RECORD +39 -25
- onnx_diagnostic/onnx_tools.py +0 -260
- /onnx_diagnostic/{args.py → helpers/args_helper.py} +0 -0
- {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.3.0.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.3.0.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
import pprint
|
|
3
|
-
from typing import Any, Callable, Dict
|
|
3
|
+
from typing import Any, Callable, Dict, List, Optional, Set
|
|
4
4
|
from .onnx_export_serialization import (
|
|
5
5
|
flatten_with_keys_dynamic_cache,
|
|
6
6
|
flatten_dynamic_cache,
|
|
@@ -12,27 +12,36 @@ from .onnx_export_serialization import (
|
|
|
12
12
|
from .patches import patch_transformers as patch_transformers_list
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
def
|
|
15
|
+
def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
|
|
16
16
|
"""
|
|
17
17
|
Applies all patches defined in classes prefixed by ``patched_``
|
|
18
18
|
``cls._PATCHED_CLASS_`` defines the class to patch,
|
|
19
19
|
``cls._PATCHES_`` defines the method to patch.
|
|
20
|
-
The returns information needs to be sent to :func:`
|
|
20
|
+
The returns information needs to be sent to :func:`unpatch_module_or_classes`
|
|
21
21
|
to revert the changes.
|
|
22
|
+
|
|
23
|
+
:param mod: module of list of clsses to patch
|
|
24
|
+
:param verbose: verbosity
|
|
25
|
+
:return: patch info
|
|
22
26
|
"""
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
27
|
+
if isinstance(mod, list):
|
|
28
|
+
to_patch = mod
|
|
29
|
+
name = "list"
|
|
30
|
+
else:
|
|
31
|
+
to_patch = []
|
|
32
|
+
for k in dir(mod):
|
|
33
|
+
if k.startswith("patched_"):
|
|
34
|
+
v = getattr(mod, k)
|
|
35
|
+
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
|
|
36
|
+
to_patch.append(v)
|
|
37
|
+
name = mod.__name__
|
|
29
38
|
|
|
30
39
|
res = {}
|
|
31
40
|
for cls in to_patch:
|
|
32
41
|
original = cls._PATCHED_CLASS_
|
|
33
42
|
methods = cls._PATCHES_
|
|
34
43
|
if verbose:
|
|
35
|
-
print(f"[
|
|
44
|
+
print(f"[patch_module_or_classes] {name} - {cls.__name__}: {', '.join(methods)}")
|
|
36
45
|
|
|
37
46
|
keep = {n: getattr(original, n, None) for n in methods}
|
|
38
47
|
for n in methods:
|
|
@@ -42,20 +51,30 @@ def patch_module(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
|
|
|
42
51
|
return res
|
|
43
52
|
|
|
44
53
|
|
|
45
|
-
def
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
54
|
+
def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0):
|
|
55
|
+
"""
|
|
56
|
+
Reverts modification made by :func:`patch_module_or_classes`.
|
|
57
|
+
|
|
58
|
+
:param mod: module of list of clsses to patch
|
|
59
|
+
:param verbose: verbosity
|
|
60
|
+
"""
|
|
61
|
+
if isinstance(mod, list):
|
|
62
|
+
to_patch = mod
|
|
63
|
+
name = "list"
|
|
64
|
+
else:
|
|
65
|
+
to_patch = []
|
|
66
|
+
for k in dir(mod):
|
|
67
|
+
if k.startswith("patched_"):
|
|
68
|
+
v = getattr(mod, k)
|
|
69
|
+
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
|
|
70
|
+
to_patch.append(v)
|
|
71
|
+
name = mod.__name__
|
|
53
72
|
set_patch = set(to_patch)
|
|
54
73
|
|
|
55
74
|
for cls, methods in info.items():
|
|
56
75
|
assert cls in set_patch, f"No patch registered for {cls} in {mod} (found {set_patch})"
|
|
57
76
|
if verbose:
|
|
58
|
-
print(f"[
|
|
77
|
+
print(f"[unpatch_module_or_classes] {name} - {cls.__name__}: {', '.join(methods)}")
|
|
59
78
|
original = cls._PATCHED_CLASS_
|
|
60
79
|
for n, v in methods.items():
|
|
61
80
|
if v is None:
|
|
@@ -65,9 +84,14 @@ def unpatch_module(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0
|
|
|
65
84
|
setattr(original, n, v)
|
|
66
85
|
|
|
67
86
|
|
|
87
|
+
PATCH_OF_PATCHES: Set[Any] = set()
|
|
88
|
+
|
|
89
|
+
|
|
68
90
|
def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
|
|
69
91
|
# Cache serialization: to be moved into appropriate packages
|
|
70
92
|
import torch
|
|
93
|
+
import transformers
|
|
94
|
+
import packaging.version as pv
|
|
71
95
|
|
|
72
96
|
try:
|
|
73
97
|
from transformers.cache_utils import DynamicCache
|
|
@@ -100,7 +124,40 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
|
|
|
100
124
|
flatten_with_keys_fn=flatten_with_keys_mamba_cache,
|
|
101
125
|
)
|
|
102
126
|
|
|
103
|
-
# DynamicCache
|
|
127
|
+
# DynamicCache serialization is different in transformers and does not
|
|
128
|
+
# play way with torch.export.export.
|
|
129
|
+
# see test test_export_dynamic_cache_cat with NOBYPASS=1
|
|
130
|
+
# :: NOBYBASS=1 python _unittests/ut_torch_export_patches/test_dynamic_class.py -k e_c
|
|
131
|
+
# This is caused by this line:
|
|
132
|
+
# torch.fx._pytree.register_pytree_flatten_spec(
|
|
133
|
+
# DynamicCache, _flatten_dynamic_cache_for_fx)
|
|
134
|
+
# so we remove it anyway
|
|
135
|
+
if (
|
|
136
|
+
DynamicCache in torch.fx._pytree.SUPPORTED_NODES
|
|
137
|
+
and not PATCH_OF_PATCHES
|
|
138
|
+
# and pv.Version(torch.__version__) < pv.Version("2.7")
|
|
139
|
+
and pv.Version(transformers.__version__) >= pv.Version("4.50")
|
|
140
|
+
):
|
|
141
|
+
if verbose:
|
|
142
|
+
print(
|
|
143
|
+
"[_register_cache_serialization] DynamicCache "
|
|
144
|
+
"is unregistered and registered first."
|
|
145
|
+
)
|
|
146
|
+
_unregister(DynamicCache)
|
|
147
|
+
torch.utils._pytree.register_pytree_node(
|
|
148
|
+
DynamicCache,
|
|
149
|
+
flatten_dynamic_cache,
|
|
150
|
+
unflatten_dynamic_cache,
|
|
151
|
+
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
|
|
152
|
+
flatten_with_keys_fn=flatten_with_keys_dynamic_cache,
|
|
153
|
+
)
|
|
154
|
+
if pv.Version(torch.__version__) < pv.Version("2.7"):
|
|
155
|
+
torch.fx._pytree.register_pytree_flatten_spec(
|
|
156
|
+
DynamicCache, lambda x, _: [x.key_cache, x.value_cache]
|
|
157
|
+
)
|
|
158
|
+
# To avoid doing it multiple times.
|
|
159
|
+
PATCH_OF_PATCHES.add(DynamicCache)
|
|
160
|
+
|
|
104
161
|
unregistered_dynamic_cache = True
|
|
105
162
|
if DynamicCache is not None and DynamicCache in torch.utils._pytree.SUPPORTED_NODES:
|
|
106
163
|
if verbose > 1:
|
|
@@ -116,12 +173,13 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
|
|
|
116
173
|
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
|
|
117
174
|
flatten_with_keys_fn=flatten_with_keys_dynamic_cache,
|
|
118
175
|
)
|
|
119
|
-
torch.
|
|
120
|
-
|
|
121
|
-
|
|
176
|
+
if pv.Version(torch.__version__) < pv.Version("2.7"):
|
|
177
|
+
torch.fx._pytree.register_pytree_flatten_spec(
|
|
178
|
+
DynamicCache, lambda x, _: [x.key_cache, x.value_cache]
|
|
179
|
+
)
|
|
122
180
|
|
|
123
181
|
# check
|
|
124
|
-
from ..
|
|
182
|
+
from ..helpers.cache_helper import make_dynamic_cache
|
|
125
183
|
|
|
126
184
|
cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
|
|
127
185
|
values, spec = torch.utils._pytree.tree_flatten(cache)
|
|
@@ -180,7 +238,7 @@ def _unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
|
|
|
180
238
|
def register_additional_serialization_functions(
|
|
181
239
|
patch_transformers: bool = False, verbose: int = 0
|
|
182
240
|
) -> Callable:
|
|
183
|
-
"""The necessary
|
|
241
|
+
"""The necessary modifications to run the fx Graph."""
|
|
184
242
|
fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x)
|
|
185
243
|
done = _register_cache_serialization(verbose=verbose)
|
|
186
244
|
try:
|
|
@@ -195,9 +253,10 @@ def bypass_export_some_errors(
|
|
|
195
253
|
patch_torch: bool = True,
|
|
196
254
|
patch_transformers: bool = False,
|
|
197
255
|
catch_constraints: bool = True,
|
|
198
|
-
stop_if_static:
|
|
256
|
+
stop_if_static: int = 0,
|
|
199
257
|
verbose: int = 0,
|
|
200
258
|
patch: bool = True,
|
|
259
|
+
custom_patches: Optional[List[type["torch.nn.Module"]]] = None, # noqa: F821
|
|
201
260
|
) -> Callable:
|
|
202
261
|
"""
|
|
203
262
|
Tries to bypass some situations :func:`torch.export.export` does not support.
|
|
@@ -211,9 +270,14 @@ def bypass_export_some_errors(
|
|
|
211
270
|
can be put to stop at that stage.
|
|
212
271
|
:param stop_if_static: see example :ref:`l-plot-export-locale-issue`,
|
|
213
272
|
to stop the export as soon as an issue is detected with dynamic shapes
|
|
214
|
-
and show a stack trace indicating the exact location of the issue
|
|
273
|
+
and show a stack trace indicating the exact location of the issue,
|
|
274
|
+
``if stop_if_static > 1``, more methods are replace to catch more
|
|
275
|
+
issues
|
|
215
276
|
:param patch: if False, disable all patches except the registration of
|
|
216
277
|
serialization function
|
|
278
|
+
:param custom_patches: to apply custom patches,
|
|
279
|
+
every patched class must define static attributes
|
|
280
|
+
``_PATCHES_``, ``_PATCHED_CLASS_``
|
|
217
281
|
:param verbose: to show which patches is applied
|
|
218
282
|
|
|
219
283
|
The list of available patches.
|
|
@@ -301,6 +365,7 @@ def bypass_export_some_errors(
|
|
|
301
365
|
f_sympy_name = getattr(sympy.core.numbers.IntegerConstant, "name", None)
|
|
302
366
|
|
|
303
367
|
if verbose:
|
|
368
|
+
print(f"[bypass_export_some_errors] sympy.__version__={sympy.__version__!r}")
|
|
304
369
|
print("[bypass_export_some_errors] patch sympy")
|
|
305
370
|
|
|
306
371
|
sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}"
|
|
@@ -318,6 +383,8 @@ def bypass_export_some_errors(
|
|
|
318
383
|
)
|
|
319
384
|
|
|
320
385
|
if verbose:
|
|
386
|
+
print(f"[bypass_export_some_errors] torch.__version__={torch.__version__!r}")
|
|
387
|
+
print(f"[bypass_export_some_errors] stop_if_static={stop_if_static!r}")
|
|
321
388
|
print("[bypass_export_some_errors] patch pytorch")
|
|
322
389
|
|
|
323
390
|
# torch.jit.isinstance
|
|
@@ -359,23 +426,46 @@ def bypass_export_some_errors(
|
|
|
359
426
|
)
|
|
360
427
|
|
|
361
428
|
if stop_if_static:
|
|
429
|
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
430
|
+
from .patches.patch_torch import patched_ShapeEnv
|
|
431
|
+
|
|
362
432
|
if verbose:
|
|
363
433
|
print(
|
|
364
434
|
"[bypass_export_some_errors] assert when a dynamic dimension turns static"
|
|
365
435
|
)
|
|
366
|
-
|
|
367
|
-
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
368
|
-
from .patches.patch_torch import patched_ShapeEnv
|
|
436
|
+
print("[bypass_export_some_errors] replaces ShapeEnv._set_replacement")
|
|
369
437
|
|
|
370
438
|
f_shape_env__set_replacement = ShapeEnv._set_replacement
|
|
371
439
|
ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement
|
|
372
440
|
|
|
441
|
+
if stop_if_static > 1:
|
|
442
|
+
if verbose:
|
|
443
|
+
print("[bypass_export_some_errors] replaces ShapeEnv._check_frozen")
|
|
444
|
+
f_shape_env__check_frozen = ShapeEnv._check_frozen
|
|
445
|
+
ShapeEnv._check_frozen = patched_ShapeEnv._check_frozen
|
|
446
|
+
|
|
373
447
|
####################
|
|
374
448
|
# patch transformers
|
|
375
449
|
####################
|
|
376
450
|
|
|
377
451
|
if patch_transformers:
|
|
378
|
-
|
|
452
|
+
if verbose:
|
|
453
|
+
import transformers
|
|
454
|
+
|
|
455
|
+
print(
|
|
456
|
+
f"[bypass_export_some_errors] transformers.__version__="
|
|
457
|
+
f"{transformers.__version__!r}"
|
|
458
|
+
)
|
|
459
|
+
revert_patches_info = patch_module_or_classes(
|
|
460
|
+
patch_transformers_list, verbose=verbose
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
if custom_patches:
|
|
464
|
+
if verbose:
|
|
465
|
+
print("[bypass_export_some_errors] applies custom patches")
|
|
466
|
+
revert_custom_patches_info = patch_module_or_classes(
|
|
467
|
+
custom_patches, verbose=verbose
|
|
468
|
+
)
|
|
379
469
|
|
|
380
470
|
########
|
|
381
471
|
# export
|
|
@@ -397,7 +487,6 @@ def bypass_export_some_errors(
|
|
|
397
487
|
print("[bypass_export_some_errors] remove patches")
|
|
398
488
|
|
|
399
489
|
if patch_sympy:
|
|
400
|
-
|
|
401
490
|
# tracked by https://github.com/pytorch/pytorch/issues/143494
|
|
402
491
|
if f_sympy_name:
|
|
403
492
|
sympy.core.numbers.IntegerConstant.name = f_sympy_name
|
|
@@ -428,6 +517,10 @@ def bypass_export_some_errors(
|
|
|
428
517
|
print("[bypass_export_some_errors] restored ShapeEnv._set_replacement")
|
|
429
518
|
|
|
430
519
|
ShapeEnv._set_replacement = f_shape_env__set_replacement
|
|
520
|
+
if stop_if_static > 1:
|
|
521
|
+
if verbose:
|
|
522
|
+
print("[bypass_export_some_errors] restored ShapeEnv._check_frozen")
|
|
523
|
+
ShapeEnv._check_frozen = f_shape_env__check_frozen
|
|
431
524
|
|
|
432
525
|
if catch_constraints:
|
|
433
526
|
# to catch or skip dynamic_shapes issues
|
|
@@ -440,12 +533,23 @@ def bypass_export_some_errors(
|
|
|
440
533
|
if verbose:
|
|
441
534
|
print("[bypass_export_some_errors] restored shape constraints")
|
|
442
535
|
|
|
536
|
+
if custom_patches:
|
|
537
|
+
if verbose:
|
|
538
|
+
print("[bypass_export_some_errors] unpatch custom patches")
|
|
539
|
+
unpatch_module_or_classes(
|
|
540
|
+
custom_patches, revert_custom_patches_info, verbose=verbose
|
|
541
|
+
)
|
|
542
|
+
|
|
443
543
|
##############
|
|
444
544
|
# transformers
|
|
445
545
|
##############
|
|
446
546
|
|
|
447
547
|
if patch_transformers:
|
|
448
|
-
|
|
548
|
+
if verbose:
|
|
549
|
+
print("[bypass_export_some_errors] unpatch transformers")
|
|
550
|
+
unpatch_module_or_classes(
|
|
551
|
+
patch_transformers_list, revert_patches_info, verbose=verbose
|
|
552
|
+
)
|
|
449
553
|
|
|
450
554
|
########
|
|
451
555
|
# caches
|
|
@@ -97,6 +97,10 @@ def flatten_dynamic_cache(
|
|
|
97
97
|
dynamic_cache: transformers.cache_utils.DynamicCache,
|
|
98
98
|
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
99
99
|
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
100
|
+
import transformers.cache_utils
|
|
101
|
+
|
|
102
|
+
if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"):
|
|
103
|
+
return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache)
|
|
100
104
|
flat = [
|
|
101
105
|
(k, getattr(dynamic_cache, k))
|
|
102
106
|
for k in ["key_cache", "value_cache"]
|
|
@@ -111,7 +115,10 @@ def flatten_with_keys_dynamic_cache(d: Dict[Any, Any]) -> Tuple[
|
|
|
111
115
|
]:
|
|
112
116
|
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
113
117
|
import torch
|
|
118
|
+
import transformers.cache_utils
|
|
114
119
|
|
|
120
|
+
if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"):
|
|
121
|
+
return transformers.cache_utils._flatten_with_keys_dynamic_cache(d)
|
|
115
122
|
values, context = flatten_dynamic_cache(d)
|
|
116
123
|
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
117
124
|
|
|
@@ -122,9 +129,13 @@ def unflatten_dynamic_cache(
|
|
|
122
129
|
output_type=None,
|
|
123
130
|
) -> transformers.cache_utils.DynamicCache:
|
|
124
131
|
"""Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
|
|
125
|
-
|
|
132
|
+
import transformers.cache_utils
|
|
133
|
+
|
|
134
|
+
if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"):
|
|
135
|
+
assert output_type is None, f"output_type={output_type} not supported"
|
|
136
|
+
return transformers.cache_utils._unflatten_dynamic_cache(values, context)
|
|
126
137
|
|
|
127
|
-
cache = DynamicCache()
|
|
138
|
+
cache = transformers.cache_utils.DynamicCache()
|
|
128
139
|
values = dict(zip(context, values))
|
|
129
140
|
for k, v in values.items():
|
|
130
141
|
setattr(cache, k, v)
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from typing import Any, Dict, Optional, Tuple
|
|
3
|
+
import torch
|
|
4
|
+
import transformers
|
|
5
|
+
from ..helpers import string_type
|
|
6
|
+
from ..helpers.cache_helper import make_dynamic_cache
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _process_cache(k: str, v):
|
|
10
|
+
assert k != "position_ids" or isinstance(
|
|
11
|
+
k, torch.Tensor
|
|
12
|
+
), f"Unexpected type for parameter {k!r} {string_type(v, with_shape=True)}"
|
|
13
|
+
if (
|
|
14
|
+
isinstance(v, list)
|
|
15
|
+
and all(isinstance(i, tuple) for i in v)
|
|
16
|
+
and set(len(t) for t in v) == {2}
|
|
17
|
+
):
|
|
18
|
+
# A dynamicCache
|
|
19
|
+
cache = make_dynamic_cache(v)
|
|
20
|
+
return cache
|
|
21
|
+
if isinstance(v, torch.Tensor):
|
|
22
|
+
return v
|
|
23
|
+
raise NotImplementedError(
|
|
24
|
+
f"Unable to process parameter {k!r} with v={string_type(v,with_shape=True)}"
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _make_shape(subset: Dict, cls: type, value: Any) -> Any:
|
|
29
|
+
if cls is transformers.cache_utils.DynamicCache:
|
|
30
|
+
assert subset, "DynamicCache cannot be empty"
|
|
31
|
+
values = set(map(str, subset.values()))
|
|
32
|
+
assert len(values) == 1, (
|
|
33
|
+
f"Inconsistencies in subset={subset}, found={values}, "
|
|
34
|
+
f"it cannot be a {cls}, value={string_type(value)}"
|
|
35
|
+
)
|
|
36
|
+
cache_length = len(value.key_cache)
|
|
37
|
+
for v in subset.values():
|
|
38
|
+
axes = v
|
|
39
|
+
break
|
|
40
|
+
new_shape = [[axes for i in range(cache_length)], [axes for i in range(cache_length)]]
|
|
41
|
+
return new_shape
|
|
42
|
+
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
|
|
43
|
+
raise NotImplementedError(
|
|
44
|
+
f"_make_shape not implemented for registered class={cls}, "
|
|
45
|
+
f"subset={subset}, value={string_type(value)}"
|
|
46
|
+
)
|
|
47
|
+
raise NotImplementedError(
|
|
48
|
+
f"_make_shape not implemented for cls={cls}, "
|
|
49
|
+
f"subset={subset}, value={string_type(value)}"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def convert_dynamic_axes_into_dynamic_shapes(
|
|
54
|
+
model: torch.nn.Module,
|
|
55
|
+
args: Optional[Tuple[Any, ...]] = None,
|
|
56
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
|
57
|
+
dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None,
|
|
58
|
+
prefix_mapping: Optional[Dict[str, str]] = None,
|
|
59
|
+
verbose: int = 0,
|
|
60
|
+
) -> Tuple[Tuple[Any, ...], Dict[str, Any], Dict[str, Any]]:
|
|
61
|
+
"""
|
|
62
|
+
Converts the input from an export to something :func:`torch.export.export` can handle.
|
|
63
|
+
|
|
64
|
+
:param model: model to convert (used to extract the signature)
|
|
65
|
+
:param args: positional arguments
|
|
66
|
+
:param kwargs: named arguments
|
|
67
|
+
:param dynamic_axes: dynamic axes
|
|
68
|
+
:param prefix_mapping: prefix mapping
|
|
69
|
+
:param verbose: verbosity
|
|
70
|
+
:return: (args, kwargs, dynamic shapes)
|
|
71
|
+
"""
|
|
72
|
+
new_kwargs = {}
|
|
73
|
+
if args:
|
|
74
|
+
assert hasattr(model, "forward"), f"Missing method 'forward' for {model!r}"
|
|
75
|
+
plus = 0 if isinstance(model, torch.nn.Module) else 1
|
|
76
|
+
print(
|
|
77
|
+
f"[convert_dynamic_axes_into_dynamic_shapes] "
|
|
78
|
+
f"mapping args to kwargs for model="
|
|
79
|
+
f"{model if plus else model.__class__.__name__}"
|
|
80
|
+
)
|
|
81
|
+
pars = inspect.signature(model.forward).parameters
|
|
82
|
+
assert len(pars) >= len(
|
|
83
|
+
args
|
|
84
|
+
), f"Length mismatch, len(args)={len(args)}, pars={list(pars)}"
|
|
85
|
+
|
|
86
|
+
for i, p in enumerate(pars):
|
|
87
|
+
if i < plus:
|
|
88
|
+
continue
|
|
89
|
+
if i - plus >= len(args):
|
|
90
|
+
break
|
|
91
|
+
if verbose:
|
|
92
|
+
print(
|
|
93
|
+
f"[convert_dynamic_axes_into_dynamic_shapes] mapping args[{i-plus}] "
|
|
94
|
+
f"to {p!r} ({string_type(args[i-plus])})"
|
|
95
|
+
)
|
|
96
|
+
new_kwargs[p] = args[i - plus]
|
|
97
|
+
|
|
98
|
+
if kwargs:
|
|
99
|
+
for k, v in kwargs.items():
|
|
100
|
+
assert k not in new_kwargs, f"Argument {k!r} from kwargs already present in args."
|
|
101
|
+
new_kwargs[k] = v
|
|
102
|
+
|
|
103
|
+
# process
|
|
104
|
+
updated_kwargs = {}
|
|
105
|
+
changes = {}
|
|
106
|
+
for k, v in new_kwargs.items():
|
|
107
|
+
if isinstance(v, torch.Tensor):
|
|
108
|
+
updated_kwargs[k] = v
|
|
109
|
+
continue
|
|
110
|
+
if isinstance(v, list):
|
|
111
|
+
# cache?
|
|
112
|
+
updated_kwargs[k] = _process_cache(k, v)
|
|
113
|
+
if type(updated_kwargs[k]) is not type(v):
|
|
114
|
+
# A cache was introduced.
|
|
115
|
+
if verbose:
|
|
116
|
+
print(
|
|
117
|
+
f"[convert_dynamic_axes_into_dynamic_shapes] parameter "
|
|
118
|
+
f"{k!r} was changed into {type(updated_kwargs[k])}"
|
|
119
|
+
)
|
|
120
|
+
changes[k] = type(updated_kwargs[k])
|
|
121
|
+
continue
|
|
122
|
+
raise NotImplementedError(
|
|
123
|
+
f"Unexpected type {type(v)} for parameter {k!r} "
|
|
124
|
+
f"({string_type(v, with_shape=True)})"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# process dynamic axes
|
|
128
|
+
if changes:
|
|
129
|
+
dynamic_shapes = {}
|
|
130
|
+
done = set()
|
|
131
|
+
for k, v in dynamic_axes.items():
|
|
132
|
+
if k not in changes and k in updated_kwargs and isinstance(v, dict):
|
|
133
|
+
dynamic_shapes[k] = v
|
|
134
|
+
continue
|
|
135
|
+
if "." in k:
|
|
136
|
+
# something like present.0.key
|
|
137
|
+
prefix = k.split(".")[0]
|
|
138
|
+
if prefix in done:
|
|
139
|
+
continue
|
|
140
|
+
args_prefix = (
|
|
141
|
+
prefix_mapping[prefix]
|
|
142
|
+
if prefix_mapping and prefix in prefix_mapping
|
|
143
|
+
else prefix
|
|
144
|
+
)
|
|
145
|
+
if args_prefix in updated_kwargs and args_prefix in changes:
|
|
146
|
+
# A cache.
|
|
147
|
+
cls = changes[args_prefix]
|
|
148
|
+
dynamic_shapes[args_prefix] = _make_shape(
|
|
149
|
+
{
|
|
150
|
+
_: __
|
|
151
|
+
for _, __ in dynamic_axes.items()
|
|
152
|
+
if _.startswith(f"{prefix}.")
|
|
153
|
+
},
|
|
154
|
+
cls,
|
|
155
|
+
updated_kwargs[args_prefix],
|
|
156
|
+
)
|
|
157
|
+
done.add(prefix)
|
|
158
|
+
continue
|
|
159
|
+
if k not in updated_kwargs:
|
|
160
|
+
# dynamic axes not in the given inputs, should be raise an exception?
|
|
161
|
+
if verbose:
|
|
162
|
+
print(
|
|
163
|
+
f"[convert_dynamic_axes_into_dynamic_shapes] dropping axes "
|
|
164
|
+
f"{k!r}-{v!r}, not found in {set(updated_kwargs)}"
|
|
165
|
+
)
|
|
166
|
+
continue
|
|
167
|
+
raise NotImplementedError(
|
|
168
|
+
f"Unable to process dynamic axes {k!r}, axes={v}, "
|
|
169
|
+
f"value={string_type(updated_kwargs[k], with_shape=True)}, "
|
|
170
|
+
f"dynamic axes={dynamic_axes}, "
|
|
171
|
+
f"updated_kwargs={string_type(updated_kwargs, with_shape=True)}"
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
return (), updated_kwargs, dynamic_shapes
|
|
@@ -131,7 +131,7 @@ def patched__broadcast_shapes(*_shapes):
|
|
|
131
131
|
assert isinstance(shape, Sequence)
|
|
132
132
|
|
|
133
133
|
# Computes common shape
|
|
134
|
-
common_shape
|
|
134
|
+
common_shape = [ # List[Union[int, torch.SymInt]]
|
|
135
135
|
1,
|
|
136
136
|
] * reduce(max, (len(shape) for shape in shapes))
|
|
137
137
|
for _arg_idx, shape in enumerate(shapes):
|
|
@@ -150,6 +150,16 @@ def patched__broadcast_shapes(*_shapes):
|
|
|
150
150
|
|
|
151
151
|
class patched_ShapeEnv:
|
|
152
152
|
|
|
153
|
+
def _check_frozen(
|
|
154
|
+
self, expr: "sympy.Basic", concrete_val: "sympy.Basic" # noqa: F821
|
|
155
|
+
) -> None:
|
|
156
|
+
if self.frozen:
|
|
157
|
+
self.counter["ignored_backward_guard"] += 1
|
|
158
|
+
raise AssertionError(
|
|
159
|
+
f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
|
|
160
|
+
f"this could result in accuracy problems."
|
|
161
|
+
)
|
|
162
|
+
|
|
153
163
|
def _set_replacement(
|
|
154
164
|
self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str # noqa: F821
|
|
155
165
|
) -> None:
|
|
@@ -314,7 +324,7 @@ class patched_ShapeEnv:
|
|
|
314
324
|
# )
|
|
315
325
|
# self.log.debug("SPECIALIZATION", stack_info=True)
|
|
316
326
|
assert msg != "range_refined_to_singleton", (
|
|
317
|
-
f"A dynamic dimension becomes static! "
|
|
327
|
+
f"patched_ShapeEnv: A dynamic dimension becomes static! "
|
|
318
328
|
f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
|
|
319
329
|
)
|
|
320
330
|
# log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)
|
|
@@ -4,9 +4,9 @@ from dataclasses import dataclass
|
|
|
4
4
|
from typing import Any, Dict, List, Optional, Tuple
|
|
5
5
|
import torch
|
|
6
6
|
import transformers
|
|
7
|
-
|
|
7
|
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
8
8
|
from transformers.cache_utils import StaticCache, Cache, DynamicCache
|
|
9
|
-
from ...torch_test_helper import is_torchdynamo_exporting
|
|
9
|
+
from ...helpers.torch_test_helper import is_torchdynamo_exporting
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def _patch_make_causal_mask(
|
|
@@ -54,7 +54,7 @@ if sys.version_info[:2] <= (3, 11):
|
|
|
54
54
|
"""
|
|
55
55
|
|
|
56
56
|
_PATCHES_ = ["_make_causal_mask"]
|
|
57
|
-
_PATCHED_CLASS_ =
|
|
57
|
+
_PATCHED_CLASS_ = AttentionMaskConverter
|
|
58
58
|
|
|
59
59
|
@staticmethod
|
|
60
60
|
def _make_causal_mask(
|
|
@@ -79,7 +79,7 @@ else:
|
|
|
79
79
|
"""
|
|
80
80
|
|
|
81
81
|
_PATCHES_ = ["_make_causal_mask"]
|
|
82
|
-
_PATCHED_CLASS_ =
|
|
82
|
+
_PATCHED_CLASS_ = AttentionMaskConverter
|
|
83
83
|
|
|
84
84
|
@staticmethod
|
|
85
85
|
def _make_causal_mask(
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .model_inputs import get_untrained_model_with_inputs
|