onnx-diagnostic 0.7.16__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +78 -22
  3. onnx_diagnostic/export/api.py +124 -0
  4. onnx_diagnostic/export/dynamic_shapes.py +2 -1
  5. onnx_diagnostic/export/shape_helper.py +47 -70
  6. onnx_diagnostic/ext_test_case.py +11 -0
  7. onnx_diagnostic/helpers/cache_helper.py +38 -7
  8. onnx_diagnostic/helpers/fake_tensor_helper.py +224 -104
  9. onnx_diagnostic/helpers/helper.py +27 -33
  10. onnx_diagnostic/helpers/log_helper.py +109 -5
  11. onnx_diagnostic/helpers/memory_peak.py +2 -0
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +1 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +132 -2
  14. onnx_diagnostic/helpers/onnx_helper.py +1 -1
  15. onnx_diagnostic/helpers/ort_session.py +4 -0
  16. onnx_diagnostic/helpers/rt_helper.py +393 -43
  17. onnx_diagnostic/helpers/torch_helper.py +20 -1
  18. onnx_diagnostic/tasks/__init__.py +7 -0
  19. onnx_diagnostic/tasks/automatic_speech_recognition.py +2 -8
  20. onnx_diagnostic/tasks/feature_extraction.py +2 -8
  21. onnx_diagnostic/tasks/image_text_to_text.py +10 -8
  22. onnx_diagnostic/tasks/summarization.py +2 -8
  23. onnx_diagnostic/tasks/text2text_generation.py +3 -8
  24. onnx_diagnostic/tasks/text_generation.py +86 -65
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +718 -438
  26. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  27. onnx_diagnostic/torch_export_patches/patch_inputs.py +1 -1
  28. onnx_diagnostic/torch_export_patches/patch_module.py +9 -36
  29. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -6
  30. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +162 -24
  31. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +140 -104
  32. onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -4
  33. onnx_diagnostic/torch_models/validate.py +626 -228
  34. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/METADATA +1 -1
  35. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/RECORD +38 -36
  36. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/WHEEL +0 -0
  37. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/licenses/LICENSE.txt +0 -0
  38. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  import functools
2
2
  import importlib
3
+ import inspect
3
4
  import contextlib
4
5
  import re
5
6
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -8,6 +9,7 @@ from .onnx_export_serialization import (
8
9
  unregister_cache_serialization,
9
10
  )
10
11
  from .patches import patch_transformers as patch_transformers_list
12
+ from .patch_details import PatchDetails
11
13
 
12
14
 
13
15
  def get_function(name: str) -> Tuple[type, Callable]:
@@ -51,7 +53,9 @@ def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
51
53
  return name, to_patch
52
54
 
53
55
 
54
- def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
56
+ def patch_module_or_classes(
57
+ mod, verbose: int = 0, patch_details: Optional[PatchDetails] = None
58
+ ) -> Dict[type, Dict[type, Callable]]:
55
59
  """
56
60
  Applies all patches defined in classes prefixed by ``patched_``
57
61
  ``cls._PATCHED_CLASS_`` defines the class to patch,
@@ -61,13 +65,16 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
61
65
 
62
66
  :param mod: module of list of clsses to patch
63
67
  :param verbose: verbosity
68
+ :param patch_details: used to store information about the applied patches
64
69
  :return: patch info
65
70
  """
66
71
  if isinstance(mod, list):
67
72
  to_patch = mod
68
73
  name = "list"
74
+ list_name = "auto/list"
69
75
  else:
70
76
  name, to_patch = get_patches(mod, verbose)
77
+ list_name = f"auto/{mod.__name__.split('.')[-1]}"
71
78
 
72
79
  res = {}
73
80
  for cls in to_patch:
@@ -76,9 +83,15 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
76
83
  keep = {}
77
84
  original = cls["module"]
78
85
  f = cls["function"]
86
+ assert not f.__name__.startswith("patched_"), (
87
+ f"The function {f} was already patched or the patch was not removed, "
88
+ f"original={original}"
89
+ )
79
90
  res[f] = f
80
91
  if verbose:
81
92
  print(f"[patch_module_or_classes] function: {original.__name__}.{f.__name__}")
93
+ if patch_details:
94
+ patch_details.append(list_name, getattr(original, f.__name__), cls["patch"])
82
95
  setattr(original, f.__name__, cls["patch"])
83
96
  continue
84
97
 
@@ -89,6 +102,18 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
89
102
 
90
103
  keep = {n: getattr(original, n, None) for n in methods}
91
104
  for n in methods:
105
+ if patch_details:
106
+ if hasattr(original, n):
107
+ p = patch_details.append(list_name, getattr(original, n), getattr(cls, n))
108
+ else:
109
+ p = patch_details.append(
110
+ list_name, f"{original.__name__}{n}", getattr(cls, n)
111
+ )
112
+ if "@patched_dynamic_rope_update" in inspect.getsource(getattr(cls, n)):
113
+ # a tweak to include that patch.
114
+ f = patch_details.find("patched_dynamic_rope_update")
115
+ if f is not None:
116
+ p.add_dependency(f)
92
117
  setattr(original, n, getattr(cls, n))
93
118
  res[cls] = keep
94
119
 
