onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1098 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import importlib
|
|
3
|
+
import inspect
|
|
4
|
+
import contextlib
|
|
5
|
+
import re
|
|
6
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
7
|
+
from .onnx_export_serialization import (
|
|
8
|
+
register_cache_serialization,
|
|
9
|
+
unregister_cache_serialization,
|
|
10
|
+
)
|
|
11
|
+
from .patches import patch_transformers as patch_transformers_list
|
|
12
|
+
from .patch_details import PatchDetails
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_function(name: str) -> Tuple[type, Callable]:
|
|
16
|
+
"""Returns the module and the function based on its name."""
|
|
17
|
+
spl = name.split(".")
|
|
18
|
+
module_name = ".".join(spl[:-1])
|
|
19
|
+
fname = spl[-1]
|
|
20
|
+
mod = importlib.import_module(module_name)
|
|
21
|
+
if not hasattr(mod, fname):
|
|
22
|
+
return None, None
|
|
23
|
+
return mod, getattr(mod, fname)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@functools.lru_cache
|
|
27
|
+
def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
|
|
28
|
+
"""Returns the list of patches to make for a specific module."""
|
|
29
|
+
to_patch = []
|
|
30
|
+
for k in dir(mod):
|
|
31
|
+
if k.startswith("patched_"):
|
|
32
|
+
v = getattr(mod, k)
|
|
33
|
+
if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
|
|
34
|
+
to_patch.append(v)
|
|
35
|
+
else:
|
|
36
|
+
# a function
|
|
37
|
+
doc = v.__doc__.lstrip()
|
|
38
|
+
if doc.startswith("manual patch"):
|
|
39
|
+
continue
|
|
40
|
+
reg = re.compile("[\\[]patch:([a-z_A-Z.]+)[\\]]")
|
|
41
|
+
fall = reg.findall(doc)
|
|
42
|
+
assert (
|
|
43
|
+
len(fall) == 1
|
|
44
|
+
), f"Unable to find patching information for {v} in \n{doc}"
|
|
45
|
+
fmod, f = get_function(fall[0])
|
|
46
|
+
if fmod is None and f is None:
|
|
47
|
+
# The function does not exist in this version of transformers.
|
|
48
|
+
# No patch is needed.
|
|
49
|
+
continue
|
|
50
|
+
to_patch.append({"module": fmod, "function": f, "patch": v})
|
|
51
|
+
|
|
52
|
+
name = mod.__name__
|
|
53
|
+
return name, to_patch
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def patch_module_or_classes(
|
|
57
|
+
mod, verbose: int = 0, patch_details: Optional[PatchDetails] = None
|
|
58
|
+
) -> Dict[type, Dict[type, Callable]]:
|
|
59
|
+
"""
|
|
60
|
+
Applies all patches defined in classes prefixed by ``patched_``
|
|
61
|
+
``cls._PATCHED_CLASS_`` defines the class to patch,
|
|
62
|
+
``cls._PATCHES_`` defines the method to patch.
|
|
63
|
+
The returns information needs to be sent to :func:`unpatch_module_or_classes`
|
|
64
|
+
to revert the changes.
|
|
65
|
+
|
|
66
|
+
:param mod: module of list of clsses to patch
|
|
67
|
+
:param verbose: verbosity
|
|
68
|
+
:param patch_details: used to store information about the applied patches
|
|
69
|
+
:return: patch info
|
|
70
|
+
"""
|
|
71
|
+
if isinstance(mod, list):
|
|
72
|
+
to_patch = mod
|
|
73
|
+
name = "list"
|
|
74
|
+
list_name = "auto/list"
|
|
75
|
+
else:
|
|
76
|
+
name, to_patch = get_patches(mod, verbose)
|
|
77
|
+
list_name = f"auto/{mod.__name__.split('.')[-1]}"
|
|
78
|
+
|
|
79
|
+
res = {}
|
|
80
|
+
for cls in to_patch:
|
|
81
|
+
if isinstance(cls, dict):
|
|
82
|
+
# a function
|
|
83
|
+
keep = {}
|
|
84
|
+
original = cls["module"]
|
|
85
|
+
f = cls["function"]
|
|
86
|
+
assert not f.__name__.startswith("patched_"), (
|
|
87
|
+
f"The function {f} was already patched or the patch was not removed, "
|
|
88
|
+
f"original={original}"
|
|
89
|
+
)
|
|
90
|
+
res[f] = f
|
|
91
|
+
if verbose:
|
|
92
|
+
print(f"[patch_module_or_classes] function: {original.__name__}.{f.__name__}")
|
|
93
|
+
if patch_details:
|
|
94
|
+
patch_details.append(list_name, getattr(original, f.__name__), cls["patch"])
|
|
95
|
+
setattr(original, f.__name__, cls["patch"])
|
|
96
|
+
continue
|
|
97
|
+
|
|
98
|
+
original = cls._PATCHED_CLASS_
|
|
99
|
+
methods = [_ for _ in cls._PATCHES_ if _ is not None]
|
|
100
|
+
if verbose:
|
|
101
|
+
print(f"[patch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}")
|
|
102
|
+
|
|
103
|
+
keep = {n: getattr(original, n, None) for n in methods}
|
|
104
|
+
for n in methods:
|
|
105
|
+
if patch_details:
|
|
106
|
+
if hasattr(original, n):
|
|
107
|
+
p = patch_details.append(list_name, getattr(original, n), getattr(cls, n))
|
|
108
|
+
else:
|
|
109
|
+
p = patch_details.append(
|
|
110
|
+
list_name, f"{original.__name__}{n}", getattr(cls, n)
|
|
111
|
+
)
|
|
112
|
+
if "@patched_dynamic_rope_update" in inspect.getsource(getattr(cls, n)):
|
|
113
|
+
# a tweak to include that patch.
|
|
114
|
+
f = patch_details.find("patched_dynamic_rope_update")
|
|
115
|
+
if f is not None:
|
|
116
|
+
p.add_dependency(f)
|
|
117
|
+
setattr(original, n, getattr(cls, n))
|
|
118
|
+
res[cls] = keep
|
|
119
|
+
|
|
120
|
+
return res
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0):
|
|
124
|
+
"""
|
|
125
|
+
Reverts modification made by :func:`patch_module_or_classes`.
|
|
126
|
+
|
|
127
|
+
:param mod: module of list of clsses to patch
|
|
128
|
+
:param verbose: verbosity
|
|
129
|
+
"""
|
|
130
|
+
if isinstance(mod, list):
|
|
131
|
+
to_patch = mod
|
|
132
|
+
name = "list"
|
|
133
|
+
else:
|
|
134
|
+
name, to_patch = get_patches(mod, verbose)
|
|
135
|
+
|
|
136
|
+
set_patch_cls = {i for i in to_patch if not isinstance(i, dict)}
|
|
137
|
+
dict_patch_fct = {i["function"]: i for i in to_patch if isinstance(i, dict)}
|
|
138
|
+
|
|
139
|
+
for cls, methods in info.items():
|
|
140
|
+
if cls in set_patch_cls:
|
|
141
|
+
if verbose:
|
|
142
|
+
print(
|
|
143
|
+
f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}"
|
|
144
|
+
)
|
|
145
|
+
original = cls._PATCHED_CLASS_
|
|
146
|
+
for n, v in methods.items():
|
|
147
|
+
if v is None:
|
|
148
|
+
# The method did not exist. We remove it.
|
|
149
|
+
delattr(original, n)
|
|
150
|
+
else:
|
|
151
|
+
setattr(original, n, v)
|
|
152
|
+
continue
|
|
153
|
+
assert cls in dict_patch_fct, (
|
|
154
|
+
f"No patch registered for {cls} in {mod} "
|
|
155
|
+
f"(found {set_patch_cls} and {set(dict_patch_fct)})"
|
|
156
|
+
)
|
|
157
|
+
patch = dict_patch_fct[cls]
|
|
158
|
+
if verbose:
|
|
159
|
+
print(
|
|
160
|
+
f"[unpatch_module_or_classes] function "
|
|
161
|
+
f"{patch['module'].__name__}.{cls.__name__}"
|
|
162
|
+
)
|
|
163
|
+
setattr(patch["module"], cls.__name__, patch["function"])
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@contextlib.contextmanager
|
|
167
|
+
def register_additional_serialization_functions(
|
|
168
|
+
patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0
|
|
169
|
+
) -> Callable:
|
|
170
|
+
"""The necessary modifications to run the fx Graph."""
|
|
171
|
+
fct_callable = (
|
|
172
|
+
replacement_before_exporting
|
|
173
|
+
if patch_transformers or patch_diffusers
|
|
174
|
+
else (lambda x: x)
|
|
175
|
+
)
|
|
176
|
+
done = register_cache_serialization(
|
|
177
|
+
patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose
|
|
178
|
+
)
|
|
179
|
+
try:
|
|
180
|
+
yield fct_callable
|
|
181
|
+
finally:
|
|
182
|
+
unregister_cache_serialization(done, verbose=verbose)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _patch_sympy(verbose: int, patch_details: PatchDetails) -> Tuple[Optional[Callable], ...]:
|
|
186
|
+
import sympy
|
|
187
|
+
|
|
188
|
+
f_sympy_name = getattr(sympy.core.numbers.IntegerConstant, "name", None)
|
|
189
|
+
|
|
190
|
+
if verbose:
|
|
191
|
+
print(f"[torch_export_patches] sympy.__version__={sympy.__version__!r}")
|
|
192
|
+
print("[torch_export_patches] patch sympy")
|
|
193
|
+
|
|
194
|
+
sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}"
|
|
195
|
+
if patch_details:
|
|
196
|
+
patch_details.append(
|
|
197
|
+
"sympy",
|
|
198
|
+
f_sympy_name or "sympy.core.numbers.IntegerConstant.name",
|
|
199
|
+
sympy.core.numbers.IntegerConstant.name,
|
|
200
|
+
)
|
|
201
|
+
return (f_sympy_name,)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _unpatch_sympy(verbose: int, f_sympy_name: Optional[Callable]):
|
|
205
|
+
# tracked by https://github.com/pytorch/pytorch/issues/143494
|
|
206
|
+
import sympy
|
|
207
|
+
|
|
208
|
+
if f_sympy_name:
|
|
209
|
+
sympy.core.numbers.IntegerConstant.name = f_sympy_name
|
|
210
|
+
else:
|
|
211
|
+
delattr(sympy.core.numbers.IntegerConstant, "name")
|
|
212
|
+
|
|
213
|
+
if verbose:
|
|
214
|
+
print("[torch_export_patches] restored sympy functions")
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _patch_torch(
|
|
218
|
+
verbose: int,
|
|
219
|
+
patch_details: PatchDetails,
|
|
220
|
+
patch_torch: int,
|
|
221
|
+
catch_constraints: bool,
|
|
222
|
+
stop_if_static: int,
|
|
223
|
+
) -> Tuple[Optional[Callable], ...]:
|
|
224
|
+
import torch
|
|
225
|
+
import torch.jit
|
|
226
|
+
import torch._export.non_strict_utils # produce_guards_and_solve_constraints
|
|
227
|
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
228
|
+
from .patches.patch_torch import (
|
|
229
|
+
patched_infer_size,
|
|
230
|
+
patched_vmap,
|
|
231
|
+
patched__broadcast_shapes,
|
|
232
|
+
patched__constrain_user_specified_dimhint_range,
|
|
233
|
+
_catch_produce_guards_and_solve_constraints,
|
|
234
|
+
patch__check_input_constraints_for_graph,
|
|
235
|
+
patched__broadcast_in_dim_meta,
|
|
236
|
+
patched__broadcast_in_dim_meta_level_2,
|
|
237
|
+
patched__maybe_broadcast,
|
|
238
|
+
patched_ShapeEnv,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
f___constrain_user_specified_dimhint_range = None
|
|
242
|
+
f__broadcast_in_dim_meta = None
|
|
243
|
+
f__broadcast_shapes = None
|
|
244
|
+
f__check_input_constraints_for_graph = None
|
|
245
|
+
f__maybe_broadcast = None
|
|
246
|
+
f_broadcast_in_dim = None
|
|
247
|
+
f_infer_size = None
|
|
248
|
+
f_jit_isinstance = None
|
|
249
|
+
f_mark_static_address = None
|
|
250
|
+
f_produce_guards_and_solve_constraints = None
|
|
251
|
+
f_shape_env__check_frozen = None
|
|
252
|
+
f_shape_env__evaluate_expr = None
|
|
253
|
+
f_shape_env__log_guard = None
|
|
254
|
+
f_shape_env__set_replacement = None
|
|
255
|
+
f_vmap = None
|
|
256
|
+
|
|
257
|
+
if verbose:
|
|
258
|
+
print(f"[torch_export_patches] torch.__version__={torch.__version__!r}")
|
|
259
|
+
print(f"[torch_export_patches] stop_if_static={stop_if_static!r}")
|
|
260
|
+
print("[torch_export_patches] patch pytorch")
|
|
261
|
+
|
|
262
|
+
# torch.vmap
|
|
263
|
+
f_vmap = torch.vmap
|
|
264
|
+
torch.vmap = patched_vmap
|
|
265
|
+
|
|
266
|
+
# torch.jit.isinstance
|
|
267
|
+
f_jit_isinstance = torch.jit.isinstance
|
|
268
|
+
torch.jit.isinstance = isinstance
|
|
269
|
+
|
|
270
|
+
# torch._dynamo.mark_static_address
|
|
271
|
+
f_mark_static_address = torch._dynamo.mark_static_address
|
|
272
|
+
torch._dynamo.mark_static_address = lambda *_, **y_: None
|
|
273
|
+
|
|
274
|
+
# torch._subclasses.fake_impls.infer_size
|
|
275
|
+
f_infer_size = torch._subclasses.fake_impls.infer_size
|
|
276
|
+
torch._subclasses.fake_impls.infer_size = patched_infer_size
|
|
277
|
+
if patch_details:
|
|
278
|
+
patch_details.append("torch", f_infer_size, patched_infer_size)
|
|
279
|
+
|
|
280
|
+
# torch._refs._broadcast_shapes
|
|
281
|
+
f__broadcast_shapes = torch._refs._broadcast_shapes
|
|
282
|
+
torch._refs._broadcast_shapes = patched__broadcast_shapes
|
|
283
|
+
torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes
|
|
284
|
+
if patch_details:
|
|
285
|
+
patch_details.append("torch", f__broadcast_shapes, patched__broadcast_shapes)
|
|
286
|
+
|
|
287
|
+
# torch._export.non_strict_utils._constrain_user_specified_dimhint_range
|
|
288
|
+
f___constrain_user_specified_dimhint_range = (
|
|
289
|
+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range
|
|
290
|
+
)
|
|
291
|
+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
|
|
292
|
+
patched__constrain_user_specified_dimhint_range
|
|
293
|
+
)
|
|
294
|
+
if patch_details:
|
|
295
|
+
patch_details.append(
|
|
296
|
+
"torch",
|
|
297
|
+
f___constrain_user_specified_dimhint_range,
|
|
298
|
+
patched__constrain_user_specified_dimhint_range,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# torch._prims._broadcast_in_dim_meta
|
|
302
|
+
f_broadcast_in_dim = torch._prims.broadcast_in_dim
|
|
303
|
+
f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
|
|
304
|
+
_patched_dim_f = (
|
|
305
|
+
patched__broadcast_in_dim_meta_level_2
|
|
306
|
+
if patch_torch == 2
|
|
307
|
+
else patched__broadcast_in_dim_meta
|
|
308
|
+
)
|
|
309
|
+
torch._prims._broadcast_in_dim_meta = _patched_dim_f
|
|
310
|
+
torch._prims.broadcast_in_dim = _patched_dim_f
|
|
311
|
+
if patch_details:
|
|
312
|
+
patch_details.append("torch", f__broadcast_in_dim_meta, _patched_dim_f)
|
|
313
|
+
|
|
314
|
+
# torch._refs._maybe_broadcast
|
|
315
|
+
f__maybe_broadcast = torch._refs._maybe_broadcast
|
|
316
|
+
torch._refs._maybe_broadcast = patched__maybe_broadcast
|
|
317
|
+
if patch_details:
|
|
318
|
+
patch_details.append("torch", f__maybe_broadcast, patched__maybe_broadcast)
|
|
319
|
+
|
|
320
|
+
# ShapeEnv
|
|
321
|
+
f_shape_env__evaluate_expr = ShapeEnv._evaluate_expr
|
|
322
|
+
ShapeEnv._evaluate_expr = patched_ShapeEnv._evaluate_expr
|
|
323
|
+
if patch_details:
|
|
324
|
+
patch_details.append(
|
|
325
|
+
"torch", f_shape_env__evaluate_expr, patched_ShapeEnv._evaluate_expr
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
# torch._export.non_strict_utils.produce_guards_and_solve_constraints
|
|
329
|
+
if catch_constraints:
|
|
330
|
+
if verbose:
|
|
331
|
+
print("[torch_export_patches] modifies shape constraints")
|
|
332
|
+
f_produce_guards_and_solve_constraints = (
|
|
333
|
+
torch._export.non_strict_utils.produce_guards_and_solve_constraints
|
|
334
|
+
)
|
|
335
|
+
f__check_input_constraints_for_graph = (
|
|
336
|
+
torch._export.utils._check_input_constraints_for_graph
|
|
337
|
+
)
|
|
338
|
+
torch._export.non_strict_utils.produce_guards_and_solve_constraints = (
|
|
339
|
+
lambda *args, **kwargs: _catch_produce_guards_and_solve_constraints(
|
|
340
|
+
f_produce_guards_and_solve_constraints, *args, verbose=verbose, **kwargs
|
|
341
|
+
)
|
|
342
|
+
)
|
|
343
|
+
torch._export.utils._check_input_constraints_for_graph = (
|
|
344
|
+
lambda *args, **kwargs: patch__check_input_constraints_for_graph(
|
|
345
|
+
f__check_input_constraints_for_graph, *args, verbose=verbose, **kwargs
|
|
346
|
+
)
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
if patch_torch and stop_if_static:
|
|
350
|
+
ShapeEnv._log_guard_remember = ShapeEnv._log_guard
|
|
351
|
+
|
|
352
|
+
if verbose:
|
|
353
|
+
print("[torch_export_patches] assert when a dynamic dimension turns static")
|
|
354
|
+
print("[torch_export_patches] replaces ShapeEnv._set_replacement")
|
|
355
|
+
|
|
356
|
+
f_shape_env__set_replacement = ShapeEnv._set_replacement
|
|
357
|
+
ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement
|
|
358
|
+
if patch_details:
|
|
359
|
+
patch_details.append(
|
|
360
|
+
"torch", f_shape_env__set_replacement, patched_ShapeEnv._set_replacement
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
if verbose:
|
|
364
|
+
print("[torch_export_patches] replaces ShapeEnv._log_guard")
|
|
365
|
+
f_shape_env__log_guard = ShapeEnv._log_guard
|
|
366
|
+
ShapeEnv._log_guard = patched_ShapeEnv._log_guard
|
|
367
|
+
if patch_details:
|
|
368
|
+
patch_details.append("torch", f_shape_env__log_guard, patched_ShapeEnv._log_guard)
|
|
369
|
+
|
|
370
|
+
if stop_if_static > 1:
|
|
371
|
+
if verbose:
|
|
372
|
+
print("[torch_export_patches] replaces ShapeEnv._check_frozen")
|
|
373
|
+
f_shape_env__check_frozen = ShapeEnv._check_frozen
|
|
374
|
+
ShapeEnv._check_frozen = patched_ShapeEnv._check_frozen
|
|
375
|
+
if patch_details:
|
|
376
|
+
patch_details.append(
|
|
377
|
+
"torch", f_shape_env__check_frozen, ShapeEnv._check_frozen
|
|
378
|
+
)
|
|
379
|
+
return (
|
|
380
|
+
f___constrain_user_specified_dimhint_range,
|
|
381
|
+
f__broadcast_in_dim_meta,
|
|
382
|
+
f__broadcast_shapes,
|
|
383
|
+
f__check_input_constraints_for_graph,
|
|
384
|
+
f__maybe_broadcast,
|
|
385
|
+
f_broadcast_in_dim,
|
|
386
|
+
f_infer_size,
|
|
387
|
+
f_jit_isinstance,
|
|
388
|
+
f_mark_static_address,
|
|
389
|
+
f_produce_guards_and_solve_constraints,
|
|
390
|
+
f_shape_env__check_frozen,
|
|
391
|
+
f_shape_env__evaluate_expr,
|
|
392
|
+
f_shape_env__log_guard,
|
|
393
|
+
f_shape_env__set_replacement,
|
|
394
|
+
f_vmap,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def _unpatch_torch(
|
|
399
|
+
verbose: int,
|
|
400
|
+
_patch_details: PatchDetails,
|
|
401
|
+
patch_torch: int,
|
|
402
|
+
catch_constraints: bool,
|
|
403
|
+
stop_if_static: int,
|
|
404
|
+
f___constrain_user_specified_dimhint_range: Optional[Callable],
|
|
405
|
+
f__broadcast_in_dim_meta: Optional[Callable],
|
|
406
|
+
f__broadcast_shapes: Optional[Callable],
|
|
407
|
+
f__check_input_constraints_for_graph: Optional[Callable],
|
|
408
|
+
f__maybe_broadcast: Optional[Callable],
|
|
409
|
+
f_broadcast_in_dim: Optional[Callable],
|
|
410
|
+
f_infer_size: Optional[Callable],
|
|
411
|
+
f_jit_isinstance: Optional[Callable],
|
|
412
|
+
f_mark_static_address: Optional[Callable],
|
|
413
|
+
f_produce_guards_and_solve_constraints: Optional[Callable],
|
|
414
|
+
f_shape_env__check_frozen: Optional[Callable],
|
|
415
|
+
f_shape_env__evaluate_expr: Optional[Callable],
|
|
416
|
+
f_shape_env__log_guard: Optional[Callable],
|
|
417
|
+
f_shape_env__set_replacement: Optional[Callable],
|
|
418
|
+
f_vmap: Optional[Callable],
|
|
419
|
+
):
|
|
420
|
+
import torch
|
|
421
|
+
import torch.jit
|
|
422
|
+
import torch._export.non_strict_utils # produce_guards_and_solve_constraints
|
|
423
|
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
424
|
+
|
|
425
|
+
# this should disappear when torch.jit is removed
|
|
426
|
+
torch.vmap = f_vmap
|
|
427
|
+
torch.jit.isinstance = f_jit_isinstance
|
|
428
|
+
torch._dynamo.mark_static_address = f_mark_static_address
|
|
429
|
+
# tracked by https://github.com/pytorch/pytorch/issues/143495
|
|
430
|
+
torch._subclasses.fake_impls.infer_size = f_infer_size
|
|
431
|
+
torch._refs._broadcast_shapes = f__broadcast_shapes
|
|
432
|
+
torch._meta_registrations._broadcast_shapes = f__broadcast_shapes
|
|
433
|
+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
|
|
434
|
+
f___constrain_user_specified_dimhint_range
|
|
435
|
+
)
|
|
436
|
+
torch._prims._broadcast_in_dim_meta = f__broadcast_in_dim_meta
|
|
437
|
+
torch._prims.broadcast_in_dim = f_broadcast_in_dim
|
|
438
|
+
torch._refs._maybe_broadcast = f__maybe_broadcast
|
|
439
|
+
ShapeEnv._evaluate_expr = f_shape_env__evaluate_expr
|
|
440
|
+
|
|
441
|
+
if verbose:
|
|
442
|
+
print("[torch_export_patches] restored pytorch functions")
|
|
443
|
+
|
|
444
|
+
if patch_torch and stop_if_static:
|
|
445
|
+
if verbose:
|
|
446
|
+
print("[torch_export_patches] restored ShapeEnv._set_replacement")
|
|
447
|
+
|
|
448
|
+
ShapeEnv._set_replacement = f_shape_env__set_replacement
|
|
449
|
+
|
|
450
|
+
if verbose:
|
|
451
|
+
print("[torch_export_patches] restored ShapeEnv._log_guard")
|
|
452
|
+
|
|
453
|
+
ShapeEnv._log_guard = f_shape_env__log_guard
|
|
454
|
+
|
|
455
|
+
if stop_if_static > 1:
|
|
456
|
+
if verbose:
|
|
457
|
+
print("[torch_export_patches] restored ShapeEnv._check_frozen")
|
|
458
|
+
ShapeEnv._check_frozen = f_shape_env__check_frozen
|
|
459
|
+
|
|
460
|
+
if patch_torch and catch_constraints:
|
|
461
|
+
# to catch or skip dynamic_shapes issues
|
|
462
|
+
torch._export.non_strict_utils.produce_guards_and_solve_constraints = (
|
|
463
|
+
f_produce_guards_and_solve_constraints
|
|
464
|
+
)
|
|
465
|
+
torch._export.utils._check_input_constraints_for_graph = (
|
|
466
|
+
f__check_input_constraints_for_graph
|
|
467
|
+
)
|
|
468
|
+
if verbose:
|
|
469
|
+
print("[torch_export_patches] restored shape constraints")
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def _patch_transformers(
|
|
473
|
+
verbose: int, patch_details: PatchDetails
|
|
474
|
+
) -> Tuple[Optional[Callable], ...]:
|
|
475
|
+
import transformers
|
|
476
|
+
|
|
477
|
+
try:
|
|
478
|
+
import transformers.masking_utils as masking_utils
|
|
479
|
+
except ImportError:
|
|
480
|
+
masking_utils = None
|
|
481
|
+
|
|
482
|
+
try:
|
|
483
|
+
import transformers.integrations.sdpa_attention as sdpa_attention
|
|
484
|
+
except ImportError:
|
|
485
|
+
sdpa_attention = None
|
|
486
|
+
|
|
487
|
+
try:
|
|
488
|
+
import transformers.modeling_utils as modeling_utils
|
|
489
|
+
except ImportError:
|
|
490
|
+
modeling_utils = None
|
|
491
|
+
|
|
492
|
+
try:
|
|
493
|
+
import transformers.modeling_rope_utils as modeling_rope_utils
|
|
494
|
+
except ImportError:
|
|
495
|
+
modeling_rope_utils = None
|
|
496
|
+
|
|
497
|
+
if (
|
|
498
|
+
patch_details
|
|
499
|
+
and modeling_rope_utils
|
|
500
|
+
and hasattr(modeling_rope_utils, "dynamic_rope_update")
|
|
501
|
+
):
|
|
502
|
+
patch_details.append(
|
|
503
|
+
"patch_transformers",
|
|
504
|
+
modeling_rope_utils.dynamic_rope_update,
|
|
505
|
+
patch_transformers_list.patched_dynamic_rope_update,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
if verbose:
|
|
509
|
+
print(f"[torch_export_patches] transformers.__version__={transformers.__version__!r}")
|
|
510
|
+
assert not sdpa_attention.sdpa_attention_forward.__name__.startswith("patched_"), (
|
|
511
|
+
f"Function 'sdpa_attention.sdpa_attention_forward' is already patched, "
|
|
512
|
+
f"sdpa_attention.sdpa_attention_forward={sdpa_attention.sdpa_attention_forward}"
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
f_transformers__vmap_for_bhqkv = None
|
|
516
|
+
f_transformers_eager_mask = None
|
|
517
|
+
f_transformers_sdpa_attention_forward = None
|
|
518
|
+
f_transformers_sdpa_mask = None
|
|
519
|
+
f_transformers_sdpa_mask_recent_torch = None
|
|
520
|
+
|
|
521
|
+
if ( # vmap
|
|
522
|
+
masking_utils
|
|
523
|
+
and patch_transformers_list.patch_masking_utils
|
|
524
|
+
and hasattr(masking_utils, "_vmap_for_bhqkv")
|
|
525
|
+
):
|
|
526
|
+
if verbose:
|
|
527
|
+
print("[torch_export_patches] patches transformers.masking_utils._vmap_for_bhqkv")
|
|
528
|
+
f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
|
|
529
|
+
masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
|
|
530
|
+
if patch_details:
|
|
531
|
+
patch_details.append(
|
|
532
|
+
"transformers",
|
|
533
|
+
f_transformers__vmap_for_bhqkv,
|
|
534
|
+
patch_transformers_list.patched__vmap_for_bhqkv,
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
if verbose:
|
|
538
|
+
print(
|
|
539
|
+
"[torch_export_patches] patches "
|
|
540
|
+
"transformers.masking_utils.sdpa_mask_recent_torch"
|
|
541
|
+
)
|
|
542
|
+
f_transformers_sdpa_mask_recent_torch = masking_utils.sdpa_mask_recent_torch
|
|
543
|
+
masking_utils.sdpa_mask_recent_torch = (
|
|
544
|
+
patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
545
|
+
)
|
|
546
|
+
if patch_details:
|
|
547
|
+
patch_details.append(
|
|
548
|
+
"transformers",
|
|
549
|
+
f_transformers_sdpa_mask_recent_torch,
|
|
550
|
+
patch_transformers_list.patched_sdpa_mask_recent_torch,
|
|
551
|
+
)
|
|
552
|
+
if masking_utils.sdpa_mask == f_transformers_sdpa_mask_recent_torch:
|
|
553
|
+
if verbose:
|
|
554
|
+
print("[torch_export_patches] patches transformers.masking_utils.sdpa_mask")
|
|
555
|
+
f_transformers_sdpa_mask = masking_utils.sdpa_mask
|
|
556
|
+
masking_utils.sdpa_mask = patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
557
|
+
if patch_details:
|
|
558
|
+
patch_details.append(
|
|
559
|
+
"transformers",
|
|
560
|
+
f_transformers_sdpa_mask,
|
|
561
|
+
patch_transformers_list.patched_sdpa_mask_recent_torch,
|
|
562
|
+
)
|
|
563
|
+
else:
|
|
564
|
+
f_transformers_sdpa_mask = None
|
|
565
|
+
|
|
566
|
+
if ( # eager_mask
|
|
567
|
+
masking_utils
|
|
568
|
+
and patch_transformers_list.patch_masking_utils
|
|
569
|
+
and hasattr(masking_utils, "eager_mask")
|
|
570
|
+
):
|
|
571
|
+
if verbose:
|
|
572
|
+
print("[torch_export_patches] patches transformers.masking_utils.eager_mask")
|
|
573
|
+
f_transformers_eager_mask = masking_utils.eager_mask
|
|
574
|
+
masking_utils.eager_mask = patch_transformers_list.patched_eager_mask
|
|
575
|
+
if patch_details:
|
|
576
|
+
patch_details.append(
|
|
577
|
+
"transformers",
|
|
578
|
+
f_transformers_eager_mask,
|
|
579
|
+
patch_transformers_list.patched_eager_mask,
|
|
580
|
+
)
|
|
581
|
+
if (
|
|
582
|
+
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
583
|
+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
|
|
584
|
+
== f_transformers_eager_mask
|
|
585
|
+
):
|
|
586
|
+
if verbose:
|
|
587
|
+
print(
|
|
588
|
+
"[torch_export_patches] patches "
|
|
589
|
+
"transformers.masking_utils.eager_mask "
|
|
590
|
+
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
591
|
+
)
|
|
592
|
+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
|
|
593
|
+
patch_transformers_list.patched_eager_mask
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
if ( # sdpa_mask
|
|
597
|
+
masking_utils
|
|
598
|
+
and patch_transformers_list.patch_masking_utils
|
|
599
|
+
and hasattr(masking_utils, "sdpa_mask")
|
|
600
|
+
and f_transformers_sdpa_mask is not None
|
|
601
|
+
):
|
|
602
|
+
if verbose:
|
|
603
|
+
print(
|
|
604
|
+
"[torch_export_patches] patches "
|
|
605
|
+
"transformers.masking_utils.sdpa_mask "
|
|
606
|
+
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
607
|
+
)
|
|
608
|
+
if (
|
|
609
|
+
"sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
610
|
+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] == f_transformers_sdpa_mask
|
|
611
|
+
):
|
|
612
|
+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
|
|
613
|
+
patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
if ( # sdpa_attention_forward
|
|
617
|
+
sdpa_attention is not None
|
|
618
|
+
and modeling_utils is not None
|
|
619
|
+
and hasattr(sdpa_attention, "sdpa_attention_forward")
|
|
620
|
+
and hasattr(sdpa_attention, "use_gqa_in_sdpa")
|
|
621
|
+
and hasattr(modeling_utils, "AttentionInterface")
|
|
622
|
+
):
|
|
623
|
+
if verbose:
|
|
624
|
+
print(
|
|
625
|
+
"[torch_export_patches] patches "
|
|
626
|
+
"transformers.integrations.sdpa_attention.sdpa_attention_forward"
|
|
627
|
+
)
|
|
628
|
+
f_transformers_sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
|
|
629
|
+
assert not f_transformers_sdpa_attention_forward.__name__.startswith("patched_"), (
|
|
630
|
+
f"Function 'sdpa_attention.sdpa_attention_forward' is already patched, "
|
|
631
|
+
f"sdpa_attention.sdpa_attention_forward={f_transformers_sdpa_attention_forward}"
|
|
632
|
+
)
|
|
633
|
+
sdpa_attention.sdpa_attention_forward = (
|
|
634
|
+
patch_transformers_list.patched_sdpa_attention_forward
|
|
635
|
+
)
|
|
636
|
+
modeling_utils.sdpa_attention_forward = (
|
|
637
|
+
patch_transformers_list.patched_sdpa_attention_forward
|
|
638
|
+
)
|
|
639
|
+
modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
|
|
640
|
+
patch_transformers_list.patched_sdpa_attention_forward
|
|
641
|
+
)
|
|
642
|
+
if patch_details:
|
|
643
|
+
patch_details.append(
|
|
644
|
+
"transformers",
|
|
645
|
+
f_transformers_sdpa_attention_forward,
|
|
646
|
+
patch_transformers_list.patched_sdpa_attention_forward,
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
revert_patches_info = patch_module_or_classes(
|
|
650
|
+
patch_transformers_list, verbose=verbose, patch_details=patch_details
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
return (
|
|
654
|
+
f_transformers__vmap_for_bhqkv,
|
|
655
|
+
f_transformers_eager_mask,
|
|
656
|
+
f_transformers_sdpa_attention_forward,
|
|
657
|
+
f_transformers_sdpa_mask,
|
|
658
|
+
f_transformers_sdpa_mask_recent_torch,
|
|
659
|
+
revert_patches_info,
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def _unpatch_transformers(
|
|
664
|
+
verbose: int,
|
|
665
|
+
_patch_details: PatchDetails,
|
|
666
|
+
f_transformers__vmap_for_bhqkv: Optional[Callable],
|
|
667
|
+
f_transformers_eager_mask: Optional[Callable],
|
|
668
|
+
f_transformers_sdpa_attention_forward: Optional[Callable],
|
|
669
|
+
f_transformers_sdpa_mask: Optional[Callable],
|
|
670
|
+
f_transformers_sdpa_mask_recent_torch: Optional[Callable],
|
|
671
|
+
revert_patches_info: Optional[Callable],
|
|
672
|
+
):
|
|
673
|
+
|
|
674
|
+
try:
|
|
675
|
+
import transformers.masking_utils as masking_utils
|
|
676
|
+
except ImportError:
|
|
677
|
+
masking_utils = None
|
|
678
|
+
|
|
679
|
+
try:
|
|
680
|
+
import transformers.integrations.sdpa_attention as sdpa_attention
|
|
681
|
+
except ImportError:
|
|
682
|
+
sdpa_attention = None
|
|
683
|
+
|
|
684
|
+
try:
|
|
685
|
+
import transformers.modeling_utils as modeling_utils
|
|
686
|
+
except ImportError:
|
|
687
|
+
modeling_utils = None
|
|
688
|
+
|
|
689
|
+
try:
|
|
690
|
+
import transformers.masking_utils as masking_utils
|
|
691
|
+
except ImportError:
|
|
692
|
+
masking_utils = None
|
|
693
|
+
if verbose:
|
|
694
|
+
print("[torch_export_patches] unpatches transformers")
|
|
695
|
+
|
|
696
|
+
if ( # vmap
|
|
697
|
+
masking_utils
|
|
698
|
+
and patch_transformers_list.patch_masking_utils
|
|
699
|
+
and hasattr(masking_utils, "_vmap_for_bhqkv")
|
|
700
|
+
):
|
|
701
|
+
assert f_transformers__vmap_for_bhqkv.__name__ == "_vmap_for_bhqkv", (
|
|
702
|
+
f"corrupted function '_vmap_for_bhqkv', its name is "
|
|
703
|
+
f"{f_transformers__vmap_for_bhqkv.__name__!r}"
|
|
704
|
+
)
|
|
705
|
+
masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
|
|
706
|
+
|
|
707
|
+
if verbose:
|
|
708
|
+
print("[torch_export_patches] restored transformers.masking_utils._vmap_for_bhqkv")
|
|
709
|
+
|
|
710
|
+
assert f_transformers_sdpa_mask_recent_torch.__name__ == "sdpa_mask_recent_torch", (
|
|
711
|
+
f"corrupted function 'sdpa_mask_recent_torch', its name is "
|
|
712
|
+
f"{f_transformers_sdpa_mask_recent_torch.__name__!r}"
|
|
713
|
+
)
|
|
714
|
+
masking_utils.sdpa_mask_recent_torch = f_transformers_sdpa_mask_recent_torch
|
|
715
|
+
|
|
716
|
+
if verbose:
|
|
717
|
+
print(
|
|
718
|
+
"[torch_export_patches] restored "
|
|
719
|
+
"transformers.masking_utils.sdpa_mask_recent_torch"
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
if f_transformers_sdpa_mask is not None:
|
|
723
|
+
assert f_transformers_sdpa_mask.__name__ in (
|
|
724
|
+
"sdpa_mask",
|
|
725
|
+
"sdpa_mask_recent_torch",
|
|
726
|
+
), (
|
|
727
|
+
f"corrupted function 'sdpa_mask', its name is "
|
|
728
|
+
f"{f_transformers_sdpa_mask.__name__!r}"
|
|
729
|
+
)
|
|
730
|
+
masking_utils.sdpa_mask = f_transformers_sdpa_mask
|
|
731
|
+
if verbose:
|
|
732
|
+
print("[torch_export_patches] restored transformers.masking_utils.sdpa_mask")
|
|
733
|
+
|
|
734
|
+
if ( # eager_mask
|
|
735
|
+
masking_utils
|
|
736
|
+
and patch_transformers_list.patch_masking_utils
|
|
737
|
+
and hasattr(masking_utils, "eager_mask")
|
|
738
|
+
):
|
|
739
|
+
assert f_transformers_eager_mask.__name__ == "eager_mask", (
|
|
740
|
+
f"corrupted function 'eager_mask', its name is "
|
|
741
|
+
f"{f_transformers_eager_mask.__name__!r}"
|
|
742
|
+
)
|
|
743
|
+
masking_utils.eager_mask = f_transformers_eager_mask
|
|
744
|
+
if verbose:
|
|
745
|
+
print("[torch_export_patches] restored transformers.masking_utils.eager_mask")
|
|
746
|
+
assert masking_utils.eager_mask.__name__ == "eager_mask", (
|
|
747
|
+
f"corrupted function 'eager_mask', its name is "
|
|
748
|
+
f"{masking_utils.eager_mask.__name__!r}"
|
|
749
|
+
)
|
|
750
|
+
if (
|
|
751
|
+
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
752
|
+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
|
|
753
|
+
== patch_transformers_list.patched_eager_mask
|
|
754
|
+
):
|
|
755
|
+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = f_transformers_eager_mask
|
|
756
|
+
if verbose:
|
|
757
|
+
print(
|
|
758
|
+
"[torch_export_patches] restored "
|
|
759
|
+
"transformers.masking_utils.eager_mask "
|
|
760
|
+
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
761
|
+
)
|
|
762
|
+
assert masking_utils.eager_mask.__name__ == "eager_mask", (
|
|
763
|
+
f"corrupted function 'eager_mask', its name is "
|
|
764
|
+
f"{masking_utils.eager_mask.__name__!r}"
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
if ( # sdpa_mask
|
|
768
|
+
masking_utils
|
|
769
|
+
and patch_transformers_list.patch_masking_utils
|
|
770
|
+
and hasattr(masking_utils, "sdpa_mask")
|
|
771
|
+
):
|
|
772
|
+
if (
|
|
773
|
+
"sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
774
|
+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
|
|
775
|
+
== patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
776
|
+
):
|
|
777
|
+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = f_transformers_sdpa_mask
|
|
778
|
+
if verbose:
|
|
779
|
+
print(
|
|
780
|
+
"[torch_export_patches] restored "
|
|
781
|
+
"transformers.masking_utils.sdpa_mask "
|
|
782
|
+
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
if ( # sdpa_attention_forward
|
|
786
|
+
sdpa_attention is not None
|
|
787
|
+
and modeling_utils is not None
|
|
788
|
+
and hasattr(sdpa_attention, "sdpa_attention_forward")
|
|
789
|
+
and hasattr(sdpa_attention, "use_gqa_in_sdpa")
|
|
790
|
+
and hasattr(modeling_utils, "AttentionInterface")
|
|
791
|
+
):
|
|
792
|
+
sdpa_attention.sdpa_attention_forward = f_transformers_sdpa_attention_forward
|
|
793
|
+
modeling_utils.sdpa_attention_forward = f_transformers_sdpa_attention_forward
|
|
794
|
+
modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
|
|
795
|
+
f_transformers_sdpa_attention_forward
|
|
796
|
+
)
|
|
797
|
+
if verbose:
|
|
798
|
+
print(
|
|
799
|
+
"[torch_export_patches] restored "
|
|
800
|
+
"transformers.integrations.sdpa_attention."
|
|
801
|
+
"sdpa_attention_forward"
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
unpatch_module_or_classes(patch_transformers_list, revert_patches_info, verbose=verbose)
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
@contextlib.contextmanager
|
|
808
|
+
def torch_export_patches(
|
|
809
|
+
patch_sympy: bool = True,
|
|
810
|
+
patch_torch: Union[bool, int] = True,
|
|
811
|
+
patch_transformers: bool = False,
|
|
812
|
+
patch_diffusers: bool = False,
|
|
813
|
+
catch_constraints: bool = True,
|
|
814
|
+
stop_if_static: int = 0,
|
|
815
|
+
verbose: int = 0,
|
|
816
|
+
patch: bool = True,
|
|
817
|
+
custom_patches: Optional[List[type["torch.nn.Module"]]] = None, # noqa: F821
|
|
818
|
+
rewrite: Optional[List[Callable]] = None,
|
|
819
|
+
dump_rewriting: Optional[str] = None,
|
|
820
|
+
patch_details: Optional[PatchDetails] = None,
|
|
821
|
+
) -> Callable:
|
|
822
|
+
"""
|
|
823
|
+
Tries to bypass some situations :func:`torch.export.export` does not support.
|
|
824
|
+
See also :ref:`l-patches-explained` and :ref:`l-patch-coverage`.
|
|
825
|
+
|
|
826
|
+
:param patch_sympy: fix missing method ``name`` for IntegerConstant
|
|
827
|
+
:param patch_torch: patches :epkg:`torch` with supported implementation
|
|
828
|
+
:param patch_transformers: patches :epkg:`transformers` with supported implementation
|
|
829
|
+
:param patch_diffusers: patches :epkg:`diffusers` with supported implementation
|
|
830
|
+
:param catch_constraints: catch constraints related to dynamic shapes,
|
|
831
|
+
as a result, some dynamic dimension may turn into static ones,
|
|
832
|
+
the environment variable ``SKIP_SOLVE_CONSTRAINTS=0``
|
|
833
|
+
can be put to stop at that stage.
|
|
834
|
+
:param stop_if_static: see example :ref:`l-plot-export-locale-issue`,
|
|
835
|
+
to stop the export as soon as an issue is detected with dynamic shapes
|
|
836
|
+
and show a stack trace indicating the exact location of the issue,
|
|
837
|
+
``if stop_if_static > 1``, more methods are replace to catch more
|
|
838
|
+
issues
|
|
839
|
+
:param patch: if False, disable all patches but keeps the registration of
|
|
840
|
+
serialization functions if other patch functions are enabled
|
|
841
|
+
:param custom_patches: to apply custom patches,
|
|
842
|
+
every patched class must define static attributes
|
|
843
|
+
``_PATCHES_``, ``_PATCHED_CLASS_``
|
|
844
|
+
:param rewrite: list of methods to automatically rewrite
|
|
845
|
+
before exporting, methods with control flow need to be rewritten
|
|
846
|
+
before being exported if the execution path depends on the inputs,
|
|
847
|
+
this is done by function :func:`transform_method
|
|
848
|
+
<onnx_diagnostic.torch_export_patches.patch_module.transform_method>`,
|
|
849
|
+
its documentation provides possible values
|
|
850
|
+
:param dump_rewriting: dumps rewriting information in file beginning with that prefix
|
|
851
|
+
:param patch_details: if specified, this class is used to stored every rewritten done.
|
|
852
|
+
:param verbose: to show which patches is applied
|
|
853
|
+
|
|
854
|
+
The list of available patches.
|
|
855
|
+
|
|
856
|
+
* ``torch.jit.isinstance``
|
|
857
|
+
* ``torch._dynamo.mark_static_address``
|
|
858
|
+
* ``torch._subclasses.fake_impls.infer_size``
|
|
859
|
+
* ``torch.vmap``
|
|
860
|
+
* fix missing method ``name`` for ``sympy.S.IntegerConstant``
|
|
861
|
+
* ``AttentionMaskConverter._make_causal_mask``
|
|
862
|
+
* Serialization of ``MambaCache`` (in :epkg:`transformers`)
|
|
863
|
+
* Serialization of ``DynamicCache`` (in :epkg:`transformers`)
|
|
864
|
+
* reduce errors due to shape inference
|
|
865
|
+
* fixes some transformers classes,
|
|
866
|
+
see :mod:`onnx_diagnostic.torch_export_patches.patches.patch_transformers`
|
|
867
|
+
|
|
868
|
+
Serialization issues happen when a module takes one input or output
|
|
869
|
+
has a type :func:`torch.export.export` cannot serialize.
|
|
870
|
+
|
|
871
|
+
Examples:
|
|
872
|
+
|
|
873
|
+
.. code-block:: python
|
|
874
|
+
|
|
875
|
+
with torch_export_patches(patch_transformers=True) as modificator:
|
|
876
|
+
inputs = modificator(inputs)
|
|
877
|
+
onx = to_onnx(..., inputs, ...)
|
|
878
|
+
|
|
879
|
+
.. code-block:: python
|
|
880
|
+
|
|
881
|
+
with torch_export_patches(patch_transformers=True) as modificator:
|
|
882
|
+
inputs = modificator(inputs)
|
|
883
|
+
onx = torch.onnx.export(..., inputs, ...)
|
|
884
|
+
|
|
885
|
+
It can be used as well to fix the torch export:
|
|
886
|
+
|
|
887
|
+
.. code-block:: python
|
|
888
|
+
|
|
889
|
+
with torch_export_patches(patch_transformers=True) as modificator:
|
|
890
|
+
inputs = modificator(inputs)
|
|
891
|
+
ep = torch.export.export(..., inputs, ...)
|
|
892
|
+
|
|
893
|
+
When running the model through the exported program, only the
|
|
894
|
+
serialization functions need to be restored:
|
|
895
|
+
|
|
896
|
+
.. code-block:: python
|
|
897
|
+
|
|
898
|
+
with register_additional_serialization_functions() as modificator:
|
|
899
|
+
inputs = modificator(inputs)
|
|
900
|
+
ep = torch.export.export(..., inputs, ...)
|
|
901
|
+
|
|
902
|
+
When exporting a model with a cache, the following error message
|
|
903
|
+
may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``.
|
|
904
|
+
It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`.
|
|
905
|
+
"""
|
|
906
|
+
if verbose:
|
|
907
|
+
print(f"[torch_export_patches] patch_sympy={patch_sympy!r}")
|
|
908
|
+
print(f" . patch_torch={patch_torch!r}")
|
|
909
|
+
print(f" . patch_transformers={patch_transformers!r}")
|
|
910
|
+
print(f" . patch_diffusers={patch_diffusers!r}")
|
|
911
|
+
print(f" . catch_constraints={catch_constraints!r}")
|
|
912
|
+
print(f" . stop_if_static={stop_if_static!r}")
|
|
913
|
+
print(f" . patch={patch!r}")
|
|
914
|
+
print(f" . custom_patches={custom_patches!r}")
|
|
915
|
+
print(f"[torch_export_patches] dump_rewriting={dump_rewriting!r}")
|
|
916
|
+
|
|
917
|
+
if rewrite:
|
|
918
|
+
from .patch_module import torch_export_rewrite
|
|
919
|
+
|
|
920
|
+
with (
|
|
921
|
+
torch_export_rewrite(
|
|
922
|
+
rewrite=rewrite,
|
|
923
|
+
dump_rewriting=dump_rewriting,
|
|
924
|
+
verbose=verbose,
|
|
925
|
+
patch_details=patch_details,
|
|
926
|
+
),
|
|
927
|
+
torch_export_patches( # type: ignore[var-annotated]
|
|
928
|
+
patch_sympy=patch_sympy,
|
|
929
|
+
patch_torch=patch_torch,
|
|
930
|
+
patch_transformers=patch_transformers,
|
|
931
|
+
patch_diffusers=patch_diffusers,
|
|
932
|
+
catch_constraints=catch_constraints,
|
|
933
|
+
stop_if_static=stop_if_static,
|
|
934
|
+
verbose=verbose,
|
|
935
|
+
patch=patch,
|
|
936
|
+
custom_patches=custom_patches,
|
|
937
|
+
patch_details=patch_details,
|
|
938
|
+
) as f,
|
|
939
|
+
):
|
|
940
|
+
try:
|
|
941
|
+
yield f
|
|
942
|
+
finally:
|
|
943
|
+
pass
|
|
944
|
+
elif not patch:
|
|
945
|
+
fct_callable = lambda x: x # noqa: E731
|
|
946
|
+
done = register_cache_serialization(
|
|
947
|
+
patch_transformers=patch_transformers,
|
|
948
|
+
patch_diffusers=patch_diffusers,
|
|
949
|
+
verbose=verbose,
|
|
950
|
+
)
|
|
951
|
+
try:
|
|
952
|
+
yield fct_callable
|
|
953
|
+
finally:
|
|
954
|
+
unregister_cache_serialization(done, verbose=verbose)
|
|
955
|
+
else:
|
|
956
|
+
if verbose:
|
|
957
|
+
print(
|
|
958
|
+
"[torch_export_patches] replace torch.jit.isinstance, "
|
|
959
|
+
"torch._dynamo.mark_static_address"
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
# caches
|
|
963
|
+
|
|
964
|
+
cache_done = register_cache_serialization(
|
|
965
|
+
patch_transformers=patch_transformers,
|
|
966
|
+
patch_diffusers=patch_diffusers,
|
|
967
|
+
verbose=verbose,
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
# patches
|
|
971
|
+
|
|
972
|
+
if patch_sympy:
|
|
973
|
+
(f_sympy_name,) = _patch_sympy(verbose, patch_details)
|
|
974
|
+
|
|
975
|
+
if patch_torch:
|
|
976
|
+
(
|
|
977
|
+
f___constrain_user_specified_dimhint_range,
|
|
978
|
+
f__broadcast_in_dim_meta,
|
|
979
|
+
f__broadcast_shapes,
|
|
980
|
+
f__check_input_constraints_for_graph,
|
|
981
|
+
f__maybe_broadcast,
|
|
982
|
+
f_broadcast_in_dim,
|
|
983
|
+
f_infer_size,
|
|
984
|
+
f_jit_isinstance,
|
|
985
|
+
f_mark_static_address,
|
|
986
|
+
f_produce_guards_and_solve_constraints,
|
|
987
|
+
f_shape_env__check_frozen,
|
|
988
|
+
f_shape_env__evaluate_expr,
|
|
989
|
+
f_shape_env__log_guard,
|
|
990
|
+
f_shape_env__set_replacement,
|
|
991
|
+
f_vmap,
|
|
992
|
+
) = _patch_torch(
|
|
993
|
+
verbose, patch_details, patch_torch, catch_constraints, stop_if_static
|
|
994
|
+
)
|
|
995
|
+
|
|
996
|
+
if patch_transformers:
|
|
997
|
+
(
|
|
998
|
+
f_transformers__vmap_for_bhqkv,
|
|
999
|
+
f_transformers_eager_mask,
|
|
1000
|
+
f_transformers_sdpa_attention_forward,
|
|
1001
|
+
f_transformers_sdpa_mask,
|
|
1002
|
+
f_transformers_sdpa_mask_recent_torch,
|
|
1003
|
+
revert_patches_info,
|
|
1004
|
+
) = _patch_transformers(verbose, patch_details)
|
|
1005
|
+
|
|
1006
|
+
if custom_patches:
|
|
1007
|
+
if verbose:
|
|
1008
|
+
print("[torch_export_patches] applies custom patches")
|
|
1009
|
+
revert_custom_patches_info = patch_module_or_classes(
|
|
1010
|
+
custom_patches, verbose=verbose, patch_details=patch_details
|
|
1011
|
+
)
|
|
1012
|
+
|
|
1013
|
+
# export
|
|
1014
|
+
|
|
1015
|
+
fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x)
|
|
1016
|
+
|
|
1017
|
+
if verbose:
|
|
1018
|
+
print("[torch_export_patches] done patching")
|
|
1019
|
+
|
|
1020
|
+
try:
|
|
1021
|
+
yield fct_callable
|
|
1022
|
+
finally:
|
|
1023
|
+
|
|
1024
|
+
# unpatch
|
|
1025
|
+
|
|
1026
|
+
if verbose:
|
|
1027
|
+
print("[torch_export_patches] remove patches")
|
|
1028
|
+
|
|
1029
|
+
if patch_sympy:
|
|
1030
|
+
_unpatch_sympy(verbose, f_sympy_name)
|
|
1031
|
+
|
|
1032
|
+
if patch_torch:
|
|
1033
|
+
_unpatch_torch(
|
|
1034
|
+
verbose,
|
|
1035
|
+
patch_details,
|
|
1036
|
+
patch_torch,
|
|
1037
|
+
catch_constraints,
|
|
1038
|
+
stop_if_static,
|
|
1039
|
+
f___constrain_user_specified_dimhint_range,
|
|
1040
|
+
f__broadcast_in_dim_meta,
|
|
1041
|
+
f__broadcast_shapes,
|
|
1042
|
+
f__check_input_constraints_for_graph,
|
|
1043
|
+
f__maybe_broadcast,
|
|
1044
|
+
f_broadcast_in_dim,
|
|
1045
|
+
f_infer_size,
|
|
1046
|
+
f_jit_isinstance,
|
|
1047
|
+
f_mark_static_address,
|
|
1048
|
+
f_produce_guards_and_solve_constraints,
|
|
1049
|
+
f_shape_env__check_frozen,
|
|
1050
|
+
f_shape_env__evaluate_expr,
|
|
1051
|
+
f_shape_env__log_guard,
|
|
1052
|
+
f_shape_env__set_replacement,
|
|
1053
|
+
f_vmap,
|
|
1054
|
+
)
|
|
1055
|
+
|
|
1056
|
+
if patch_transformers:
|
|
1057
|
+
_unpatch_transformers(
|
|
1058
|
+
verbose,
|
|
1059
|
+
patch_details,
|
|
1060
|
+
f_transformers__vmap_for_bhqkv,
|
|
1061
|
+
f_transformers_eager_mask,
|
|
1062
|
+
f_transformers_sdpa_attention_forward,
|
|
1063
|
+
f_transformers_sdpa_mask,
|
|
1064
|
+
f_transformers_sdpa_mask_recent_torch,
|
|
1065
|
+
revert_patches_info,
|
|
1066
|
+
)
|
|
1067
|
+
|
|
1068
|
+
if custom_patches:
|
|
1069
|
+
if verbose:
|
|
1070
|
+
print("[torch_export_patches] unpatches custom patches")
|
|
1071
|
+
unpatch_module_or_classes(
|
|
1072
|
+
custom_patches, revert_custom_patches_info, verbose=verbose
|
|
1073
|
+
)
|
|
1074
|
+
|
|
1075
|
+
########
|
|
1076
|
+
# caches
|
|
1077
|
+
########
|
|
1078
|
+
|
|
1079
|
+
unregister_cache_serialization(cache_done, verbose=verbose)
|
|
1080
|
+
|
|
1081
|
+
|
|
1082
|
+
def replacement_before_exporting(args: Any) -> Any:
|
|
1083
|
+
"""Does replacements on the given inputs if needed."""
|
|
1084
|
+
if args is None:
|
|
1085
|
+
return None
|
|
1086
|
+
if isinstance(args, (int, float)):
|
|
1087
|
+
return args
|
|
1088
|
+
if type(args) not in {dict, tuple, list}:
|
|
1089
|
+
# BaseModelOutput is a dict
|
|
1090
|
+
return args
|
|
1091
|
+
if isinstance(args, dict):
|
|
1092
|
+
return {k: replacement_before_exporting(v) for k, v in args.items()}
|
|
1093
|
+
if isinstance(args, tuple):
|
|
1094
|
+
return tuple(replacement_before_exporting(v) for v in args)
|
|
1095
|
+
if isinstance(args, list):
|
|
1096
|
+
return [replacement_before_exporting(v) for v in args]
|
|
1097
|
+
|
|
1098
|
+
return args
|