onnx-diagnostic 0.7.16__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +78 -22
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +2 -1
- onnx_diagnostic/export/shape_helper.py +47 -70
- onnx_diagnostic/ext_test_case.py +11 -0
- onnx_diagnostic/helpers/cache_helper.py +38 -7
- onnx_diagnostic/helpers/fake_tensor_helper.py +224 -104
- onnx_diagnostic/helpers/helper.py +27 -33
- onnx_diagnostic/helpers/log_helper.py +109 -5
- onnx_diagnostic/helpers/memory_peak.py +2 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +1 -1
- onnx_diagnostic/helpers/model_builder_helper.py +132 -2
- onnx_diagnostic/helpers/onnx_helper.py +1 -1
- onnx_diagnostic/helpers/ort_session.py +4 -0
- onnx_diagnostic/helpers/rt_helper.py +393 -43
- onnx_diagnostic/helpers/torch_helper.py +20 -1
- onnx_diagnostic/tasks/__init__.py +7 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +2 -8
- onnx_diagnostic/tasks/feature_extraction.py +2 -8
- onnx_diagnostic/tasks/image_text_to_text.py +10 -8
- onnx_diagnostic/tasks/summarization.py +2 -8
- onnx_diagnostic/tasks/text2text_generation.py +3 -8
- onnx_diagnostic/tasks/text_generation.py +86 -65
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +718 -438
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +1 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +9 -36
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -6
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +162 -24
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +140 -104
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -4
- onnx_diagnostic/torch_models/validate.py +626 -228
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/RECORD +38 -36
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import importlib
|
|
3
|
+
import inspect
|
|
3
4
|
import contextlib
|
|
4
5
|
import re
|
|
5
6
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
@@ -8,6 +9,7 @@ from .onnx_export_serialization import (
|
|
|
8
9
|
unregister_cache_serialization,
|
|
9
10
|
)
|
|
10
11
|
from .patches import patch_transformers as patch_transformers_list
|
|
12
|
+
from .patch_details import PatchDetails
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
def get_function(name: str) -> Tuple[type, Callable]:
|
|
@@ -51,7 +53,9 @@ def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
|
|
|
51
53
|
return name, to_patch
|
|
52
54
|
|
|
53
55
|
|
|
54
|
-
def patch_module_or_classes(
|
|
56
|
+
def patch_module_or_classes(
|
|
57
|
+
mod, verbose: int = 0, patch_details: Optional[PatchDetails] = None
|
|
58
|
+
) -> Dict[type, Dict[type, Callable]]:
|
|
55
59
|
"""
|
|
56
60
|
Applies all patches defined in classes prefixed by ``patched_``
|
|
57
61
|
``cls._PATCHED_CLASS_`` defines the class to patch,
|
|
@@ -61,13 +65,16 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
|
|
|
61
65
|
|
|
62
66
|
:param mod: module of list of clsses to patch
|
|
63
67
|
:param verbose: verbosity
|
|
68
|
+
:param patch_details: used to store information about the applied patches
|
|
64
69
|
:return: patch info
|
|
65
70
|
"""
|
|
66
71
|
if isinstance(mod, list):
|
|
67
72
|
to_patch = mod
|
|
68
73
|
name = "list"
|
|
74
|
+
list_name = "auto/list"
|
|
69
75
|
else:
|
|
70
76
|
name, to_patch = get_patches(mod, verbose)
|
|
77
|
+
list_name = f"auto/{mod.__name__.split('.')[-1]}"
|
|
71
78
|
|
|
72
79
|
res = {}
|
|
73
80
|
for cls in to_patch:
|
|
@@ -76,9 +83,15 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
|
|
|
76
83
|
keep = {}
|
|
77
84
|
original = cls["module"]
|
|
78
85
|
f = cls["function"]
|
|
86
|
+
assert not f.__name__.startswith("patched_"), (
|
|
87
|
+
f"The function {f} was already patched or the patch was not removed, "
|
|
88
|
+
f"original={original}"
|
|
89
|
+
)
|
|
79
90
|
res[f] = f
|
|
80
91
|
if verbose:
|
|
81
92
|
print(f"[patch_module_or_classes] function: {original.__name__}.{f.__name__}")
|
|
93
|
+
if patch_details:
|
|
94
|
+
patch_details.append(list_name, getattr(original, f.__name__), cls["patch"])
|
|
82
95
|
setattr(original, f.__name__, cls["patch"])
|
|
83
96
|
continue
|
|
84
97
|
|
|
@@ -89,6 +102,18 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
|
|
|
89
102
|
|
|
90
103
|
keep = {n: getattr(original, n, None) for n in methods}
|
|
91
104
|
for n in methods:
|
|
105
|
+
if patch_details:
|
|
106
|
+
if hasattr(original, n):
|
|
107
|
+
p = patch_details.append(list_name, getattr(original, n), getattr(cls, n))
|
|
108
|
+
else:
|
|
109
|
+
p = patch_details.append(
|
|
110
|
+
list_name, f"{original.__name__}{n}", getattr(cls, n)
|
|
111
|
+
)
|
|
112
|
+
if "@patched_dynamic_rope_update" in inspect.getsource(getattr(cls, n)):
|
|
113
|
+
# a tweak to include that patch.
|
|
114
|
+
f = patch_details.find("patched_dynamic_rope_update")
|
|
115
|
+
if f is not None:
|
|
116
|
+
p.add_dependency(f)
|
|
92
117
|
setattr(original, n, getattr(cls, n))
|
|
93
118
|
res[cls] = keep
|
|
94
119
|
|
|
@@ -157,6 +182,628 @@ def register_additional_serialization_functions(
|
|
|
157
182
|
unregister_cache_serialization(done, verbose=verbose)
|
|
158
183
|
|
|
159
184
|
|
|
185
|
+
def _patch_sympy(verbose: int, patch_details: PatchDetails) -> Tuple[Optional[Callable], ...]:
|
|
186
|
+
import sympy
|
|
187
|
+
|
|
188
|
+
f_sympy_name = getattr(sympy.core.numbers.IntegerConstant, "name", None)
|
|
189
|
+
|
|
190
|
+
if verbose:
|
|
191
|
+
print(f"[torch_export_patches] sympy.__version__={sympy.__version__!r}")
|
|
192
|
+
print("[torch_export_patches] patch sympy")
|
|
193
|
+
|
|
194
|
+
sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}"
|
|
195
|
+
if patch_details:
|
|
196
|
+
patch_details.append(
|
|
197
|
+
"sympy",
|
|
198
|
+
f_sympy_name or "sympy.core.numbers.IntegerConstant.name",
|
|
199
|
+
sympy.core.numbers.IntegerConstant.name,
|
|
200
|
+
)
|
|
201
|
+
return (f_sympy_name,)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _unpatch_sympy(verbose: int, f_sympy_name: Optional[Callable]):
|
|
205
|
+
# tracked by https://github.com/pytorch/pytorch/issues/143494
|
|
206
|
+
import sympy
|
|
207
|
+
|
|
208
|
+
if f_sympy_name:
|
|
209
|
+
sympy.core.numbers.IntegerConstant.name = f_sympy_name
|
|
210
|
+
else:
|
|
211
|
+
delattr(sympy.core.numbers.IntegerConstant, "name")
|
|
212
|
+
|
|
213
|
+
if verbose:
|
|
214
|
+
print("[torch_export_patches] restored sympy functions")
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _patch_torch(
|
|
218
|
+
verbose: int,
|
|
219
|
+
patch_details: PatchDetails,
|
|
220
|
+
patch_torch: int,
|
|
221
|
+
catch_constraints: bool,
|
|
222
|
+
stop_if_static: int,
|
|
223
|
+
) -> Tuple[Optional[Callable], ...]:
|
|
224
|
+
import torch
|
|
225
|
+
import torch.jit
|
|
226
|
+
import torch._export.non_strict_utils # produce_guards_and_solve_constraints
|
|
227
|
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
228
|
+
from .patches.patch_torch import (
|
|
229
|
+
patched_infer_size,
|
|
230
|
+
patched_vmap,
|
|
231
|
+
patched__broadcast_shapes,
|
|
232
|
+
patched__constrain_user_specified_dimhint_range,
|
|
233
|
+
_catch_produce_guards_and_solve_constraints,
|
|
234
|
+
patch__check_input_constraints_for_graph,
|
|
235
|
+
patched__broadcast_in_dim_meta,
|
|
236
|
+
patched__broadcast_in_dim_meta_level_2,
|
|
237
|
+
patched__maybe_broadcast,
|
|
238
|
+
patched_ShapeEnv,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
f___constrain_user_specified_dimhint_range = None
|
|
242
|
+
f__broadcast_in_dim_meta = None
|
|
243
|
+
f__broadcast_shapes = None
|
|
244
|
+
f__check_input_constraints_for_graph = None
|
|
245
|
+
f__maybe_broadcast = None
|
|
246
|
+
f_broadcast_in_dim = None
|
|
247
|
+
f_infer_size = None
|
|
248
|
+
f_jit_isinstance = None
|
|
249
|
+
f_mark_static_address = None
|
|
250
|
+
f_produce_guards_and_solve_constraints = None
|
|
251
|
+
f_shape_env__check_frozen = None
|
|
252
|
+
f_shape_env__evaluate_expr = None
|
|
253
|
+
f_shape_env__log_guard = None
|
|
254
|
+
f_shape_env__set_replacement = None
|
|
255
|
+
f_vmap = None
|
|
256
|
+
|
|
257
|
+
if verbose:
|
|
258
|
+
print(f"[torch_export_patches] torch.__version__={torch.__version__!r}")
|
|
259
|
+
print(f"[torch_export_patches] stop_if_static={stop_if_static!r}")
|
|
260
|
+
print("[torch_export_patches] patch pytorch")
|
|
261
|
+
|
|
262
|
+
# torch.vmap
|
|
263
|
+
f_vmap = torch.vmap
|
|
264
|
+
torch.vmap = patched_vmap
|
|
265
|
+
|
|
266
|
+
# torch.jit.isinstance
|
|
267
|
+
f_jit_isinstance = torch.jit.isinstance
|
|
268
|
+
torch.jit.isinstance = isinstance
|
|
269
|
+
|
|
270
|
+
# torch._dynamo.mark_static_address
|
|
271
|
+
f_mark_static_address = torch._dynamo.mark_static_address
|
|
272
|
+
torch._dynamo.mark_static_address = lambda *_, **y_: None
|
|
273
|
+
|
|
274
|
+
# torch._subclasses.fake_impls.infer_size
|
|
275
|
+
f_infer_size = torch._subclasses.fake_impls.infer_size
|
|
276
|
+
torch._subclasses.fake_impls.infer_size = patched_infer_size
|
|
277
|
+
if patch_details:
|
|
278
|
+
patch_details.append("torch", f_infer_size, patched_infer_size)
|
|
279
|
+
|
|
280
|
+
# torch._refs._broadcast_shapes
|
|
281
|
+
f__broadcast_shapes = torch._refs._broadcast_shapes
|
|
282
|
+
torch._refs._broadcast_shapes = patched__broadcast_shapes
|
|
283
|
+
torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes
|
|
284
|
+
if patch_details:
|
|
285
|
+
patch_details.append("torch", f__broadcast_shapes, patched__broadcast_shapes)
|
|
286
|
+
|
|
287
|
+
# torch._export.non_strict_utils._constrain_user_specified_dimhint_range
|
|
288
|
+
f___constrain_user_specified_dimhint_range = (
|
|
289
|
+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range
|
|
290
|
+
)
|
|
291
|
+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
|
|
292
|
+
patched__constrain_user_specified_dimhint_range
|
|
293
|
+
)
|
|
294
|
+
if patch_details:
|
|
295
|
+
patch_details.append(
|
|
296
|
+
"torch",
|
|
297
|
+
f___constrain_user_specified_dimhint_range,
|
|
298
|
+
patched__constrain_user_specified_dimhint_range,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# torch._prims._broadcast_in_dim_meta
|
|
302
|
+
f_broadcast_in_dim = torch._prims.broadcast_in_dim
|
|
303
|
+
f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
|
|
304
|
+
_patched_dim_f = (
|
|
305
|
+
patched__broadcast_in_dim_meta_level_2
|
|
306
|
+
if patch_torch == 2
|
|
307
|
+
else patched__broadcast_in_dim_meta
|
|
308
|
+
)
|
|
309
|
+
torch._prims._broadcast_in_dim_meta = _patched_dim_f
|
|
310
|
+
torch._prims.broadcast_in_dim = _patched_dim_f
|
|
311
|
+
if patch_details:
|
|
312
|
+
patch_details.append("torch", f__broadcast_in_dim_meta, _patched_dim_f)
|
|
313
|
+
|
|
314
|
+
# torch._refs._maybe_broadcast
|
|
315
|
+
f__maybe_broadcast = torch._refs._maybe_broadcast
|
|
316
|
+
torch._refs._maybe_broadcast = patched__maybe_broadcast
|
|
317
|
+
if patch_details:
|
|
318
|
+
patch_details.append("torch", f__maybe_broadcast, patched__maybe_broadcast)
|
|
319
|
+
|
|
320
|
+
# ShapeEnv
|
|
321
|
+
f_shape_env__evaluate_expr = ShapeEnv._evaluate_expr
|
|
322
|
+
ShapeEnv._evaluate_expr = patched_ShapeEnv._evaluate_expr
|
|
323
|
+
if patch_details:
|
|
324
|
+
patch_details.append(
|
|
325
|
+
"torch", f_shape_env__evaluate_expr, patched_ShapeEnv._evaluate_expr
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
# torch._export.non_strict_utils.produce_guards_and_solve_constraints
|
|
329
|
+
if catch_constraints:
|
|
330
|
+
if verbose:
|
|
331
|
+
print("[torch_export_patches] modifies shape constraints")
|
|
332
|
+
f_produce_guards_and_solve_constraints = (
|
|
333
|
+
torch._export.non_strict_utils.produce_guards_and_solve_constraints
|
|
334
|
+
)
|
|
335
|
+
f__check_input_constraints_for_graph = (
|
|
336
|
+
torch._export.utils._check_input_constraints_for_graph
|
|
337
|
+
)
|
|
338
|
+
torch._export.non_strict_utils.produce_guards_and_solve_constraints = (
|
|
339
|
+
lambda *args, **kwargs: _catch_produce_guards_and_solve_constraints(
|
|
340
|
+
f_produce_guards_and_solve_constraints, *args, verbose=verbose, **kwargs
|
|
341
|
+
)
|
|
342
|
+
)
|
|
343
|
+
torch._export.utils._check_input_constraints_for_graph = (
|
|
344
|
+
lambda *args, **kwargs: patch__check_input_constraints_for_graph(
|
|
345
|
+
f__check_input_constraints_for_graph, *args, verbose=verbose, **kwargs
|
|
346
|
+
)
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
if patch_torch and stop_if_static:
|
|
350
|
+
ShapeEnv._log_guard_remember = ShapeEnv._log_guard
|
|
351
|
+
|
|
352
|
+
if verbose:
|
|
353
|
+
print("[torch_export_patches] assert when a dynamic dimension turns static")
|
|
354
|
+
print("[torch_export_patches] replaces ShapeEnv._set_replacement")
|
|
355
|
+
|
|
356
|
+
f_shape_env__set_replacement = ShapeEnv._set_replacement
|
|
357
|
+
ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement
|
|
358
|
+
if patch_details:
|
|
359
|
+
patch_details.append(
|
|
360
|
+
"torch", f_shape_env__set_replacement, patched_ShapeEnv._set_replacement
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
if verbose:
|
|
364
|
+
print("[torch_export_patches] replaces ShapeEnv._log_guard")
|
|
365
|
+
f_shape_env__log_guard = ShapeEnv._log_guard
|
|
366
|
+
ShapeEnv._log_guard = patched_ShapeEnv._log_guard
|
|
367
|
+
if patch_details:
|
|
368
|
+
patch_details.append("torch", f_shape_env__log_guard, patched_ShapeEnv._log_guard)
|
|
369
|
+
|
|
370
|
+
if stop_if_static > 1:
|
|
371
|
+
if verbose:
|
|
372
|
+
print("[torch_export_patches] replaces ShapeEnv._check_frozen")
|
|
373
|
+
f_shape_env__check_frozen = ShapeEnv._check_frozen
|
|
374
|
+
ShapeEnv._check_frozen = patched_ShapeEnv._check_frozen
|
|
375
|
+
if patch_details:
|
|
376
|
+
patch_details.append(
|
|
377
|
+
"torch", f_shape_env__check_frozen, ShapeEnv._check_frozen
|
|
378
|
+
)
|
|
379
|
+
return (
|
|
380
|
+
f___constrain_user_specified_dimhint_range,
|
|
381
|
+
f__broadcast_in_dim_meta,
|
|
382
|
+
f__broadcast_shapes,
|
|
383
|
+
f__check_input_constraints_for_graph,
|
|
384
|
+
f__maybe_broadcast,
|
|
385
|
+
f_broadcast_in_dim,
|
|
386
|
+
f_infer_size,
|
|
387
|
+
f_jit_isinstance,
|
|
388
|
+
f_mark_static_address,
|
|
389
|
+
f_produce_guards_and_solve_constraints,
|
|
390
|
+
f_shape_env__check_frozen,
|
|
391
|
+
f_shape_env__evaluate_expr,
|
|
392
|
+
f_shape_env__log_guard,
|
|
393
|
+
f_shape_env__set_replacement,
|
|
394
|
+
f_vmap,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def _unpatch_torch(
|
|
399
|
+
verbose: int,
|
|
400
|
+
_patch_details: PatchDetails,
|
|
401
|
+
patch_torch: int,
|
|
402
|
+
catch_constraints: bool,
|
|
403
|
+
stop_if_static: int,
|
|
404
|
+
f___constrain_user_specified_dimhint_range: Optional[Callable],
|
|
405
|
+
f__broadcast_in_dim_meta: Optional[Callable],
|
|
406
|
+
f__broadcast_shapes: Optional[Callable],
|
|
407
|
+
f__check_input_constraints_for_graph: Optional[Callable],
|
|
408
|
+
f__maybe_broadcast: Optional[Callable],
|
|
409
|
+
f_broadcast_in_dim: Optional[Callable],
|
|
410
|
+
f_infer_size: Optional[Callable],
|
|
411
|
+
f_jit_isinstance: Optional[Callable],
|
|
412
|
+
f_mark_static_address: Optional[Callable],
|
|
413
|
+
f_produce_guards_and_solve_constraints: Optional[Callable],
|
|
414
|
+
f_shape_env__check_frozen: Optional[Callable],
|
|
415
|
+
f_shape_env__evaluate_expr: Optional[Callable],
|
|
416
|
+
f_shape_env__log_guard: Optional[Callable],
|
|
417
|
+
f_shape_env__set_replacement: Optional[Callable],
|
|
418
|
+
f_vmap: Optional[Callable],
|
|
419
|
+
):
|
|
420
|
+
import torch
|
|
421
|
+
import torch.jit
|
|
422
|
+
import torch._export.non_strict_utils # produce_guards_and_solve_constraints
|
|
423
|
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
424
|
+
|
|
425
|
+
# this should disappear when torch.jit is removed
|
|
426
|
+
torch.vmap = f_vmap
|
|
427
|
+
torch.jit.isinstance = f_jit_isinstance
|
|
428
|
+
torch._dynamo.mark_static_address = f_mark_static_address
|
|
429
|
+
# tracked by https://github.com/pytorch/pytorch/issues/143495
|
|
430
|
+
torch._subclasses.fake_impls.infer_size = f_infer_size
|
|
431
|
+
torch._refs._broadcast_shapes = f__broadcast_shapes
|
|
432
|
+
torch._meta_registrations._broadcast_shapes = f__broadcast_shapes
|
|
433
|
+
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
|
|
434
|
+
f___constrain_user_specified_dimhint_range
|
|
435
|
+
)
|
|
436
|
+
torch._prims._broadcast_in_dim_meta = f__broadcast_in_dim_meta
|
|
437
|
+
torch._prims.broadcast_in_dim = f_broadcast_in_dim
|
|
438
|
+
torch._refs._maybe_broadcast = f__maybe_broadcast
|
|
439
|
+
ShapeEnv._evaluate_expr = f_shape_env__evaluate_expr
|
|
440
|
+
|
|
441
|
+
if verbose:
|
|
442
|
+
print("[torch_export_patches] restored pytorch functions")
|
|
443
|
+
|
|
444
|
+
if patch_torch and stop_if_static:
|
|
445
|
+
if verbose:
|
|
446
|
+
print("[torch_export_patches] restored ShapeEnv._set_replacement")
|
|
447
|
+
|
|
448
|
+
ShapeEnv._set_replacement = f_shape_env__set_replacement
|
|
449
|
+
|
|
450
|
+
if verbose:
|
|
451
|
+
print("[torch_export_patches] restored ShapeEnv._log_guard")
|
|
452
|
+
|
|
453
|
+
ShapeEnv._log_guard = f_shape_env__log_guard
|
|
454
|
+
|
|
455
|
+
if stop_if_static > 1:
|
|
456
|
+
if verbose:
|
|
457
|
+
print("[torch_export_patches] restored ShapeEnv._check_frozen")
|
|
458
|
+
ShapeEnv._check_frozen = f_shape_env__check_frozen
|
|
459
|
+
|
|
460
|
+
if patch_torch and catch_constraints:
|
|
461
|
+
# to catch or skip dynamic_shapes issues
|
|
462
|
+
torch._export.non_strict_utils.produce_guards_and_solve_constraints = (
|
|
463
|
+
f_produce_guards_and_solve_constraints
|
|
464
|
+
)
|
|
465
|
+
torch._export.utils._check_input_constraints_for_graph = (
|
|
466
|
+
f__check_input_constraints_for_graph
|
|
467
|
+
)
|
|
468
|
+
if verbose:
|
|
469
|
+
print("[torch_export_patches] restored shape constraints")
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def _patch_transformers(
|
|
473
|
+
verbose: int, patch_details: PatchDetails
|
|
474
|
+
) -> Tuple[Optional[Callable], ...]:
|
|
475
|
+
import transformers
|
|
476
|
+
|
|
477
|
+
try:
|
|
478
|
+
import transformers.masking_utils as masking_utils
|
|
479
|
+
except ImportError:
|
|
480
|
+
masking_utils = None
|
|
481
|
+
|
|
482
|
+
try:
|
|
483
|
+
import transformers.integrations.sdpa_attention as sdpa_attention
|
|
484
|
+
except ImportError:
|
|
485
|
+
sdpa_attention = None
|
|
486
|
+
|
|
487
|
+
try:
|
|
488
|
+
import transformers.modeling_utils as modeling_utils
|
|
489
|
+
except ImportError:
|
|
490
|
+
modeling_utils = None
|
|
491
|
+
|
|
492
|
+
try:
|
|
493
|
+
import transformers.modeling_rope_utils as modeling_rope_utils
|
|
494
|
+
except ImportError:
|
|
495
|
+
modeling_rope_utils = None
|
|
496
|
+
|
|
497
|
+
if (
|
|
498
|
+
patch_details
|
|
499
|
+
and modeling_rope_utils
|
|
500
|
+
and hasattr(modeling_rope_utils, "dynamic_rope_update")
|
|
501
|
+
):
|
|
502
|
+
patch_details.append(
|
|
503
|
+
"patch_transformers",
|
|
504
|
+
modeling_rope_utils.dynamic_rope_update,
|
|
505
|
+
patch_transformers_list.patched_dynamic_rope_update,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
if verbose:
|
|
509
|
+
print(f"[torch_export_patches] transformers.__version__={transformers.__version__!r}")
|
|
510
|
+
assert not sdpa_attention.sdpa_attention_forward.__name__.startswith("patched_"), (
|
|
511
|
+
f"Function 'sdpa_attention.sdpa_attention_forward' is already patched, "
|
|
512
|
+
f"sdpa_attention.sdpa_attention_forward={sdpa_attention.sdpa_attention_forward}"
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
f_transformers__vmap_for_bhqkv = None
|
|
516
|
+
f_transformers_eager_mask = None
|
|
517
|
+
f_transformers_sdpa_attention_forward = None
|
|
518
|
+
f_transformers_sdpa_mask = None
|
|
519
|
+
f_transformers_sdpa_mask_recent_torch = None
|
|
520
|
+
|
|
521
|
+
if ( # vmap
|
|
522
|
+
masking_utils
|
|
523
|
+
and patch_transformers_list.patch_masking_utils
|
|
524
|
+
and hasattr(masking_utils, "_vmap_for_bhqkv")
|
|
525
|
+
):
|
|
526
|
+
if verbose:
|
|
527
|
+
print("[torch_export_patches] patches transformers.masking_utils._vmap_for_bhqkv")
|
|
528
|
+
f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
|
|
529
|
+
masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
|
|
530
|
+
if patch_details:
|
|
531
|
+
patch_details.append(
|
|
532
|
+
"transformers",
|
|
533
|
+
f_transformers__vmap_for_bhqkv,
|
|
534
|
+
patch_transformers_list.patched__vmap_for_bhqkv,
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
if verbose:
|
|
538
|
+
print(
|
|
539
|
+
"[torch_export_patches] patches "
|
|
540
|
+
"transformers.masking_utils.sdpa_mask_recent_torch"
|
|
541
|
+
)
|
|
542
|
+
f_transformers_sdpa_mask_recent_torch = masking_utils.sdpa_mask_recent_torch
|
|
543
|
+
masking_utils.sdpa_mask_recent_torch = (
|
|
544
|
+
patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
545
|
+
)
|
|
546
|
+
if patch_details:
|
|
547
|
+
patch_details.append(
|
|
548
|
+
"transformers",
|
|
549
|
+
f_transformers_sdpa_mask_recent_torch,
|
|
550
|
+
patch_transformers_list.patched_sdpa_mask_recent_torch,
|
|
551
|
+
)
|
|
552
|
+
if masking_utils.sdpa_mask == f_transformers_sdpa_mask_recent_torch:
|
|
553
|
+
if verbose:
|
|
554
|
+
print("[torch_export_patches] patches transformers.masking_utils.sdpa_mask")
|
|
555
|
+
f_transformers_sdpa_mask = masking_utils.sdpa_mask
|
|
556
|
+
masking_utils.sdpa_mask = patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
557
|
+
if patch_details:
|
|
558
|
+
patch_details.append(
|
|
559
|
+
"transformers",
|
|
560
|
+
f_transformers_sdpa_mask,
|
|
561
|
+
patch_transformers_list.patched_sdpa_mask_recent_torch,
|
|
562
|
+
)
|
|
563
|
+
else:
|
|
564
|
+
f_transformers_sdpa_mask = None
|
|
565
|
+
|
|
566
|
+
if ( # eager_mask
|
|
567
|
+
masking_utils
|
|
568
|
+
and patch_transformers_list.patch_masking_utils
|
|
569
|
+
and hasattr(masking_utils, "eager_mask")
|
|
570
|
+
):
|
|
571
|
+
if verbose:
|
|
572
|
+
print("[torch_export_patches] patches transformers.masking_utils.eager_mask")
|
|
573
|
+
f_transformers_eager_mask = masking_utils.eager_mask
|
|
574
|
+
masking_utils.eager_mask = patch_transformers_list.patched_eager_mask
|
|
575
|
+
if patch_details:
|
|
576
|
+
patch_details.append(
|
|
577
|
+
"transformers",
|
|
578
|
+
f_transformers_eager_mask,
|
|
579
|
+
patch_transformers_list.patched_eager_mask,
|
|
580
|
+
)
|
|
581
|
+
if (
|
|
582
|
+
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
583
|
+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
|
|
584
|
+
== f_transformers_eager_mask
|
|
585
|
+
):
|
|
586
|
+
if verbose:
|
|
587
|
+
print(
|
|
588
|
+
"[torch_export_patches] patches "
|
|
589
|
+
"transformers.masking_utils.eager_mask "
|
|
590
|
+
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
591
|
+
)
|
|
592
|
+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
|
|
593
|
+
patch_transformers_list.patched_eager_mask
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
if ( # sdpa_mask
|
|
597
|
+
masking_utils
|
|
598
|
+
and patch_transformers_list.patch_masking_utils
|
|
599
|
+
and hasattr(masking_utils, "sdpa_mask")
|
|
600
|
+
and f_transformers_sdpa_mask is not None
|
|
601
|
+
):
|
|
602
|
+
if verbose:
|
|
603
|
+
print(
|
|
604
|
+
"[torch_export_patches] patches "
|
|
605
|
+
"transformers.masking_utils.sdpa_mask "
|
|
606
|
+
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
607
|
+
)
|
|
608
|
+
if (
|
|
609
|
+
"sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
610
|
+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] == f_transformers_sdpa_mask
|
|
611
|
+
):
|
|
612
|
+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
|
|
613
|
+
patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
if ( # sdpa_attention_forward
|
|
617
|
+
sdpa_attention is not None
|
|
618
|
+
and modeling_utils is not None
|
|
619
|
+
and hasattr(sdpa_attention, "sdpa_attention_forward")
|
|
620
|
+
and hasattr(sdpa_attention, "use_gqa_in_sdpa")
|
|
621
|
+
and hasattr(modeling_utils, "AttentionInterface")
|
|
622
|
+
):
|
|
623
|
+
if verbose:
|
|
624
|
+
print(
|
|
625
|
+
"[torch_export_patches] patches "
|
|
626
|
+
"transformers.integrations.sdpa_attention.sdpa_attention_forward"
|
|
627
|
+
)
|
|
628
|
+
f_transformers_sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
|
|
629
|
+
assert not f_transformers_sdpa_attention_forward.__name__.startswith("patched_"), (
|
|
630
|
+
f"Function 'sdpa_attention.sdpa_attention_forward' is already patched, "
|
|
631
|
+
f"sdpa_attention.sdpa_attention_forward={f_transformers_sdpa_attention_forward}"
|
|
632
|
+
)
|
|
633
|
+
sdpa_attention.sdpa_attention_forward = (
|
|
634
|
+
patch_transformers_list.patched_sdpa_attention_forward
|
|
635
|
+
)
|
|
636
|
+
modeling_utils.sdpa_attention_forward = (
|
|
637
|
+
patch_transformers_list.patched_sdpa_attention_forward
|
|
638
|
+
)
|
|
639
|
+
modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
|
|
640
|
+
patch_transformers_list.patched_sdpa_attention_forward
|
|
641
|
+
)
|
|
642
|
+
if patch_details:
|
|
643
|
+
patch_details.append(
|
|
644
|
+
"transformers",
|
|
645
|
+
f_transformers_sdpa_attention_forward,
|
|
646
|
+
patch_transformers_list.patched_sdpa_attention_forward,
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
revert_patches_info = patch_module_or_classes(
|
|
650
|
+
patch_transformers_list, verbose=verbose, patch_details=patch_details
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
return (
|
|
654
|
+
f_transformers__vmap_for_bhqkv,
|
|
655
|
+
f_transformers_eager_mask,
|
|
656
|
+
f_transformers_sdpa_attention_forward,
|
|
657
|
+
f_transformers_sdpa_mask,
|
|
658
|
+
f_transformers_sdpa_mask_recent_torch,
|
|
659
|
+
revert_patches_info,
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def _unpatch_transformers(
|
|
664
|
+
verbose: int,
|
|
665
|
+
_patch_details: PatchDetails,
|
|
666
|
+
f_transformers__vmap_for_bhqkv: Optional[Callable],
|
|
667
|
+
f_transformers_eager_mask: Optional[Callable],
|
|
668
|
+
f_transformers_sdpa_attention_forward: Optional[Callable],
|
|
669
|
+
f_transformers_sdpa_mask: Optional[Callable],
|
|
670
|
+
f_transformers_sdpa_mask_recent_torch: Optional[Callable],
|
|
671
|
+
revert_patches_info: Optional[Callable],
|
|
672
|
+
):
|
|
673
|
+
|
|
674
|
+
try:
|
|
675
|
+
import transformers.masking_utils as masking_utils
|
|
676
|
+
except ImportError:
|
|
677
|
+
masking_utils = None
|
|
678
|
+
|
|
679
|
+
try:
|
|
680
|
+
import transformers.integrations.sdpa_attention as sdpa_attention
|
|
681
|
+
except ImportError:
|
|
682
|
+
sdpa_attention = None
|
|
683
|
+
|
|
684
|
+
try:
|
|
685
|
+
import transformers.modeling_utils as modeling_utils
|
|
686
|
+
except ImportError:
|
|
687
|
+
modeling_utils = None
|
|
688
|
+
|
|
689
|
+
try:
|
|
690
|
+
import transformers.masking_utils as masking_utils
|
|
691
|
+
except ImportError:
|
|
692
|
+
masking_utils = None
|
|
693
|
+
if verbose:
|
|
694
|
+
print("[torch_export_patches] unpatches transformers")
|
|
695
|
+
|
|
696
|
+
if ( # vmap
|
|
697
|
+
masking_utils
|
|
698
|
+
and patch_transformers_list.patch_masking_utils
|
|
699
|
+
and hasattr(masking_utils, "_vmap_for_bhqkv")
|
|
700
|
+
):
|
|
701
|
+
assert f_transformers__vmap_for_bhqkv.__name__ == "_vmap_for_bhqkv", (
|
|
702
|
+
f"corrupted function '_vmap_for_bhqkv', its name is "
|
|
703
|
+
f"{f_transformers__vmap_for_bhqkv.__name__!r}"
|
|
704
|
+
)
|
|
705
|
+
masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
|
|
706
|
+
|
|
707
|
+
if verbose:
|
|
708
|
+
print("[torch_export_patches] restored transformers.masking_utils._vmap_for_bhqkv")
|
|
709
|
+
|
|
710
|
+
assert f_transformers_sdpa_mask_recent_torch.__name__ == "sdpa_mask_recent_torch", (
|
|
711
|
+
f"corrupted function 'sdpa_mask_recent_torch', its name is "
|
|
712
|
+
f"{f_transformers_sdpa_mask_recent_torch.__name__!r}"
|
|
713
|
+
)
|
|
714
|
+
masking_utils.sdpa_mask_recent_torch = f_transformers_sdpa_mask_recent_torch
|
|
715
|
+
|
|
716
|
+
if verbose:
|
|
717
|
+
print(
|
|
718
|
+
"[torch_export_patches] restored "
|
|
719
|
+
"transformers.masking_utils.sdpa_mask_recent_torch"
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
if f_transformers_sdpa_mask is not None:
|
|
723
|
+
assert f_transformers_sdpa_mask.__name__ in (
|
|
724
|
+
"sdpa_mask",
|
|
725
|
+
"sdpa_mask_recent_torch",
|
|
726
|
+
), (
|
|
727
|
+
f"corrupted function 'sdpa_mask', its name is "
|
|
728
|
+
f"{f_transformers_sdpa_mask.__name__!r}"
|
|
729
|
+
)
|
|
730
|
+
masking_utils.sdpa_mask = f_transformers_sdpa_mask
|
|
731
|
+
if verbose:
|
|
732
|
+
print("[torch_export_patches] restored transformers.masking_utils.sdpa_mask")
|
|
733
|
+
|
|
734
|
+
if ( # eager_mask
|
|
735
|
+
masking_utils
|
|
736
|
+
and patch_transformers_list.patch_masking_utils
|
|
737
|
+
and hasattr(masking_utils, "eager_mask")
|
|
738
|
+
):
|
|
739
|
+
assert f_transformers_eager_mask.__name__ == "eager_mask", (
|
|
740
|
+
f"corrupted function 'eager_mask', its name is "
|
|
741
|
+
f"{f_transformers_eager_mask.__name__!r}"
|
|
742
|
+
)
|
|
743
|
+
masking_utils.eager_mask = f_transformers_eager_mask
|
|
744
|
+
if verbose:
|
|
745
|
+
print("[torch_export_patches] restored transformers.masking_utils.eager_mask")
|
|
746
|
+
assert masking_utils.eager_mask.__name__ == "eager_mask", (
|
|
747
|
+
f"corrupted function 'eager_mask', its name is "
|
|
748
|
+
f"{masking_utils.eager_mask.__name__!r}"
|
|
749
|
+
)
|
|
750
|
+
if (
|
|
751
|
+
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
752
|
+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
|
|
753
|
+
== patch_transformers_list.patched_eager_mask
|
|
754
|
+
):
|
|
755
|
+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = f_transformers_eager_mask
|
|
756
|
+
if verbose:
|
|
757
|
+
print(
|
|
758
|
+
"[torch_export_patches] restored "
|
|
759
|
+
"transformers.masking_utils.eager_mask "
|
|
760
|
+
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
761
|
+
)
|
|
762
|
+
assert masking_utils.eager_mask.__name__ == "eager_mask", (
|
|
763
|
+
f"corrupted function 'eager_mask', its name is "
|
|
764
|
+
f"{masking_utils.eager_mask.__name__!r}"
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
if ( # sdpa_mask
|
|
768
|
+
masking_utils
|
|
769
|
+
and patch_transformers_list.patch_masking_utils
|
|
770
|
+
and hasattr(masking_utils, "sdpa_mask")
|
|
771
|
+
):
|
|
772
|
+
if (
|
|
773
|
+
"sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
774
|
+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
|
|
775
|
+
== patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
776
|
+
):
|
|
777
|
+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = f_transformers_sdpa_mask
|
|
778
|
+
if verbose:
|
|
779
|
+
print(
|
|
780
|
+
"[torch_export_patches] restored "
|
|
781
|
+
"transformers.masking_utils.sdpa_mask "
|
|
782
|
+
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
if ( # sdpa_attention_forward
|
|
786
|
+
sdpa_attention is not None
|
|
787
|
+
and modeling_utils is not None
|
|
788
|
+
and hasattr(sdpa_attention, "sdpa_attention_forward")
|
|
789
|
+
and hasattr(sdpa_attention, "use_gqa_in_sdpa")
|
|
790
|
+
and hasattr(modeling_utils, "AttentionInterface")
|
|
791
|
+
):
|
|
792
|
+
sdpa_attention.sdpa_attention_forward = f_transformers_sdpa_attention_forward
|
|
793
|
+
modeling_utils.sdpa_attention_forward = f_transformers_sdpa_attention_forward
|
|
794
|
+
modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
|
|
795
|
+
f_transformers_sdpa_attention_forward
|
|
796
|
+
)
|
|
797
|
+
if verbose:
|
|
798
|
+
print(
|
|
799
|
+
"[torch_export_patches] restored "
|
|
800
|
+
"transformers.integrations.sdpa_attention."
|
|
801
|
+
"sdpa_attention_forward"
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
unpatch_module_or_classes(patch_transformers_list, revert_patches_info, verbose=verbose)
|
|
805
|
+
|
|
806
|
+
|
|
160
807
|
@contextlib.contextmanager
|
|
161
808
|
def torch_export_patches(
|
|
162
809
|
patch_sympy: bool = True,
|
|
@@ -170,6 +817,7 @@ def torch_export_patches(
|
|
|
170
817
|
custom_patches: Optional[List[type["torch.nn.Module"]]] = None, # noqa: F821
|
|
171
818
|
rewrite: Optional[List[Callable]] = None,
|
|
172
819
|
dump_rewriting: Optional[str] = None,
|
|
820
|
+
patch_details: Optional[PatchDetails] = None,
|
|
173
821
|
) -> Callable:
|
|
174
822
|
"""
|
|
175
823
|
Tries to bypass some situations :func:`torch.export.export` does not support.
|
|
@@ -200,6 +848,7 @@ def torch_export_patches(
|
|
|
200
848
|
<onnx_diagnostic.torch_export_patches.patch_module.transform_method>`,
|
|
201
849
|
its documentation provides possible values
|
|
202
850
|
:param dump_rewriting: dumps rewriting information in file beginning with that prefix
|
|
851
|
+
:param patch_details: if specified, this class is used to stored every rewritten done.
|
|
203
852
|
:param verbose: to show which patches is applied
|
|
204
853
|
|
|
205
854
|
The list of available patches.
|
|
@@ -270,7 +919,10 @@ def torch_export_patches(
|
|
|
270
919
|
|
|
271
920
|
with (
|
|
272
921
|
torch_export_rewrite(
|
|
273
|
-
rewrite=rewrite,
|
|
922
|
+
rewrite=rewrite,
|
|
923
|
+
dump_rewriting=dump_rewriting,
|
|
924
|
+
verbose=verbose,
|
|
925
|
+
patch_details=patch_details,
|
|
274
926
|
),
|
|
275
927
|
torch_export_patches( # type: ignore[var-annotated]
|
|
276
928
|
patch_sympy=patch_sympy,
|
|
@@ -282,6 +934,7 @@ def torch_export_patches(
|
|
|
282
934
|
verbose=verbose,
|
|
283
935
|
patch=patch,
|
|
284
936
|
custom_patches=custom_patches,
|
|
937
|
+
patch_details=patch_details,
|
|
285
938
|
) as f,
|
|
286
939
|
):
|
|
287
940
|
try:
|
|
@@ -300,19 +953,13 @@ def torch_export_patches(
|
|
|
300
953
|
finally:
|
|
301
954
|
unregister_cache_serialization(done, verbose=verbose)
|
|
302
955
|
else:
|
|
303
|
-
import torch
|
|
304
|
-
import torch._export.non_strict_utils # produce_guards_and_solve_constraints
|
|
305
|
-
import torch.jit
|
|
306
|
-
|
|
307
956
|
if verbose:
|
|
308
957
|
print(
|
|
309
958
|
"[torch_export_patches] replace torch.jit.isinstance, "
|
|
310
959
|
"torch._dynamo.mark_static_address"
|
|
311
960
|
)
|
|
312
961
|
|
|
313
|
-
########
|
|
314
962
|
# caches
|
|
315
|
-
########
|
|
316
963
|
|
|
317
964
|
cache_done = register_cache_serialization(
|
|
318
965
|
patch_transformers=patch_transformers,
|
|
@@ -320,282 +967,50 @@ def torch_export_patches(
|
|
|
320
967
|
verbose=verbose,
|
|
321
968
|
)
|
|
322
969
|
|
|
323
|
-
|
|
324
|
-
# patch sympy
|
|
325
|
-
#############
|
|
970
|
+
# patches
|
|
326
971
|
|
|
327
972
|
if patch_sympy:
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
f_sympy_name = getattr(sympy.core.numbers.IntegerConstant, "name", None)
|
|
331
|
-
|
|
332
|
-
if verbose:
|
|
333
|
-
print(f"[torch_export_patches] sympy.__version__={sympy.__version__!r}")
|
|
334
|
-
print("[torch_export_patches] patch sympy")
|
|
335
|
-
|
|
336
|
-
sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}"
|
|
337
|
-
|
|
338
|
-
###############
|
|
339
|
-
# patch pytorch
|
|
340
|
-
###############
|
|
973
|
+
(f_sympy_name,) = _patch_sympy(verbose, patch_details)
|
|
341
974
|
|
|
342
975
|
if patch_torch:
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
# torch.vmap
|
|
363
|
-
f_vmap = torch.vmap
|
|
364
|
-
torch.vmap = patched_vmap
|
|
365
|
-
|
|
366
|
-
# torch.jit.isinstance
|
|
367
|
-
f_jit_isinstance = torch.jit.isinstance
|
|
368
|
-
torch.jit.isinstance = isinstance
|
|
369
|
-
|
|
370
|
-
# torch._dynamo.mark_static_address
|
|
371
|
-
f_mark_static_address = torch._dynamo.mark_static_address
|
|
372
|
-
torch._dynamo.mark_static_address = lambda *_, **y_: None
|
|
373
|
-
|
|
374
|
-
# torch._subclasses.fake_impls.infer_size
|
|
375
|
-
f_infer_size = torch._subclasses.fake_impls.infer_size
|
|
376
|
-
torch._subclasses.fake_impls.infer_size = patched_infer_size
|
|
377
|
-
|
|
378
|
-
# torch._refs._broadcast_shapes
|
|
379
|
-
f__broadcast_shapes = torch._refs._broadcast_shapes
|
|
380
|
-
torch._refs._broadcast_shapes = patched__broadcast_shapes
|
|
381
|
-
torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes
|
|
382
|
-
|
|
383
|
-
# torch._export.non_strict_utils._constrain_user_specified_dimhint_range
|
|
384
|
-
f___constrain_user_specified_dimhint_range = (
|
|
385
|
-
torch._export.non_strict_utils._constrain_user_specified_dimhint_range
|
|
386
|
-
)
|
|
387
|
-
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
|
|
388
|
-
patched__constrain_user_specified_dimhint_range
|
|
389
|
-
)
|
|
390
|
-
|
|
391
|
-
# torch._prims._broadcast_in_dim_meta
|
|
392
|
-
f_broadcast_in_dim = torch._prims.broadcast_in_dim
|
|
393
|
-
f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
|
|
394
|
-
_patched_dim_f = (
|
|
395
|
-
patched__broadcast_in_dim_meta_level_2
|
|
396
|
-
if patch_torch == 2
|
|
397
|
-
else patched__broadcast_in_dim_meta
|
|
976
|
+
(
|
|
977
|
+
f___constrain_user_specified_dimhint_range,
|
|
978
|
+
f__broadcast_in_dim_meta,
|
|
979
|
+
f__broadcast_shapes,
|
|
980
|
+
f__check_input_constraints_for_graph,
|
|
981
|
+
f__maybe_broadcast,
|
|
982
|
+
f_broadcast_in_dim,
|
|
983
|
+
f_infer_size,
|
|
984
|
+
f_jit_isinstance,
|
|
985
|
+
f_mark_static_address,
|
|
986
|
+
f_produce_guards_and_solve_constraints,
|
|
987
|
+
f_shape_env__check_frozen,
|
|
988
|
+
f_shape_env__evaluate_expr,
|
|
989
|
+
f_shape_env__log_guard,
|
|
990
|
+
f_shape_env__set_replacement,
|
|
991
|
+
f_vmap,
|
|
992
|
+
) = _patch_torch(
|
|
993
|
+
verbose, patch_details, patch_torch, catch_constraints, stop_if_static
|
|
398
994
|
)
|
|
399
|
-
torch._prims._broadcast_in_dim_meta = _patched_dim_f
|
|
400
|
-
torch._prims.broadcast_in_dim = _patched_dim_f
|
|
401
|
-
|
|
402
|
-
# torch._refs._maybe_broadcast
|
|
403
|
-
f__maybe_broadcast = torch._refs._maybe_broadcast
|
|
404
|
-
torch._refs._maybe_broadcast = patched__maybe_broadcast
|
|
405
|
-
|
|
406
|
-
# ShapeEnv
|
|
407
|
-
f_shape_env__evaluate_expr = ShapeEnv._evaluate_expr
|
|
408
|
-
ShapeEnv._evaluate_expr = patched_ShapeEnv._evaluate_expr
|
|
409
|
-
|
|
410
|
-
# torch._export.non_strict_utils.produce_guards_and_solve_constraints
|
|
411
|
-
if patch_torch and catch_constraints:
|
|
412
|
-
if verbose:
|
|
413
|
-
print("[torch_export_patches] modifies shape constraints")
|
|
414
|
-
f_produce_guards_and_solve_constraints = (
|
|
415
|
-
torch._export.non_strict_utils.produce_guards_and_solve_constraints
|
|
416
|
-
)
|
|
417
|
-
f__check_input_constraints_for_graph = (
|
|
418
|
-
torch._export.utils._check_input_constraints_for_graph
|
|
419
|
-
)
|
|
420
|
-
torch._export.non_strict_utils.produce_guards_and_solve_constraints = (
|
|
421
|
-
lambda *args, **kwargs: _catch_produce_guards_and_solve_constraints(
|
|
422
|
-
f_produce_guards_and_solve_constraints, *args, verbose=verbose, **kwargs
|
|
423
|
-
)
|
|
424
|
-
)
|
|
425
|
-
torch._export.utils._check_input_constraints_for_graph = (
|
|
426
|
-
lambda *args, **kwargs: patch__check_input_constraints_for_graph(
|
|
427
|
-
f__check_input_constraints_for_graph, *args, verbose=verbose, **kwargs
|
|
428
|
-
)
|
|
429
|
-
)
|
|
430
|
-
|
|
431
|
-
if patch_torch and stop_if_static:
|
|
432
|
-
ShapeEnv._log_guard_remember = ShapeEnv._log_guard
|
|
433
|
-
|
|
434
|
-
if verbose:
|
|
435
|
-
print("[torch_export_patches] assert when a dynamic dimension turns static")
|
|
436
|
-
print("[torch_export_patches] replaces ShapeEnv._set_replacement")
|
|
437
|
-
|
|
438
|
-
f_shape_env__set_replacement = ShapeEnv._set_replacement
|
|
439
|
-
ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement
|
|
440
|
-
|
|
441
|
-
if verbose:
|
|
442
|
-
print("[torch_export_patches] replaces ShapeEnv._log_guard")
|
|
443
|
-
f_shape_env__log_guard = ShapeEnv._log_guard
|
|
444
|
-
ShapeEnv._log_guard = patched_ShapeEnv._log_guard
|
|
445
|
-
|
|
446
|
-
if stop_if_static > 1:
|
|
447
|
-
if verbose:
|
|
448
|
-
print("[torch_export_patches] replaces ShapeEnv._check_frozen")
|
|
449
|
-
f_shape_env__check_frozen = ShapeEnv._check_frozen
|
|
450
|
-
ShapeEnv._check_frozen = patched_ShapeEnv._check_frozen
|
|
451
|
-
|
|
452
|
-
####################
|
|
453
|
-
# patch transformers
|
|
454
|
-
####################
|
|
455
995
|
|
|
456
996
|
if patch_transformers:
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
sdpa_attention = None
|
|
466
|
-
|
|
467
|
-
try:
|
|
468
|
-
import transformers.modeling_utils as modeling_utils
|
|
469
|
-
except ImportError:
|
|
470
|
-
modeling_utils = None
|
|
471
|
-
|
|
472
|
-
if verbose:
|
|
473
|
-
import transformers
|
|
474
|
-
|
|
475
|
-
print(
|
|
476
|
-
f"[torch_export_patches] transformers.__version__="
|
|
477
|
-
f"{transformers.__version__!r}"
|
|
478
|
-
)
|
|
479
|
-
revert_patches_info = patch_module_or_classes(
|
|
480
|
-
patch_transformers_list, verbose=verbose
|
|
481
|
-
)
|
|
482
|
-
|
|
483
|
-
if ( # vmap
|
|
484
|
-
masking_utils
|
|
485
|
-
and patch_transformers_list.patch_masking_utils
|
|
486
|
-
and hasattr(masking_utils, "_vmap_for_bhqkv")
|
|
487
|
-
):
|
|
488
|
-
if verbose:
|
|
489
|
-
print(
|
|
490
|
-
"[torch_export_patches] patches "
|
|
491
|
-
"transformers.masking_utils._vmap_for_bhqkv"
|
|
492
|
-
)
|
|
493
|
-
f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
|
|
494
|
-
masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
|
|
495
|
-
|
|
496
|
-
if verbose:
|
|
497
|
-
print(
|
|
498
|
-
"[torch_export_patches] patches "
|
|
499
|
-
"transformers.masking_utils.sdpa_mask_recent_torch"
|
|
500
|
-
)
|
|
501
|
-
f_transformers_sdpa_mask_recent_torch = masking_utils.sdpa_mask_recent_torch
|
|
502
|
-
masking_utils.sdpa_mask_recent_torch = (
|
|
503
|
-
patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
504
|
-
)
|
|
505
|
-
if masking_utils.sdpa_mask == f_transformers_sdpa_mask_recent_torch:
|
|
506
|
-
if verbose:
|
|
507
|
-
print(
|
|
508
|
-
"[torch_export_patches] patches "
|
|
509
|
-
"transformers.masking_utils.sdpa_mask"
|
|
510
|
-
)
|
|
511
|
-
f_transformers_sdpa_mask = masking_utils.sdpa_mask
|
|
512
|
-
masking_utils.sdpa_mask = (
|
|
513
|
-
patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
514
|
-
)
|
|
515
|
-
else:
|
|
516
|
-
f_transformers_sdpa_mask = None
|
|
517
|
-
|
|
518
|
-
if ( # eager_mask
|
|
519
|
-
masking_utils
|
|
520
|
-
and patch_transformers_list.patch_masking_utils
|
|
521
|
-
and hasattr(masking_utils, "eager_mask")
|
|
522
|
-
):
|
|
523
|
-
if verbose:
|
|
524
|
-
print(
|
|
525
|
-
"[torch_export_patches] patches "
|
|
526
|
-
"transformers.masking_utils.eager_mask"
|
|
527
|
-
)
|
|
528
|
-
f_transformers_eager_mask = masking_utils.eager_mask
|
|
529
|
-
masking_utils.eager_mask = patch_transformers_list.patched_eager_mask
|
|
530
|
-
if (
|
|
531
|
-
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
532
|
-
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
|
|
533
|
-
== f_transformers_eager_mask
|
|
534
|
-
):
|
|
535
|
-
if verbose:
|
|
536
|
-
print(
|
|
537
|
-
"[torch_export_patches] patches "
|
|
538
|
-
"transformers.masking_utils.eager_mask "
|
|
539
|
-
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
540
|
-
)
|
|
541
|
-
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
|
|
542
|
-
patch_transformers_list.patched_eager_mask
|
|
543
|
-
)
|
|
544
|
-
|
|
545
|
-
if ( # sdpa_mask
|
|
546
|
-
masking_utils
|
|
547
|
-
and patch_transformers_list.patch_masking_utils
|
|
548
|
-
and hasattr(masking_utils, "sdpa_mask")
|
|
549
|
-
and f_transformers_sdpa_mask is not None
|
|
550
|
-
):
|
|
551
|
-
if verbose:
|
|
552
|
-
print(
|
|
553
|
-
"[torch_export_patches] patches "
|
|
554
|
-
"transformers.masking_utils.sdpa_mask "
|
|
555
|
-
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
556
|
-
)
|
|
557
|
-
if (
|
|
558
|
-
"sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
559
|
-
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
|
|
560
|
-
== f_transformers_sdpa_mask
|
|
561
|
-
):
|
|
562
|
-
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
|
|
563
|
-
patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
564
|
-
)
|
|
565
|
-
|
|
566
|
-
if ( # sdpa_attention_forward
|
|
567
|
-
sdpa_attention is not None
|
|
568
|
-
and modeling_utils is not None
|
|
569
|
-
and hasattr(sdpa_attention, "sdpa_attention_forward")
|
|
570
|
-
and hasattr(sdpa_attention, "use_gqa_in_sdpa")
|
|
571
|
-
and hasattr(modeling_utils, "AttentionInterface")
|
|
572
|
-
):
|
|
573
|
-
if verbose:
|
|
574
|
-
print(
|
|
575
|
-
"[torch_export_patches] patches "
|
|
576
|
-
"transformers.integrations.sdpa_attention.sdpa_attention_forward"
|
|
577
|
-
)
|
|
578
|
-
f_sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
|
|
579
|
-
sdpa_attention.sdpa_attention_forward = (
|
|
580
|
-
patch_transformers_list.patched_sdpa_attention_forward
|
|
581
|
-
)
|
|
582
|
-
modeling_utils.sdpa_attention_forward = (
|
|
583
|
-
patch_transformers_list.patched_sdpa_attention_forward
|
|
584
|
-
)
|
|
585
|
-
modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
|
|
586
|
-
patch_transformers_list.patched_sdpa_attention_forward
|
|
587
|
-
)
|
|
997
|
+
(
|
|
998
|
+
f_transformers__vmap_for_bhqkv,
|
|
999
|
+
f_transformers_eager_mask,
|
|
1000
|
+
f_transformers_sdpa_attention_forward,
|
|
1001
|
+
f_transformers_sdpa_mask,
|
|
1002
|
+
f_transformers_sdpa_mask_recent_torch,
|
|
1003
|
+
revert_patches_info,
|
|
1004
|
+
) = _patch_transformers(verbose, patch_details)
|
|
588
1005
|
|
|
589
1006
|
if custom_patches:
|
|
590
1007
|
if verbose:
|
|
591
1008
|
print("[torch_export_patches] applies custom patches")
|
|
592
1009
|
revert_custom_patches_info = patch_module_or_classes(
|
|
593
|
-
custom_patches, verbose=verbose
|
|
1010
|
+
custom_patches, verbose=verbose, patch_details=patch_details
|
|
594
1011
|
)
|
|
595
1012
|
|
|
596
|
-
########
|
|
597
1013
|
# export
|
|
598
|
-
########
|
|
599
1014
|
|
|
600
1015
|
fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x)
|
|
601
1016
|
|
|
@@ -605,73 +1020,50 @@ def torch_export_patches(
|
|
|
605
1020
|
try:
|
|
606
1021
|
yield fct_callable
|
|
607
1022
|
finally:
|
|
608
|
-
|
|
609
|
-
#
|
|
610
|
-
#######
|
|
1023
|
+
|
|
1024
|
+
# unpatch
|
|
611
1025
|
|
|
612
1026
|
if verbose:
|
|
613
1027
|
print("[torch_export_patches] remove patches")
|
|
614
1028
|
|
|
615
1029
|
if patch_sympy:
|
|
616
|
-
|
|
617
|
-
if f_sympy_name:
|
|
618
|
-
sympy.core.numbers.IntegerConstant.name = f_sympy_name
|
|
619
|
-
else:
|
|
620
|
-
delattr(sympy.core.numbers.IntegerConstant, "name")
|
|
621
|
-
|
|
622
|
-
if verbose:
|
|
623
|
-
print("[torch_export_patches] restored sympy functions")
|
|
624
|
-
|
|
625
|
-
#######
|
|
626
|
-
# torch
|
|
627
|
-
#######
|
|
1030
|
+
_unpatch_sympy(verbose, f_sympy_name)
|
|
628
1031
|
|
|
629
1032
|
if patch_torch:
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
1033
|
+
_unpatch_torch(
|
|
1034
|
+
verbose,
|
|
1035
|
+
patch_details,
|
|
1036
|
+
patch_torch,
|
|
1037
|
+
catch_constraints,
|
|
1038
|
+
stop_if_static,
|
|
1039
|
+
f___constrain_user_specified_dimhint_range,
|
|
1040
|
+
f__broadcast_in_dim_meta,
|
|
1041
|
+
f__broadcast_shapes,
|
|
1042
|
+
f__check_input_constraints_for_graph,
|
|
1043
|
+
f__maybe_broadcast,
|
|
1044
|
+
f_broadcast_in_dim,
|
|
1045
|
+
f_infer_size,
|
|
1046
|
+
f_jit_isinstance,
|
|
1047
|
+
f_mark_static_address,
|
|
1048
|
+
f_produce_guards_and_solve_constraints,
|
|
1049
|
+
f_shape_env__check_frozen,
|
|
1050
|
+
f_shape_env__evaluate_expr,
|
|
1051
|
+
f_shape_env__log_guard,
|
|
1052
|
+
f_shape_env__set_replacement,
|
|
1053
|
+
f_vmap,
|
|
640
1054
|
)
|
|
641
|
-
torch._prims._broadcast_in_dim_meta = f__broadcast_in_dim_meta
|
|
642
|
-
torch._prims.broadcast_in_dim = f_broadcast_in_dim
|
|
643
|
-
torch._refs._maybe_broadcast = f__maybe_broadcast
|
|
644
|
-
ShapeEnv._evaluate_expr = f_shape_env__evaluate_expr
|
|
645
|
-
|
|
646
|
-
if verbose:
|
|
647
|
-
print("[torch_export_patches] restored pytorch functions")
|
|
648
|
-
|
|
649
|
-
if patch_torch and stop_if_static:
|
|
650
|
-
if verbose:
|
|
651
|
-
print("[torch_export_patches] restored ShapeEnv._set_replacement")
|
|
652
|
-
|
|
653
|
-
ShapeEnv._set_replacement = f_shape_env__set_replacement
|
|
654
|
-
|
|
655
|
-
if verbose:
|
|
656
|
-
print("[torch_export_patches] restored ShapeEnv._log_guard")
|
|
657
1055
|
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
f_produce_guards_and_solve_constraints
|
|
669
|
-
)
|
|
670
|
-
torch._export.utils._check_input_constraints_for_graph = (
|
|
671
|
-
f__check_input_constraints_for_graph
|
|
1056
|
+
if patch_transformers:
|
|
1057
|
+
_unpatch_transformers(
|
|
1058
|
+
verbose,
|
|
1059
|
+
patch_details,
|
|
1060
|
+
f_transformers__vmap_for_bhqkv,
|
|
1061
|
+
f_transformers_eager_mask,
|
|
1062
|
+
f_transformers_sdpa_attention_forward,
|
|
1063
|
+
f_transformers_sdpa_mask,
|
|
1064
|
+
f_transformers_sdpa_mask_recent_torch,
|
|
1065
|
+
revert_patches_info,
|
|
672
1066
|
)
|
|
673
|
-
if verbose:
|
|
674
|
-
print("[torch_export_patches] restored shape constraints")
|
|
675
1067
|
|
|
676
1068
|
if custom_patches:
|
|
677
1069
|
if verbose:
|
|
@@ -680,118 +1072,6 @@ def torch_export_patches(
|
|
|
680
1072
|
custom_patches, revert_custom_patches_info, verbose=verbose
|
|
681
1073
|
)
|
|
682
1074
|
|
|
683
|
-
##############
|
|
684
|
-
# transformers
|
|
685
|
-
##############
|
|
686
|
-
|
|
687
|
-
if patch_transformers:
|
|
688
|
-
try:
|
|
689
|
-
import transformers.masking_utils as masking_utils
|
|
690
|
-
except ImportError:
|
|
691
|
-
masking_utils = None
|
|
692
|
-
if verbose:
|
|
693
|
-
print("[torch_export_patches] unpatches transformers")
|
|
694
|
-
unpatch_module_or_classes(
|
|
695
|
-
patch_transformers_list, revert_patches_info, verbose=verbose
|
|
696
|
-
)
|
|
697
|
-
|
|
698
|
-
if ( # vmap
|
|
699
|
-
masking_utils
|
|
700
|
-
and patch_transformers_list.patch_masking_utils
|
|
701
|
-
and hasattr(masking_utils, "_vmap_for_bhqkv")
|
|
702
|
-
):
|
|
703
|
-
masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
|
|
704
|
-
|
|
705
|
-
if verbose:
|
|
706
|
-
print(
|
|
707
|
-
"[torch_export_patches] restored "
|
|
708
|
-
"transformers.masking_utils._vmap_for_bhqkv"
|
|
709
|
-
)
|
|
710
|
-
|
|
711
|
-
masking_utils.sdpa_mask_recent_torch = (
|
|
712
|
-
f_transformers_sdpa_mask_recent_torch
|
|
713
|
-
)
|
|
714
|
-
|
|
715
|
-
if verbose:
|
|
716
|
-
print(
|
|
717
|
-
"[torch_export_patches] restored "
|
|
718
|
-
"transformers.masking_utils.sdpa_mask_recent_torch"
|
|
719
|
-
)
|
|
720
|
-
|
|
721
|
-
if f_transformers_sdpa_mask is not None:
|
|
722
|
-
masking_utils.sdpa_mask = f_transformers_sdpa_mask
|
|
723
|
-
if verbose:
|
|
724
|
-
print(
|
|
725
|
-
"[torch_export_patches] restored "
|
|
726
|
-
"transformers.masking_utils.sdpa_mask"
|
|
727
|
-
)
|
|
728
|
-
|
|
729
|
-
if ( # eager_mask
|
|
730
|
-
masking_utils
|
|
731
|
-
and patch_transformers_list.patch_masking_utils
|
|
732
|
-
and hasattr(masking_utils, "eager_mask")
|
|
733
|
-
):
|
|
734
|
-
f_transformers_eager_mask = masking_utils.eager_mask
|
|
735
|
-
masking_utils.eager_mask = f_transformers_eager_mask
|
|
736
|
-
if verbose:
|
|
737
|
-
print(
|
|
738
|
-
"[torch_export_patches] restored "
|
|
739
|
-
"transformers.masking_utils.eager_mask"
|
|
740
|
-
)
|
|
741
|
-
if (
|
|
742
|
-
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
743
|
-
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
|
|
744
|
-
== patch_transformers_list.patched_eager_mask
|
|
745
|
-
):
|
|
746
|
-
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
|
|
747
|
-
f_transformers_eager_mask
|
|
748
|
-
)
|
|
749
|
-
if verbose:
|
|
750
|
-
print(
|
|
751
|
-
"[torch_export_patches] restored "
|
|
752
|
-
"transformers.masking_utils.eager_mask "
|
|
753
|
-
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
754
|
-
)
|
|
755
|
-
|
|
756
|
-
if ( # sdpa_mask
|
|
757
|
-
masking_utils
|
|
758
|
-
and patch_transformers_list.patch_masking_utils
|
|
759
|
-
and hasattr(masking_utils, "sdpa_mask")
|
|
760
|
-
):
|
|
761
|
-
if (
|
|
762
|
-
"sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
|
|
763
|
-
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
|
|
764
|
-
== patch_transformers_list.patched_sdpa_mask_recent_torch
|
|
765
|
-
):
|
|
766
|
-
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
|
|
767
|
-
f_transformers_sdpa_mask
|
|
768
|
-
)
|
|
769
|
-
if verbose:
|
|
770
|
-
print(
|
|
771
|
-
"[torch_export_patches] restored "
|
|
772
|
-
"transformers.masking_utils.sdpa_mask "
|
|
773
|
-
"in ALL_MASK_ATTENTION_FUNCTIONS"
|
|
774
|
-
)
|
|
775
|
-
|
|
776
|
-
if ( # sdpa_attention_forward
|
|
777
|
-
sdpa_attention is not None
|
|
778
|
-
and modeling_utils is not None
|
|
779
|
-
and hasattr(sdpa_attention, "sdpa_attention_forward")
|
|
780
|
-
and hasattr(sdpa_attention, "use_gqa_in_sdpa")
|
|
781
|
-
and hasattr(modeling_utils, "AttentionInterface")
|
|
782
|
-
):
|
|
783
|
-
sdpa_attention.sdpa_attention_forward = f_sdpa_attention_forward
|
|
784
|
-
modeling_utils.sdpa_attention_forward = f_sdpa_attention_forward
|
|
785
|
-
modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
|
|
786
|
-
f_sdpa_attention_forward
|
|
787
|
-
)
|
|
788
|
-
if verbose:
|
|
789
|
-
print(
|
|
790
|
-
"[torch_export_patches] restored "
|
|
791
|
-
"transformers.integrations.sdpa_attention."
|
|
792
|
-
"sdpa_attention_forward"
|
|
793
|
-
)
|
|
794
|
-
|
|
795
1075
|
########
|
|
796
1076
|
# caches
|
|
797
1077
|
########
|