@@ -157,6 +182,628 @@ def register_additional_serialization_functions(
157
182
  unregister_cache_serialization(done, verbose=verbose)
158
183
 
159
184
 
185
+ def _patch_sympy(verbose: int, patch_details: PatchDetails) -> Tuple[Optional[Callable], ...]:
186
+ import sympy
187
+
188
+ f_sympy_name = getattr(sympy.core.numbers.IntegerConstant, "name", None)
189
+
190
+ if verbose:
191
+ print(f"[torch_export_patches] sympy.__version__={sympy.__version__!r}")
192
+ print("[torch_export_patches] patch sympy")
193
+
194
+ sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}"
195
+ if patch_details:
196
+ patch_details.append(
197
+ "sympy",
198
+ f_sympy_name or "sympy.core.numbers.IntegerConstant.name",
199
+ sympy.core.numbers.IntegerConstant.name,
200
+ )
201
+ return (f_sympy_name,)
202
+
203
+
204
+ def _unpatch_sympy(verbose: int, f_sympy_name: Optional[Callable]):
205
+ # tracked by https://github.com/pytorch/pytorch/issues/143494
206
+ import sympy
207
+
208
+ if f_sympy_name:
209
+ sympy.core.numbers.IntegerConstant.name = f_sympy_name
210
+ else:
211
+ delattr(sympy.core.numbers.IntegerConstant, "name")
212
+
213
+ if verbose:
214
+ print("[torch_export_patches] restored sympy functions")
215
+
216
+
217
+ def _patch_torch(
218
+ verbose: int,
219
+ patch_details: PatchDetails,
220
+ patch_torch: int,
221
+ catch_constraints: bool,
222
+ stop_if_static: int,
223
+ ) -> Tuple[Optional[Callable], ...]:
224
+ import torch
225
+ import torch.jit
226
+ import torch._export.non_strict_utils # produce_guards_and_solve_constraints
227
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
228
+ from .patches.patch_torch import (
229
+ patched_infer_size,
230
+ patched_vmap,
231
+ patched__broadcast_shapes,
232
+ patched__constrain_user_specified_dimhint_range,
233
+ _catch_produce_guards_and_solve_constraints,
234
+ patch__check_input_constraints_for_graph,
235
+ patched__broadcast_in_dim_meta,
236
+ patched__broadcast_in_dim_meta_level_2,
237
+ patched__maybe_broadcast,
238
+ patched_ShapeEnv,
239
+ )
240
+
241
+ f___constrain_user_specified_dimhint_range = None
242
+ f__broadcast_in_dim_meta = None
243
+ f__broadcast_shapes = None
244
+ f__check_input_constraints_for_graph = None
245
+ f__maybe_broadcast = None
246
+ f_broadcast_in_dim = None
247
+ f_infer_size = None
248
+ f_jit_isinstance = None
249
+ f_mark_static_address = None
250
+ f_produce_guards_and_solve_constraints = None
251
+ f_shape_env__check_frozen = None
252
+ f_shape_env__evaluate_expr = None
253
+ f_shape_env__log_guard = None
254
+ f_shape_env__set_replacement = None
255
+ f_vmap = None
256
+
257
+ if verbose:
258
+ print(f"[torch_export_patches] torch.__version__={torch.__version__!r}")
259
+ print(f"[torch_export_patches] stop_if_static={stop_if_static!r}")
260
+ print("[torch_export_patches] patch pytorch")
261
+
262
+ # torch.vmap
263
+ f_vmap = torch.vmap
264
+ torch.vmap = patched_vmap
265
+
266
+ # torch.jit.isinstance
267
+ f_jit_isinstance = torch.jit.isinstance
268
+ torch.jit.isinstance = isinstance
269
+
270
+ # torch._dynamo.mark_static_address
271
+ f_mark_static_address = torch._dynamo.mark_static_address
272
+ torch._dynamo.mark_static_address = lambda *_, **y_: None
273
+
274
+ # torch._subclasses.fake_impls.infer_size
275
+ f_infer_size = torch._subclasses.fake_impls.infer_size
276
+ torch._subclasses.fake_impls.infer_size = patched_infer_size
277
+ if patch_details:
278
+ patch_details.append("torch", f_infer_size, patched_infer_size)
279
+
280
+ # torch._refs._broadcast_shapes
281
+ f__broadcast_shapes = torch._refs._broadcast_shapes
282
+ torch._refs._broadcast_shapes = patched__broadcast_shapes
283
+ torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes
284
+ if patch_details:
285
+ patch_details.append("torch", f__broadcast_shapes, patched__broadcast_shapes)
286
+
287
+ # torch._export.non_strict_utils._constrain_user_specified_dimhint_range
288
+ f___constrain_user_specified_dimhint_range = (
289
+ torch._export.non_strict_utils._constrain_user_specified_dimhint_range
290
+ )
291
+ torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
292
+ patched__constrain_user_specified_dimhint_range
293
+ )
294
+ if patch_details:
295
+ patch_details.append(
296
+ "torch",
297
+ f___constrain_user_specified_dimhint_range,
298
+ patched__constrain_user_specified_dimhint_range,
299
+ )
300
+
301
+ # torch._prims._broadcast_in_dim_meta
302
+ f_broadcast_in_dim = torch._prims.broadcast_in_dim
303
+ f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
304
+ _patched_dim_f = (
305
+ patched__broadcast_in_dim_meta_level_2
306
+ if patch_torch == 2
307
+ else patched__broadcast_in_dim_meta
308
+ )
309
+ torch._prims._broadcast_in_dim_meta = _patched_dim_f
310
+ torch._prims.broadcast_in_dim = _patched_dim_f
311
+ if patch_details:
312
+ patch_details.append("torch", f__broadcast_in_dim_meta, _patched_dim_f)
313
+
314
+ # torch._refs._maybe_broadcast
315
+ f__maybe_broadcast = torch._refs._maybe_broadcast
316
+ torch._refs._maybe_broadcast = patched__maybe_broadcast
317
+ if patch_details:
318
+ patch_details.append("torch", f__maybe_broadcast, patched__maybe_broadcast)
319
+
320
+ # ShapeEnv
321
+ f_shape_env__evaluate_expr = ShapeEnv._evaluate_expr
322
+ ShapeEnv._evaluate_expr = patched_ShapeEnv._evaluate_expr
323
+ if patch_details:
324
+ patch_details.append(
325
+ "torch", f_shape_env__evaluate_expr, patched_ShapeEnv._evaluate_expr
326
+ )
327
+
328
+ # torch._export.non_strict_utils.produce_guards_and_solve_constraints
329
+ if catch_constraints:
330
+ if verbose:
331
+ print("[torch_export_patches] modifies shape constraints")
332
+ f_produce_guards_and_solve_constraints = (
333
+ torch._export.non_strict_utils.produce_guards_and_solve_constraints
334
+ )
335
+ f__check_input_constraints_for_graph = (
336
+ torch._export.utils._check_input_constraints_for_graph
337
+ )
338
+ torch._export.non_strict_utils.produce_guards_and_solve_constraints = (
339
+ lambda *args, **kwargs: _catch_produce_guards_and_solve_constraints(
340
+ f_produce_guards_and_solve_constraints, *args, verbose=verbose, **kwargs
341
+ )
342
+ )
343
+ torch._export.utils._check_input_constraints_for_graph = (
344
+ lambda *args, **kwargs: patch__check_input_constraints_for_graph(
345
+ f__check_input_constraints_for_graph, *args, verbose=verbose, **kwargs
346
+ )
347
+ )
348
+
349
+ if patch_torch and stop_if_static:
350
+ ShapeEnv._log_guard_remember = ShapeEnv._log_guard
351
+
352
+ if verbose:
353
+ print("[torch_export_patches] assert when a dynamic dimension turns static")
354
+ print("[torch_export_patches] replaces ShapeEnv._set_replacement")
355
+
356
+ f_shape_env__set_replacement = ShapeEnv._set_replacement
357
+ ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement
358
+ if patch_details:
359
+ patch_details.append(
360
+ "torch", f_shape_env__set_replacement, patched_ShapeEnv._set_replacement
361
+ )
362
+
363
+ if verbose:
364
+ print("[torch_export_patches] replaces ShapeEnv._log_guard")
365
+ f_shape_env__log_guard = ShapeEnv._log_guard
366
+ ShapeEnv._log_guard = patched_ShapeEnv._log_guard
367
+ if patch_details:
368
+ patch_details.append("torch", f_shape_env__log_guard, patched_ShapeEnv._log_guard)
369
+
370
+ if stop_if_static > 1:
371
+ if verbose:
372
+ print("[torch_export_patches] replaces ShapeEnv._check_frozen")
373
+ f_shape_env__check_frozen = ShapeEnv._check_frozen
374
+ ShapeEnv._check_frozen = patched_ShapeEnv._check_frozen
375
+ if patch_details:
376
+ patch_details.append(
377
+ "torch", f_shape_env__check_frozen, ShapeEnv._check_frozen
378
+ )
379
+ return (
380
+ f___constrain_user_specified_dimhint_range,
381
+ f__broadcast_in_dim_meta,
382
+ f__broadcast_shapes,
383
+ f__check_input_constraints_for_graph,
384
+ f__maybe_broadcast,
385
+ f_broadcast_in_dim,
386
+ f_infer_size,
387
+ f_jit_isinstance,
388
+ f_mark_static_address,
389
+ f_produce_guards_and_solve_constraints,
390
+ f_shape_env__check_frozen,
391
+ f_shape_env__evaluate_expr,
392
+ f_shape_env__log_guard,
393
+ f_shape_env__set_replacement,
394
+ f_vmap,
395
+ )
396
+
397
+
398
+ def _unpatch_torch(
399
+ verbose: int,
400
+ _patch_details: PatchDetails,
401
+ patch_torch: int,
402
+ catch_constraints: bool,
403
+ stop_if_static: int,
404
+ f___constrain_user_specified_dimhint_range: Optional[Callable],
405
+ f__broadcast_in_dim_meta: Optional[Callable],
406
+ f__broadcast_shapes: Optional[Callable],
407
+ f__check_input_constraints_for_graph: Optional[Callable],
408
+ f__maybe_broadcast: Optional[Callable],
409
+ f_broadcast_in_dim: Optional[Callable],
410
+ f_infer_size: Optional[Callable],
411
+ f_jit_isinstance: Optional[Callable],
412
+ f_mark_static_address: Optional[Callable],
413
+ f_produce_guards_and_solve_constraints: Optional[Callable],
414
+ f_shape_env__check_frozen: Optional[Callable],
415
+ f_shape_env__evaluate_expr: Optional[Callable],
416
+ f_shape_env__log_guard: Optional[Callable],
417
+ f_shape_env__set_replacement: Optional[Callable],
418
+ f_vmap: Optional[Callable],
419
+ ):
420
+ import torch
421
+ import torch.jit
422
+ import torch._export.non_strict_utils # produce_guards_and_solve_constraints
423
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
424
+
425
+ # this should disappear when torch.jit is removed
426
+ torch.vmap = f_vmap
427
+ torch.jit.isinstance = f_jit_isinstance
428
+ torch._dynamo.mark_static_address = f_mark_static_address
429
+ # tracked by https://github.com/pytorch/pytorch/issues/143495
430
+ torch._subclasses.fake_impls.infer_size = f_infer_size
431
+ torch._refs._broadcast_shapes = f__broadcast_shapes
432
+ torch._meta_registrations._broadcast_shapes = f__broadcast_shapes
433
+ torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
434
+ f___constrain_user_specified_dimhint_range
435
+ )
436
+ torch._prims._broadcast_in_dim_meta = f__broadcast_in_dim_meta
437
+ torch._prims.broadcast_in_dim = f_broadcast_in_dim
438
+ torch._refs._maybe_broadcast = f__maybe_broadcast
439
+ ShapeEnv._evaluate_expr = f_shape_env__evaluate_expr
440
+
441
+ if verbose:
442
+ print("[torch_export_patches] restored pytorch functions")
443
+
444
+ if patch_torch and stop_if_static:
445
+ if verbose:
446
+ print("[torch_export_patches] restored ShapeEnv._set_replacement")
447
+
448
+ ShapeEnv._set_replacement = f_shape_env__set_replacement
449
+
450
+ if verbose:
451
+ print("[torch_export_patches] restored ShapeEnv._log_guard")
452
+
453
+ ShapeEnv._log_guard = f_shape_env__log_guard
454
+
455
+ if stop_if_static > 1:
456
+ if verbose:
457
+ print("[torch_export_patches] restored ShapeEnv._check_frozen")
458
+ ShapeEnv._check_frozen = f_shape_env__check_frozen
459
+
460
+ if patch_torch and catch_constraints:
461
+ # to catch or skip dynamic_shapes issues
462
+ torch._export.non_strict_utils.produce_guards_and_solve_constraints = (
463
+ f_produce_guards_and_solve_constraints
464
+ )
465
+ torch._export.utils._check_input_constraints_for_graph = (
466
+ f__check_input_constraints_for_graph
467
+ )
468
+ if verbose:
469
+ print("[torch_export_patches] restored shape constraints")
470
+
471
+
472
+ def _patch_transformers(
473
+ verbose: int, patch_details: PatchDetails
474
+ ) -> Tuple[Optional[Callable], ...]:
475
+ import transformers
476
+
477
+ try:
478
+ import transformers.masking_utils as masking_utils
479
+ except ImportError:
480
+ masking_utils = None
481
+
482
+ try:
483
+ import transformers.integrations.sdpa_attention as sdpa_attention
484
+ except ImportError:
485
+ sdpa_attention = None
486
+
487
+ try:
488
+ import transformers.modeling_utils as modeling_utils
489
+ except ImportError:
490
+ modeling_utils = None
491
+
492
+ try:
493
+ import transformers.modeling_rope_utils as modeling_rope_utils
494
+ except ImportError:
495
+ modeling_rope_utils = None
496
+
497
+ if (
498
+ patch_details
499
+ and modeling_rope_utils
500
+ and hasattr(modeling_rope_utils, "dynamic_rope_update")
501
+ ):
502
+ patch_details.append(
503
+ "patch_transformers",
504
+ modeling_rope_utils.dynamic_rope_update,
505
+ patch_transformers_list.patched_dynamic_rope_update,
506
+ )
507
+
508
+ if verbose:
509
+ print(f"[torch_export_patches] transformers.__version__={transformers.__version__!r}")
510
+ assert not sdpa_attention.sdpa_attention_forward.__name__.startswith("patched_"), (
511
+ f"Function 'sdpa_attention.sdpa_attention_forward' is already patched, "
512
+ f"sdpa_attention.sdpa_attention_forward={sdpa_attention.sdpa_attention_forward}"
513
+ )
514
+
515
+ f_transformers__vmap_for_bhqkv = None
516
+ f_transformers_eager_mask = None
517
+ f_transformers_sdpa_attention_forward = None
518
+ f_transformers_sdpa_mask = None
519
+ f_transformers_sdpa_mask_recent_torch = None
520
+
521
+ if ( # vmap
522
+ masking_utils
523
+ and patch_transformers_list.patch_masking_utils
524
+ and hasattr(masking_utils, "_vmap_for_bhqkv")
525
+ ):
526
+ if verbose:
527
+ print("[torch_export_patches] patches transformers.masking_utils._vmap_for_bhqkv")
528
+ f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
529
+ masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
530
+ if patch_details:
531
+ patch_details.append(
532
+ "transformers",
533
+ f_transformers__vmap_for_bhqkv,
534
+ patch_transformers_list.patched__vmap_for_bhqkv,
535
+ )
536
+
537
+ if verbose:
538
+ print(
539
+ "[torch_export_patches] patches "
540
+ "transformers.masking_utils.sdpa_mask_recent_torch"
541
+ )
542
+ f_transformers_sdpa_mask_recent_torch = masking_utils.sdpa_mask_recent_torch
543
+ masking_utils.sdpa_mask_recent_torch = (
544
+ patch_transformers_list.patched_sdpa_mask_recent_torch
545
+ )
546
+ if patch_details:
547
+ patch_details.append(
548
+ "transformers",
549
+ f_transformers_sdpa_mask_recent_torch,
550
+ patch_transformers_list.patched_sdpa_mask_recent_torch,
551
+ )
552
+ if masking_utils.sdpa_mask == f_transformers_sdpa_mask_recent_torch:
553
+ if verbose:
554
+ print("[torch_export_patches] patches transformers.masking_utils.sdpa_mask")
555
+ f_transformers_sdpa_mask = masking_utils.sdpa_mask
556
+ masking_utils.sdpa_mask = patch_transformers_list.patched_sdpa_mask_recent_torch
557
+ if patch_details:
558
+ patch_details.append(
559
+ "transformers",
560
+ f_transformers_sdpa_mask,
561
+ patch_transformers_list.patched_sdpa_mask_recent_torch,
562
+ )
563
+ else:
564
+ f_transformers_sdpa_mask = None
565
+
566
+ if ( # eager_mask
567
+ masking_utils
568
+ and patch_transformers_list.patch_masking_utils
569
+ and hasattr(masking_utils, "eager_mask")
570
+ ):
571
+ if verbose:
572
+ print("[torch_export_patches] patches transformers.masking_utils.eager_mask")
573
+ f_transformers_eager_mask = masking_utils.eager_mask
574
+ masking_utils.eager_mask = patch_transformers_list.patched_eager_mask
575
+ if patch_details:
576
+ patch_details.append(
577
+ "transformers",
578
+ f_transformers_eager_mask,
579
+ patch_transformers_list.patched_eager_mask,
580
+ )
581
+ if (
582
+ "eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
583
+ and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
584
+ == f_transformers_eager_mask
585
+ ):
586
+ if verbose:
587
+ print(
588
+ "[torch_export_patches] patches "
589
+ "transformers.masking_utils.eager_mask "
590
+ "in ALL_MASK_ATTENTION_FUNCTIONS"
591
+ )
592
+ masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
593
+ patch_transformers_list.patched_eager_mask
594
+ )
595
+
596
+ if ( # sdpa_mask
597
+ masking_utils
598
+ and patch_transformers_list.patch_masking_utils
599
+ and hasattr(masking_utils, "sdpa_mask")
600
+ and f_transformers_sdpa_mask is not None
601
+ ):
602
+ if verbose:
603
+ print(
604
+ "[torch_export_patches] patches "
605
+ "transformers.masking_utils.sdpa_mask "
606
+ "in ALL_MASK_ATTENTION_FUNCTIONS"
607
+ )
608
+ if (
609
+ "sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
610
+ and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] == f_transformers_sdpa_mask
611
+ ):
612
+ masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
613
+ patch_transformers_list.patched_sdpa_mask_recent_torch
614
+ )
615
+
616
+ if ( # sdpa_attention_forward
617
+ sdpa_attention is not None
618
+ and modeling_utils is not None
619
+ and hasattr(sdpa_attention, "sdpa_attention_forward")
620
+ and hasattr(sdpa_attention, "use_gqa_in_sdpa")
621
+ and hasattr(modeling_utils, "AttentionInterface")
622
+ ):
623
+ if verbose:
624
+ print(
625
+ "[torch_export_patches] patches "
626
+ "transformers.integrations.sdpa_attention.sdpa_attention_forward"
627
+ )
628
+ f_transformers_sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
629
+ assert not f_transformers_sdpa_attention_forward.__name__.startswith("patched_"), (
630
+ f"Function 'sdpa_attention.sdpa_attention_forward' is already patched, "
631
+ f"sdpa_attention.sdpa_attention_forward={f_transformers_sdpa_attention_forward}"
632
+ )
633
+ sdpa_attention.sdpa_attention_forward = (
634
+ patch_transformers_list.patched_sdpa_attention_forward
635
+ )
636
+ modeling_utils.sdpa_attention_forward = (
637
+ patch_transformers_list.patched_sdpa_attention_forward
638
+ )
639
+ modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
640
+ patch_transformers_list.patched_sdpa_attention_forward
641
+ )
642
+ if patch_details:
643
+ patch_details.append(
644
+ "transformers",
645
+ f_transformers_sdpa_attention_forward,
646
+ patch_transformers_list.patched_sdpa_attention_forward,
647
+ )
648
+
649
+ revert_patches_info = patch_module_or_classes(
650
+ patch_transformers_list, verbose=verbose, patch_details=patch_details
651
+ )
652
+
653
+ return (
654
+ f_transformers__vmap_for_bhqkv,
655
+ f_transformers_eager_mask,
656
+ f_transformers_sdpa_attention_forward,
657
+ f_transformers_sdpa_mask,
658
+ f_transformers_sdpa_mask_recent_torch,
659
+ revert_patches_info,
660
+ )
661
+
662
+
663
+ def _unpatch_transformers(
664
+ verbose: int,
665
+ _patch_details: PatchDetails,
666
+ f_transformers__vmap_for_bhqkv: Optional[Callable],
667
+ f_transformers_eager_mask: Optional[Callable],
668
+ f_transformers_sdpa_attention_forward: Optional[Callable],
669
+ f_transformers_sdpa_mask: Optional[Callable],
670
+ f_transformers_sdpa_mask_recent_torch: Optional[Callable],
671
+ revert_patches_info: Optional[Callable],
672
+ ):
673
+
674
+ try:
675
+ import transformers.masking_utils as masking_utils
676
+ except ImportError:
677
+ masking_utils = None
678
+
679
+ try:
680
+ import transformers.integrations.sdpa_attention as sdpa_attention
681
+ except ImportError:
682
+ sdpa_attention = None
683
+
684
+ try:
685
+ import transformers.modeling_utils as modeling_utils
686
+ except ImportError:
687
+ modeling_utils = None
688
+
689
+ try:
690
+ import transformers.masking_utils as masking_utils
691
+ except ImportError:
692
+ masking_utils = None
693
+ if verbose:
694
+ print("[torch_export_patches] unpatches transformers")
695
+
696
+ if ( # vmap
697
+ masking_utils
698
+ and patch_transformers_list.patch_masking_utils
699
+ and hasattr(masking_utils, "_vmap_for_bhqkv")
700
+ ):
701
+ assert f_transformers__vmap_for_bhqkv.__name__ == "_vmap_for_bhqkv", (
702
+ f"corrupted function '_vmap_for_bhqkv', its name is "
703
+ f"{f_transformers__vmap_for_bhqkv.__name__!r}"
704
+ )
705
+ masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
706
+
707
+ if verbose:
708
+ print("[torch_export_patches] restored transformers.masking_utils._vmap_for_bhqkv")
709
+
710
+ assert f_transformers_sdpa_mask_recent_torch.__name__ == "sdpa_mask_recent_torch", (
711
+ f"corrupted function 'sdpa_mask_recent_torch', its name is "
712
+ f"{f_transformers_sdpa_mask_recent_torch.__name__!r}"
713
+ )
714
+ masking_utils.sdpa_mask_recent_torch = f_transformers_sdpa_mask_recent_torch
715
+
716
+ if verbose:
717
+ print(
718
+ "[torch_export_patches] restored "
719
+ "transformers.masking_utils.sdpa_mask_recent_torch"
720
+ )
721
+
722
+ if f_transformers_sdpa_mask is not None:
723
+ assert f_transformers_sdpa_mask.__name__ in (
724
+ "sdpa_mask",
725
+ "sdpa_mask_recent_torch",
726
+ ), (
727
+ f"corrupted function 'sdpa_mask', its name is "
728
+ f"{f_transformers_sdpa_mask.__name__!r}"
729
+ )
730
+ masking_utils.sdpa_mask = f_transformers_sdpa_mask
731
+ if verbose:
732
+ print("[torch_export_patches] restored transformers.masking_utils.sdpa_mask")
733
+
734
+ if ( # eager_mask
735
+ masking_utils
736
+ and patch_transformers_list.patch_masking_utils
737
+ and hasattr(masking_utils, "eager_mask")
738
+ ):
739
+ assert f_transformers_eager_mask.__name__ == "eager_mask", (
740
+ f"corrupted function 'eager_mask', its name is "
741
+ f"{f_transformers_eager_mask.__name__!r}"
742
+ )
743
+ masking_utils.eager_mask = f_transformers_eager_mask
744
+ if verbose:
745
+ print("[torch_export_patches] restored transformers.masking_utils.eager_mask")
746
+ assert masking_utils.eager_mask.__name__ == "eager_mask", (
747
+ f"corrupted function 'eager_mask', its name is "
748
+ f"{masking_utils.eager_mask.__name__!r}"
749
+ )
750
+ if (
751
+ "eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
752
+ and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
753
+ == patch_transformers_list.patched_eager_mask
754
+ ):
755
+ masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = f_transformers_eager_mask
756
+ if verbose:
757
+ print(
758
+ "[torch_export_patches] restored "
759
+ "transformers.masking_utils.eager_mask "
760
+ "in ALL_MASK_ATTENTION_FUNCTIONS"
761
+ )
762
+ assert masking_utils.eager_mask.__name__ == "eager_mask", (
763
+ f"corrupted function 'eager_mask', its name is "
764
+ f"{masking_utils.eager_mask.__name__!r}"
765
+ )
766
+
767
+ if ( # sdpa_mask
768
+ masking_utils
769
+ and patch_transformers_list.patch_masking_utils
770
+ and hasattr(masking_utils, "sdpa_mask")
771
+ ):
772
+ if (
773
+ "sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
774
+ and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
775
+ == patch_transformers_list.patched_sdpa_mask_recent_torch
776
+ ):
777
+ masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = f_transformers_sdpa_mask
778
+ if verbose:
779
+ print(
780
+ "[torch_export_patches] restored "
781
+ "transformers.masking_utils.sdpa_mask "
782
+ "in ALL_MASK_ATTENTION_FUNCTIONS"
783
+ )
784
+
785
+ if ( # sdpa_attention_forward
786
+ sdpa_attention is not None
787
+ and modeling_utils is not None
788
+ and hasattr(sdpa_attention, "sdpa_attention_forward")
789
+ and hasattr(sdpa_attention, "use_gqa_in_sdpa")
790
+ and hasattr(modeling_utils, "AttentionInterface")
791
+ ):
792
+ sdpa_attention.sdpa_attention_forward = f_transformers_sdpa_attention_forward
793
+ modeling_utils.sdpa_attention_forward = f_transformers_sdpa_attention_forward
794
+ modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
795
+ f_transformers_sdpa_attention_forward
796
+ )
797
+ if verbose:
798
+ print(
799
+ "[torch_export_patches] restored "
800
+ "transformers.integrations.sdpa_attention."
801
+ "sdpa_attention_forward"
802
+ )
803
+
804
+ unpatch_module_or_classes(patch_transformers_list, revert_patches_info, verbose=verbose)
805
+
806
+
160
807
  @contextlib.contextmanager
161
808
  def torch_export_patches(
162
809
  patch_sympy: bool = True,
@@ -170,6 +817,7 @@ def torch_export_patches(
170
817
  custom_patches: Optional[List[type["torch.nn.Module"]]] = None, # noqa: F821
171
818
  rewrite: Optional[List[Callable]] = None,
172
819
  dump_rewriting: Optional[str] = None,
820
+ patch_details: Optional[PatchDetails] = None,
173
821
  ) -> Callable:
174
822
  """
175
823
  Tries to bypass some situations :func:`torch.export.export` does not support.
@@ -200,6 +848,7 @@ def torch_export_patches(
200
848
  <onnx_diagnostic.torch_export_patches.patch_module.transform_method>`,
201
849
  its documentation provides possible values
202
850
  :param dump_rewriting: dumps rewriting information in file beginning with that prefix
851
+ :param patch_details: if specified, this class is used to stored every rewritten done.
203
852
  :param verbose: to show which patches is applied
204
853
 
205
854
  The list of available patches.
@@ -270,7 +919,10 @@ def torch_export_patches(
270
919
 
271
920
  with (
272
921
  torch_export_rewrite(
273
- rewrite=rewrite, dump_rewriting=dump_rewriting, verbose=verbose
922
+ rewrite=rewrite,
923
+ dump_rewriting=dump_rewriting,
924
+ verbose=verbose,
925
+ patch_details=patch_details,
274
926
  ),
275
927
  torch_export_patches( # type: ignore[var-annotated]
276
928
  patch_sympy=patch_sympy,
@@ -282,6 +934,7 @@ def torch_export_patches(
282
934
  verbose=verbose,
283
935
  patch=patch,
284
936
  custom_patches=custom_patches,
937
+ patch_details=patch_details,
285
938
  ) as f,
286
939
  ):
287
940
  try:
@@ -300,19 +953,13 @@ def torch_export_patches(
300
953
  finally:
301
954
  unregister_cache_serialization(done, verbose=verbose)
302
955
  else:
303
- import torch
304
- import torch._export.non_strict_utils # produce_guards_and_solve_constraints
305
- import torch.jit
306
-
307
956
  if verbose:
308
957
  print(
309
958
  "[torch_export_patches] replace torch.jit.isinstance, "
310
959
  "torch._dynamo.mark_static_address"
311
960
  )
312
961
 
313
- ########
314
962
  # caches
315
- ########
316
963
 
317
964
  cache_done = register_cache_serialization(
318
965
  patch_transformers=patch_transformers,
@@ -320,282 +967,50 @@ def torch_export_patches(
320
967
  verbose=verbose,
321
968
  )
322
969
 
323
- #############
324
- # patch sympy
325
- #############
970
+ # patches
326
971
 
327
972
  if patch_sympy:
328
- import sympy
329
-
330
- f_sympy_name = getattr(sympy.core.numbers.IntegerConstant, "name", None)
331
-
332
- if verbose:
333
- print(f"[torch_export_patches] sympy.__version__={sympy.__version__!r}")
334
- print("[torch_export_patches] patch sympy")
335
-
336
- sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}"
337
-
338
- ###############
339
- # patch pytorch
340
- ###############
973
+ (f_sympy_name,) = _patch_sympy(verbose, patch_details)
341
974
 
342
975
  if patch_torch:
343
- from torch.fx.experimental.symbolic_shapes import ShapeEnv
344
- from .patches.patch_torch import (
345
- patched_infer_size,
346
- patched_vmap,
347
- patched__broadcast_shapes,
348
- patched__constrain_user_specified_dimhint_range,
349
- _catch_produce_guards_and_solve_constraints,
350
- patch__check_input_constraints_for_graph,
351
- patched__broadcast_in_dim_meta,
352
- patched__broadcast_in_dim_meta_level_2,
353
- patched__maybe_broadcast,
354
- patched_ShapeEnv,
355
- )
356
-
357
- if verbose:
358
- print(f"[torch_export_patches] torch.__version__={torch.__version__!r}")
359
- print(f"[torch_export_patches] stop_if_static={stop_if_static!r}")
360
- print("[torch_export_patches] patch pytorch")
361
-
362
- # torch.vmap
363
- f_vmap = torch.vmap
364
- torch.vmap = patched_vmap
365
-
366
- # torch.jit.isinstance
367
- f_jit_isinstance = torch.jit.isinstance
368
- torch.jit.isinstance = isinstance
369
-
370
- # torch._dynamo.mark_static_address
371
- f_mark_static_address = torch._dynamo.mark_static_address
372
- torch._dynamo.mark_static_address = lambda *_, **y_: None
373
-
374
- # torch._subclasses.fake_impls.infer_size
375
- f_infer_size = torch._subclasses.fake_impls.infer_size
376
- torch._subclasses.fake_impls.infer_size = patched_infer_size
377
-
378
- # torch._refs._broadcast_shapes
379
- f__broadcast_shapes = torch._refs._broadcast_shapes
380
- torch._refs._broadcast_shapes = patched__broadcast_shapes
381
- torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes
382
-
383
- # torch._export.non_strict_utils._constrain_user_specified_dimhint_range
384
- f___constrain_user_specified_dimhint_range = (
385
- torch._export.non_strict_utils._constrain_user_specified_dimhint_range
386
- )
387
- torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
388
- patched__constrain_user_specified_dimhint_range
389
- )
390
-
391
- # torch._prims._broadcast_in_dim_meta
392
- f_broadcast_in_dim = torch._prims.broadcast_in_dim
393
- f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
394
- _patched_dim_f = (
395
- patched__broadcast_in_dim_meta_level_2
396
- if patch_torch == 2
397
- else patched__broadcast_in_dim_meta
976
+ (
977
+ f___constrain_user_specified_dimhint_range,
978
+ f__broadcast_in_dim_meta,
979
+ f__broadcast_shapes,
980
+ f__check_input_constraints_for_graph,
981
+ f__maybe_broadcast,
982
+ f_broadcast_in_dim,
983
+ f_infer_size,
984
+ f_jit_isinstance,
985
+ f_mark_static_address,
986
+ f_produce_guards_and_solve_constraints,
987
+ f_shape_env__check_frozen,
988
+ f_shape_env__evaluate_expr,
989
+ f_shape_env__log_guard,
990
+ f_shape_env__set_replacement,
991
+ f_vmap,
992
+ ) = _patch_torch(
993
+ verbose, patch_details, patch_torch, catch_constraints, stop_if_static
398
994
  )
399
- torch._prims._broadcast_in_dim_meta = _patched_dim_f
400
- torch._prims.broadcast_in_dim = _patched_dim_f
401
-
402
- # torch._refs._maybe_broadcast
403
- f__maybe_broadcast = torch._refs._maybe_broadcast
404
- torch._refs._maybe_broadcast = patched__maybe_broadcast
405
-
406
- # ShapeEnv
407
- f_shape_env__evaluate_expr = ShapeEnv._evaluate_expr
408
- ShapeEnv._evaluate_expr = patched_ShapeEnv._evaluate_expr
409
-
410
- # torch._export.non_strict_utils.produce_guards_and_solve_constraints
411
- if patch_torch and catch_constraints:
412
- if verbose:
413
- print("[torch_export_patches] modifies shape constraints")
414
- f_produce_guards_and_solve_constraints = (
415
- torch._export.non_strict_utils.produce_guards_and_solve_constraints
416
- )
417
- f__check_input_constraints_for_graph = (
418
- torch._export.utils._check_input_constraints_for_graph
419
- )
420
- torch._export.non_strict_utils.produce_guards_and_solve_constraints = (
421
- lambda *args, **kwargs: _catch_produce_guards_and_solve_constraints(
422
- f_produce_guards_and_solve_constraints, *args, verbose=verbose, **kwargs
423
- )
424
- )
425
- torch._export.utils._check_input_constraints_for_graph = (
426
- lambda *args, **kwargs: patch__check_input_constraints_for_graph(
427
- f__check_input_constraints_for_graph, *args, verbose=verbose, **kwargs
428
- )
429
- )
430
-
431
- if patch_torch and stop_if_static:
432
- ShapeEnv._log_guard_remember = ShapeEnv._log_guard
433
-
434
- if verbose:
435
- print("[torch_export_patches] assert when a dynamic dimension turns static")
436
- print("[torch_export_patches] replaces ShapeEnv._set_replacement")
437
-
438
- f_shape_env__set_replacement = ShapeEnv._set_replacement
439
- ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement
440
-
441
- if verbose:
442
- print("[torch_export_patches] replaces ShapeEnv._log_guard")
443
- f_shape_env__log_guard = ShapeEnv._log_guard
444
- ShapeEnv._log_guard = patched_ShapeEnv._log_guard
445
-
446
- if stop_if_static > 1:
447
- if verbose:
448
- print("[torch_export_patches] replaces ShapeEnv._check_frozen")
449
- f_shape_env__check_frozen = ShapeEnv._check_frozen
450
- ShapeEnv._check_frozen = patched_ShapeEnv._check_frozen
451
-
452
- ####################
453
- # patch transformers
454
- ####################
455
995
 
456
996
  if patch_transformers:
457
- try:
458
- import transformers.masking_utils as masking_utils
459
- except ImportError:
460
- masking_utils = None
461
-
462
- try:
463
- import transformers.integrations.sdpa_attention as sdpa_attention
464
- except ImportError:
465
- sdpa_attention = None
466
-
467
- try:
468
- import transformers.modeling_utils as modeling_utils
469
- except ImportError:
470
- modeling_utils = None
471
-
472
- if verbose:
473
- import transformers
474
-
475
- print(
476
- f"[torch_export_patches] transformers.__version__="
477
- f"{transformers.__version__!r}"
478
- )
479
- revert_patches_info = patch_module_or_classes(
480
- patch_transformers_list, verbose=verbose
481
- )
482
-
483
- if ( # vmap
484
- masking_utils
485
- and patch_transformers_list.patch_masking_utils
486
- and hasattr(masking_utils, "_vmap_for_bhqkv")
487
- ):
488
- if verbose:
489
- print(
490
- "[torch_export_patches] patches "
491
- "transformers.masking_utils._vmap_for_bhqkv"
492
- )
493
- f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
494
- masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
495
-
496
- if verbose:
497
- print(
498
- "[torch_export_patches] patches "
499
- "transformers.masking_utils.sdpa_mask_recent_torch"
500
- )
501
- f_transformers_sdpa_mask_recent_torch = masking_utils.sdpa_mask_recent_torch
502
- masking_utils.sdpa_mask_recent_torch = (
503
- patch_transformers_list.patched_sdpa_mask_recent_torch
504
- )
505
- if masking_utils.sdpa_mask == f_transformers_sdpa_mask_recent_torch:
506
- if verbose:
507
- print(
508
- "[torch_export_patches] patches "
509
- "transformers.masking_utils.sdpa_mask"
510
- )
511
- f_transformers_sdpa_mask = masking_utils.sdpa_mask
512
- masking_utils.sdpa_mask = (
513
- patch_transformers_list.patched_sdpa_mask_recent_torch
514
- )
515
- else:
516
- f_transformers_sdpa_mask = None
517
-
518
- if ( # eager_mask
519
- masking_utils
520
- and patch_transformers_list.patch_masking_utils
521
- and hasattr(masking_utils, "eager_mask")
522
- ):
523
- if verbose:
524
- print(
525
- "[torch_export_patches] patches "
526
- "transformers.masking_utils.eager_mask"
527
- )
528
- f_transformers_eager_mask = masking_utils.eager_mask
529
- masking_utils.eager_mask = patch_transformers_list.patched_eager_mask
530
- if (
531
- "eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
532
- and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
533
- == f_transformers_eager_mask
534
- ):
535
- if verbose:
536
- print(
537
- "[torch_export_patches] patches "
538
- "transformers.masking_utils.eager_mask "
539
- "in ALL_MASK_ATTENTION_FUNCTIONS"
540
- )
541
- masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
542
- patch_transformers_list.patched_eager_mask
543
- )
544
-
545
- if ( # sdpa_mask
546
- masking_utils
547
- and patch_transformers_list.patch_masking_utils
548
- and hasattr(masking_utils, "sdpa_mask")
549
- and f_transformers_sdpa_mask is not None
550
- ):
551
- if verbose:
552
- print(
553
- "[torch_export_patches] patches "
554
- "transformers.masking_utils.sdpa_mask "
555
- "in ALL_MASK_ATTENTION_FUNCTIONS"
556
- )
557
- if (
558
- "sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
559
- and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
560
- == f_transformers_sdpa_mask
561
- ):
562
- masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
563
- patch_transformers_list.patched_sdpa_mask_recent_torch
564
- )
565
-
566
- if ( # sdpa_attention_forward
567
- sdpa_attention is not None
568
- and modeling_utils is not None
569
- and hasattr(sdpa_attention, "sdpa_attention_forward")
570
- and hasattr(sdpa_attention, "use_gqa_in_sdpa")
571
- and hasattr(modeling_utils, "AttentionInterface")
572
- ):
573
- if verbose:
574
- print(
575
- "[torch_export_patches] patches "
576
- "transformers.integrations.sdpa_attention.sdpa_attention_forward"
577
- )
578
- f_sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
579
- sdpa_attention.sdpa_attention_forward = (
580
- patch_transformers_list.patched_sdpa_attention_forward
581
- )
582
- modeling_utils.sdpa_attention_forward = (
583
- patch_transformers_list.patched_sdpa_attention_forward
584
- )
585
- modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
586
- patch_transformers_list.patched_sdpa_attention_forward
587
- )
997
+ (
998
+ f_transformers__vmap_for_bhqkv,
999
+ f_transformers_eager_mask,
1000
+ f_transformers_sdpa_attention_forward,
1001
+ f_transformers_sdpa_mask,
1002
+ f_transformers_sdpa_mask_recent_torch,
1003
+ revert_patches_info,
1004
+ ) = _patch_transformers(verbose, patch_details)
588
1005
 
589
1006
  if custom_patches:
590
1007
  if verbose:
591
1008
  print("[torch_export_patches] applies custom patches")
592
1009
  revert_custom_patches_info = patch_module_or_classes(
593
- custom_patches, verbose=verbose
1010
+ custom_patches, verbose=verbose, patch_details=patch_details
594
1011
  )
595
1012
 
596
- ########
597
1013
  # export
598
- ########
599
1014
 
600
1015
  fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x)
601
1016
 
@@ -605,73 +1020,50 @@ def torch_export_patches(
605
1020
  try:
606
1021
  yield fct_callable
607
1022
  finally:
608
- #######
609
- # sympy
610
- #######
1023
+
1024
+ # unpatch
611
1025
 
612
1026
  if verbose:
613
1027
  print("[torch_export_patches] remove patches")
614
1028
 
615
1029
  if patch_sympy:
616
- # tracked by https://github.com/pytorch/pytorch/issues/143494
617
- if f_sympy_name:
618
- sympy.core.numbers.IntegerConstant.name = f_sympy_name
619
- else:
620
- delattr(sympy.core.numbers.IntegerConstant, "name")
621
-
622
- if verbose:
623
- print("[torch_export_patches] restored sympy functions")
624
-
625
- #######
626
- # torch
627
- #######
1030
+ _unpatch_sympy(verbose, f_sympy_name)
628
1031
 
629
1032
  if patch_torch:
630
- # this should disappear when torch.jit is removed
631
- torch.vmap = f_vmap
632
- torch.jit.isinstance = f_jit_isinstance
633
- torch._dynamo.mark_static_address = f_mark_static_address
634
- # tracked by https://github.com/pytorch/pytorch/issues/143495
635
- torch._subclasses.fake_impls.infer_size = f_infer_size
636
- torch._refs._broadcast_shapes = f__broadcast_shapes
637
- torch._meta_registrations._broadcast_shapes = f__broadcast_shapes
638
- torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
639
- f___constrain_user_specified_dimhint_range
1033
+ _unpatch_torch(
1034
+ verbose,
1035
+ patch_details,
1036
+ patch_torch,
1037
+ catch_constraints,
1038
+ stop_if_static,
1039
+ f___constrain_user_specified_dimhint_range,
1040
+ f__broadcast_in_dim_meta,
1041
+ f__broadcast_shapes,
1042
+ f__check_input_constraints_for_graph,
1043
+ f__maybe_broadcast,
1044
+ f_broadcast_in_dim,
1045
+ f_infer_size,
1046
+ f_jit_isinstance,
1047
+ f_mark_static_address,
1048
+ f_produce_guards_and_solve_constraints,
1049
+ f_shape_env__check_frozen,
1050
+ f_shape_env__evaluate_expr,
1051
+ f_shape_env__log_guard,
1052
+ f_shape_env__set_replacement,
1053
+ f_vmap,
640
1054
  )
641
- torch._prims._broadcast_in_dim_meta = f__broadcast_in_dim_meta
642
- torch._prims.broadcast_in_dim = f_broadcast_in_dim
643
- torch._refs._maybe_broadcast = f__maybe_broadcast
644
- ShapeEnv._evaluate_expr = f_shape_env__evaluate_expr
645
-
646
- if verbose:
647
- print("[torch_export_patches] restored pytorch functions")
648
-
649
- if patch_torch and stop_if_static:
650
- if verbose:
651
- print("[torch_export_patches] restored ShapeEnv._set_replacement")
652
-
653
- ShapeEnv._set_replacement = f_shape_env__set_replacement
654
-
655
- if verbose:
656
- print("[torch_export_patches] restored ShapeEnv._log_guard")
657
1055
 
658
- ShapeEnv._log_guard = f_shape_env__log_guard
659
-
660
- if stop_if_static > 1:
661
- if verbose:
662
- print("[torch_export_patches] restored ShapeEnv._check_frozen")
663
- ShapeEnv._check_frozen = f_shape_env__check_frozen
664
-
665
- if patch_torch and catch_constraints:
666
- # to catch or skip dynamic_shapes issues
667
- torch._export.non_strict_utils.produce_guards_and_solve_constraints = (
668
- f_produce_guards_and_solve_constraints
669
- )
670
- torch._export.utils._check_input_constraints_for_graph = (
671
- f__check_input_constraints_for_graph
1056
+ if patch_transformers:
1057
+ _unpatch_transformers(
1058
+ verbose,
1059
+ patch_details,
1060
+ f_transformers__vmap_for_bhqkv,
1061
+ f_transformers_eager_mask,
1062
+ f_transformers_sdpa_attention_forward,
1063
+ f_transformers_sdpa_mask,
1064
+ f_transformers_sdpa_mask_recent_torch,
1065
+ revert_patches_info,
672
1066
  )
673
- if verbose:
674
- print("[torch_export_patches] restored shape constraints")
675
1067
 
676
1068
  if custom_patches:
677
1069
  if verbose:
@@ -680,118 +1072,6 @@ def torch_export_patches(
680
1072
  custom_patches, revert_custom_patches_info, verbose=verbose
681
1073
  )
682
1074
 
683
- ##############
684
- # transformers
685
- ##############
686
-
687
- if patch_transformers:
688
- try:
689
- import transformers.masking_utils as masking_utils
690
- except ImportError:
691
- masking_utils = None
692
- if verbose:
693
- print("[torch_export_patches] unpatches transformers")
694
- unpatch_module_or_classes(
695
- patch_transformers_list, revert_patches_info, verbose=verbose
696
- )
697
-
698
- if ( # vmap
699
- masking_utils
700
- and patch_transformers_list.patch_masking_utils
701
- and hasattr(masking_utils, "_vmap_for_bhqkv")
702
- ):
703
- masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
704
-
705
- if verbose:
706
- print(
707
- "[torch_export_patches] restored "
708
- "transformers.masking_utils._vmap_for_bhqkv"
709
- )
710
-
711
- masking_utils.sdpa_mask_recent_torch = (
712
- f_transformers_sdpa_mask_recent_torch
713
- )
714
-
715
- if verbose:
716
- print(
717
- "[torch_export_patches] restored "
718
- "transformers.masking_utils.sdpa_mask_recent_torch"
719
- )
720
-
721
- if f_transformers_sdpa_mask is not None:
722
- masking_utils.sdpa_mask = f_transformers_sdpa_mask
723
- if verbose:
724
- print(
725
- "[torch_export_patches] restored "
726
- "transformers.masking_utils.sdpa_mask"
727
- )
728
-
729
- if ( # eager_mask
730
- masking_utils
731
- and patch_transformers_list.patch_masking_utils
732
- and hasattr(masking_utils, "eager_mask")
733
- ):
734
- f_transformers_eager_mask = masking_utils.eager_mask
735
- masking_utils.eager_mask = f_transformers_eager_mask
736
- if verbose:
737
- print(
738
- "[torch_export_patches] restored "
739
- "transformers.masking_utils.eager_mask"
740
- )
741
- if (
742
- "eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
743
- and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
744
- == patch_transformers_list.patched_eager_mask
745
- ):
746
- masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
747
- f_transformers_eager_mask
748
- )
749
- if verbose:
750
- print(
751
- "[torch_export_patches] restored "
752
- "transformers.masking_utils.eager_mask "
753
- "in ALL_MASK_ATTENTION_FUNCTIONS"
754
- )
755
-
756
- if ( # sdpa_mask
757
- masking_utils
758
- and patch_transformers_list.patch_masking_utils
759
- and hasattr(masking_utils, "sdpa_mask")
760
- ):
761
- if (
762
- "sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
763
- and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
764
- == patch_transformers_list.patched_sdpa_mask_recent_torch
765
- ):
766
- masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
767
- f_transformers_sdpa_mask
768
- )
769
- if verbose:
770
- print(
771
- "[torch_export_patches] restored "
772
- "transformers.masking_utils.sdpa_mask "
773
- "in ALL_MASK_ATTENTION_FUNCTIONS"
774
- )
775
-
776
- if ( # sdpa_attention_forward
777
- sdpa_attention is not None
778
- and modeling_utils is not None
779
- and hasattr(sdpa_attention, "sdpa_attention_forward")
780
- and hasattr(sdpa_attention, "use_gqa_in_sdpa")
781
- and hasattr(modeling_utils, "AttentionInterface")
782
- ):
783
- sdpa_attention.sdpa_attention_forward = f_sdpa_attention_forward
784
- modeling_utils.sdpa_attention_forward = f_sdpa_attention_forward
785
- modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
786
- f_sdpa_attention_forward
787
- )
788
- if verbose:
789
- print(
790
- "[torch_export_patches] restored "
791
- "transformers.integrations.sdpa_attention."
792
- "sdpa_attention_forward"
793
- )
794
-
795
1075
  ########
796
1076
  # caches
797
1077
  ########