onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1098 @@
1
+ import functools
2
+ import importlib
3
+ import inspect
4
+ import contextlib
5
+ import re
6
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
7
+ from .onnx_export_serialization import (
8
+ register_cache_serialization,
9
+ unregister_cache_serialization,
10
+ )
11
+ from .patches import patch_transformers as patch_transformers_list
12
+ from .patch_details import PatchDetails
13
+
14
+
15
+ def get_function(name: str) -> Tuple[type, Callable]:
16
+ """Returns the module and the function based on its name."""
17
+ spl = name.split(".")
18
+ module_name = ".".join(spl[:-1])
19
+ fname = spl[-1]
20
+ mod = importlib.import_module(module_name)
21
+ if not hasattr(mod, fname):
22
+ return None, None
23
+ return mod, getattr(mod, fname)
24
+
25
+
26
+ @functools.lru_cache
27
+ def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
28
+ """Returns the list of patches to make for a specific module."""
29
+ to_patch = []
30
+ for k in dir(mod):
31
+ if k.startswith("patched_"):
32
+ v = getattr(mod, k)
33
+ if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
34
+ to_patch.append(v)
35
+ else:
36
+ # a function
37
+ doc = v.__doc__.lstrip()
38
+ if doc.startswith("manual patch"):
39
+ continue
40
+ reg = re.compile("[\\[]patch:([a-z_A-Z.]+)[\\]]")
41
+ fall = reg.findall(doc)
42
+ assert (
43
+ len(fall) == 1
44
+ ), f"Unable to find patching information for {v} in \n{doc}"
45
+ fmod, f = get_function(fall[0])
46
+ if fmod is None and f is None:
47
+ # The function does not exist in this version of transformers.
48
+ # No patch is needed.
49
+ continue
50
+ to_patch.append({"module": fmod, "function": f, "patch": v})
51
+
52
+ name = mod.__name__
53
+ return name, to_patch
54
+
55
+
56
+ def patch_module_or_classes(
57
+ mod, verbose: int = 0, patch_details: Optional[PatchDetails] = None
58
+ ) -> Dict[type, Dict[type, Callable]]:
59
+ """
60
+ Applies all patches defined in classes prefixed by ``patched_``
61
+ ``cls._PATCHED_CLASS_`` defines the class to patch,
62
+ ``cls._PATCHES_`` defines the method to patch.
63
+ The returns information needs to be sent to :func:`unpatch_module_or_classes`
64
+ to revert the changes.
65
+
66
+ :param mod: module of list of clsses to patch
67
+ :param verbose: verbosity
68
+ :param patch_details: used to store information about the applied patches
69
+ :return: patch info
70
+ """
71
+ if isinstance(mod, list):
72
+ to_patch = mod
73
+ name = "list"
74
+ list_name = "auto/list"
75
+ else:
76
+ name, to_patch = get_patches(mod, verbose)
77
+ list_name = f"auto/{mod.__name__.split('.')[-1]}"
78
+
79
+ res = {}
80
+ for cls in to_patch:
81
+ if isinstance(cls, dict):
82
+ # a function
83
+ keep = {}
84
+ original = cls["module"]
85
+ f = cls["function"]
86
+ assert not f.__name__.startswith("patched_"), (
87
+ f"The function {f} was already patched or the patch was not removed, "
88
+ f"original={original}"
89
+ )
90
+ res[f] = f
91
+ if verbose:
92
+ print(f"[patch_module_or_classes] function: {original.__name__}.{f.__name__}")
93
+ if patch_details:
94
+ patch_details.append(list_name, getattr(original, f.__name__), cls["patch"])
95
+ setattr(original, f.__name__, cls["patch"])
96
+ continue
97
+
98
+ original = cls._PATCHED_CLASS_
99
+ methods = [_ for _ in cls._PATCHES_ if _ is not None]
100
+ if verbose:
101
+ print(f"[patch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}")
102
+
103
+ keep = {n: getattr(original, n, None) for n in methods}
104
+ for n in methods:
105
+ if patch_details:
106
+ if hasattr(original, n):
107
+ p = patch_details.append(list_name, getattr(original, n), getattr(cls, n))
108
+ else:
109
+ p = patch_details.append(
110
+ list_name, f"{original.__name__}{n}", getattr(cls, n)
111
+ )
112
+ if "@patched_dynamic_rope_update" in inspect.getsource(getattr(cls, n)):
113
+ # a tweak to include that patch.
114
+ f = patch_details.find("patched_dynamic_rope_update")
115
+ if f is not None:
116
+ p.add_dependency(f)
117
+ setattr(original, n, getattr(cls, n))
118
+ res[cls] = keep
119
+
120
+ return res
121
+
122
+
123
+ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0):
124
+ """
125
+ Reverts modification made by :func:`patch_module_or_classes`.
126
+
127
+ :param mod: module of list of clsses to patch
128
+ :param verbose: verbosity
129
+ """
130
+ if isinstance(mod, list):
131
+ to_patch = mod
132
+ name = "list"
133
+ else:
134
+ name, to_patch = get_patches(mod, verbose)
135
+
136
+ set_patch_cls = {i for i in to_patch if not isinstance(i, dict)}
137
+ dict_patch_fct = {i["function"]: i for i in to_patch if isinstance(i, dict)}
138
+
139
+ for cls, methods in info.items():
140
+ if cls in set_patch_cls:
141
+ if verbose:
142
+ print(
143
+ f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}"
144
+ )
145
+ original = cls._PATCHED_CLASS_
146
+ for n, v in methods.items():
147
+ if v is None:
148
+ # The method did not exist. We remove it.
149
+ delattr(original, n)
150
+ else:
151
+ setattr(original, n, v)
152
+ continue
153
+ assert cls in dict_patch_fct, (
154
+ f"No patch registered for {cls} in {mod} "
155
+ f"(found {set_patch_cls} and {set(dict_patch_fct)})"
156
+ )
157
+ patch = dict_patch_fct[cls]
158
+ if verbose:
159
+ print(
160
+ f"[unpatch_module_or_classes] function "
161
+ f"{patch['module'].__name__}.{cls.__name__}"
162
+ )
163
+ setattr(patch["module"], cls.__name__, patch["function"])
164
+
165
+
166
+ @contextlib.contextmanager
167
+ def register_additional_serialization_functions(
168
+ patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0
169
+ ) -> Callable:
170
+ """The necessary modifications to run the fx Graph."""
171
+ fct_callable = (
172
+ replacement_before_exporting
173
+ if patch_transformers or patch_diffusers
174
+ else (lambda x: x)
175
+ )
176
+ done = register_cache_serialization(
177
+ patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose
178
+ )
179
+ try:
180
+ yield fct_callable
181
+ finally:
182
+ unregister_cache_serialization(done, verbose=verbose)
183
+
184
+
185
+ def _patch_sympy(verbose: int, patch_details: PatchDetails) -> Tuple[Optional[Callable], ...]:
186
+ import sympy
187
+
188
+ f_sympy_name = getattr(sympy.core.numbers.IntegerConstant, "name", None)
189
+
190
+ if verbose:
191
+ print(f"[torch_export_patches] sympy.__version__={sympy.__version__!r}")
192
+ print("[torch_export_patches] patch sympy")
193
+
194
+ sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}"
195
+ if patch_details:
196
+ patch_details.append(
197
+ "sympy",
198
+ f_sympy_name or "sympy.core.numbers.IntegerConstant.name",
199
+ sympy.core.numbers.IntegerConstant.name,
200
+ )
201
+ return (f_sympy_name,)
202
+
203
+
204
+ def _unpatch_sympy(verbose: int, f_sympy_name: Optional[Callable]):
205
+ # tracked by https://github.com/pytorch/pytorch/issues/143494
206
+ import sympy
207
+
208
+ if f_sympy_name:
209
+ sympy.core.numbers.IntegerConstant.name = f_sympy_name
210
+ else:
211
+ delattr(sympy.core.numbers.IntegerConstant, "name")
212
+
213
+ if verbose:
214
+ print("[torch_export_patches] restored sympy functions")
215
+
216
+
217
+ def _patch_torch(
218
+ verbose: int,
219
+ patch_details: PatchDetails,
220
+ patch_torch: int,
221
+ catch_constraints: bool,
222
+ stop_if_static: int,
223
+ ) -> Tuple[Optional[Callable], ...]:
224
+ import torch
225
+ import torch.jit
226
+ import torch._export.non_strict_utils # produce_guards_and_solve_constraints
227
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
228
+ from .patches.patch_torch import (
229
+ patched_infer_size,
230
+ patched_vmap,
231
+ patched__broadcast_shapes,
232
+ patched__constrain_user_specified_dimhint_range,
233
+ _catch_produce_guards_and_solve_constraints,
234
+ patch__check_input_constraints_for_graph,
235
+ patched__broadcast_in_dim_meta,
236
+ patched__broadcast_in_dim_meta_level_2,
237
+ patched__maybe_broadcast,
238
+ patched_ShapeEnv,
239
+ )
240
+
241
+ f___constrain_user_specified_dimhint_range = None
242
+ f__broadcast_in_dim_meta = None
243
+ f__broadcast_shapes = None
244
+ f__check_input_constraints_for_graph = None
245
+ f__maybe_broadcast = None
246
+ f_broadcast_in_dim = None
247
+ f_infer_size = None
248
+ f_jit_isinstance = None
249
+ f_mark_static_address = None
250
+ f_produce_guards_and_solve_constraints = None
251
+ f_shape_env__check_frozen = None
252
+ f_shape_env__evaluate_expr = None
253
+ f_shape_env__log_guard = None
254
+ f_shape_env__set_replacement = None
255
+ f_vmap = None
256
+
257
+ if verbose:
258
+ print(f"[torch_export_patches] torch.__version__={torch.__version__!r}")
259
+ print(f"[torch_export_patches] stop_if_static={stop_if_static!r}")
260
+ print("[torch_export_patches] patch pytorch")
261
+
262
+ # torch.vmap
263
+ f_vmap = torch.vmap
264
+ torch.vmap = patched_vmap
265
+
266
+ # torch.jit.isinstance
267
+ f_jit_isinstance = torch.jit.isinstance
268
+ torch.jit.isinstance = isinstance
269
+
270
+ # torch._dynamo.mark_static_address
271
+ f_mark_static_address = torch._dynamo.mark_static_address
272
+ torch._dynamo.mark_static_address = lambda *_, **y_: None
273
+
274
+ # torch._subclasses.fake_impls.infer_size
275
+ f_infer_size = torch._subclasses.fake_impls.infer_size
276
+ torch._subclasses.fake_impls.infer_size = patched_infer_size
277
+ if patch_details:
278
+ patch_details.append("torch", f_infer_size, patched_infer_size)
279
+
280
+ # torch._refs._broadcast_shapes
281
+ f__broadcast_shapes = torch._refs._broadcast_shapes
282
+ torch._refs._broadcast_shapes = patched__broadcast_shapes
283
+ torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes
284
+ if patch_details:
285
+ patch_details.append("torch", f__broadcast_shapes, patched__broadcast_shapes)
286
+
287
+ # torch._export.non_strict_utils._constrain_user_specified_dimhint_range
288
+ f___constrain_user_specified_dimhint_range = (
289
+ torch._export.non_strict_utils._constrain_user_specified_dimhint_range
290
+ )
291
+ torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
292
+ patched__constrain_user_specified_dimhint_range
293
+ )
294
+ if patch_details:
295
+ patch_details.append(
296
+ "torch",
297
+ f___constrain_user_specified_dimhint_range,
298
+ patched__constrain_user_specified_dimhint_range,
299
+ )
300
+
301
+ # torch._prims._broadcast_in_dim_meta
302
+ f_broadcast_in_dim = torch._prims.broadcast_in_dim
303
+ f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
304
+ _patched_dim_f = (
305
+ patched__broadcast_in_dim_meta_level_2
306
+ if patch_torch == 2
307
+ else patched__broadcast_in_dim_meta
308
+ )
309
+ torch._prims._broadcast_in_dim_meta = _patched_dim_f
310
+ torch._prims.broadcast_in_dim = _patched_dim_f
311
+ if patch_details:
312
+ patch_details.append("torch", f__broadcast_in_dim_meta, _patched_dim_f)
313
+
314
+ # torch._refs._maybe_broadcast
315
+ f__maybe_broadcast = torch._refs._maybe_broadcast
316
+ torch._refs._maybe_broadcast = patched__maybe_broadcast
317
+ if patch_details:
318
+ patch_details.append("torch", f__maybe_broadcast, patched__maybe_broadcast)
319
+
320
+ # ShapeEnv
321
+ f_shape_env__evaluate_expr = ShapeEnv._evaluate_expr
322
+ ShapeEnv._evaluate_expr = patched_ShapeEnv._evaluate_expr
323
+ if patch_details:
324
+ patch_details.append(
325
+ "torch", f_shape_env__evaluate_expr, patched_ShapeEnv._evaluate_expr
326
+ )
327
+
328
+ # torch._export.non_strict_utils.produce_guards_and_solve_constraints
329
+ if catch_constraints:
330
+ if verbose:
331
+ print("[torch_export_patches] modifies shape constraints")
332
+ f_produce_guards_and_solve_constraints = (
333
+ torch._export.non_strict_utils.produce_guards_and_solve_constraints
334
+ )
335
+ f__check_input_constraints_for_graph = (
336
+ torch._export.utils._check_input_constraints_for_graph
337
+ )
338
+ torch._export.non_strict_utils.produce_guards_and_solve_constraints = (
339
+ lambda *args, **kwargs: _catch_produce_guards_and_solve_constraints(
340
+ f_produce_guards_and_solve_constraints, *args, verbose=verbose, **kwargs
341
+ )
342
+ )
343
+ torch._export.utils._check_input_constraints_for_graph = (
344
+ lambda *args, **kwargs: patch__check_input_constraints_for_graph(
345
+ f__check_input_constraints_for_graph, *args, verbose=verbose, **kwargs
346
+ )
347
+ )
348
+
349
+ if patch_torch and stop_if_static:
350
+ ShapeEnv._log_guard_remember = ShapeEnv._log_guard
351
+
352
+ if verbose:
353
+ print("[torch_export_patches] assert when a dynamic dimension turns static")
354
+ print("[torch_export_patches] replaces ShapeEnv._set_replacement")
355
+
356
+ f_shape_env__set_replacement = ShapeEnv._set_replacement
357
+ ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement
358
+ if patch_details:
359
+ patch_details.append(
360
+ "torch", f_shape_env__set_replacement, patched_ShapeEnv._set_replacement
361
+ )
362
+
363
+ if verbose:
364
+ print("[torch_export_patches] replaces ShapeEnv._log_guard")
365
+ f_shape_env__log_guard = ShapeEnv._log_guard
366
+ ShapeEnv._log_guard = patched_ShapeEnv._log_guard
367
+ if patch_details:
368
+ patch_details.append("torch", f_shape_env__log_guard, patched_ShapeEnv._log_guard)
369
+
370
+ if stop_if_static > 1:
371
+ if verbose:
372
+ print("[torch_export_patches] replaces ShapeEnv._check_frozen")
373
+ f_shape_env__check_frozen = ShapeEnv._check_frozen
374
+ ShapeEnv._check_frozen = patched_ShapeEnv._check_frozen
375
+ if patch_details:
376
+ patch_details.append(
377
+ "torch", f_shape_env__check_frozen, ShapeEnv._check_frozen
378
+ )
379
+ return (
380
+ f___constrain_user_specified_dimhint_range,
381
+ f__broadcast_in_dim_meta,
382
+ f__broadcast_shapes,
383
+ f__check_input_constraints_for_graph,
384
+ f__maybe_broadcast,
385
+ f_broadcast_in_dim,
386
+ f_infer_size,
387
+ f_jit_isinstance,
388
+ f_mark_static_address,
389
+ f_produce_guards_and_solve_constraints,
390
+ f_shape_env__check_frozen,
391
+ f_shape_env__evaluate_expr,
392
+ f_shape_env__log_guard,
393
+ f_shape_env__set_replacement,
394
+ f_vmap,
395
+ )
396
+
397
+
398
+ def _unpatch_torch(
399
+ verbose: int,
400
+ _patch_details: PatchDetails,
401
+ patch_torch: int,
402
+ catch_constraints: bool,
403
+ stop_if_static: int,
404
+ f___constrain_user_specified_dimhint_range: Optional[Callable],
405
+ f__broadcast_in_dim_meta: Optional[Callable],
406
+ f__broadcast_shapes: Optional[Callable],
407
+ f__check_input_constraints_for_graph: Optional[Callable],
408
+ f__maybe_broadcast: Optional[Callable],
409
+ f_broadcast_in_dim: Optional[Callable],
410
+ f_infer_size: Optional[Callable],
411
+ f_jit_isinstance: Optional[Callable],
412
+ f_mark_static_address: Optional[Callable],
413
+ f_produce_guards_and_solve_constraints: Optional[Callable],
414
+ f_shape_env__check_frozen: Optional[Callable],
415
+ f_shape_env__evaluate_expr: Optional[Callable],
416
+ f_shape_env__log_guard: Optional[Callable],
417
+ f_shape_env__set_replacement: Optional[Callable],
418
+ f_vmap: Optional[Callable],
419
+ ):
420
+ import torch
421
+ import torch.jit
422
+ import torch._export.non_strict_utils # produce_guards_and_solve_constraints
423
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
424
+
425
+ # this should disappear when torch.jit is removed
426
+ torch.vmap = f_vmap
427
+ torch.jit.isinstance = f_jit_isinstance
428
+ torch._dynamo.mark_static_address = f_mark_static_address
429
+ # tracked by https://github.com/pytorch/pytorch/issues/143495
430
+ torch._subclasses.fake_impls.infer_size = f_infer_size
431
+ torch._refs._broadcast_shapes = f__broadcast_shapes
432
+ torch._meta_registrations._broadcast_shapes = f__broadcast_shapes
433
+ torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
434
+ f___constrain_user_specified_dimhint_range
435
+ )
436
+ torch._prims._broadcast_in_dim_meta = f__broadcast_in_dim_meta
437
+ torch._prims.broadcast_in_dim = f_broadcast_in_dim
438
+ torch._refs._maybe_broadcast = f__maybe_broadcast
439
+ ShapeEnv._evaluate_expr = f_shape_env__evaluate_expr
440
+
441
+ if verbose:
442
+ print("[torch_export_patches] restored pytorch functions")
443
+
444
+ if patch_torch and stop_if_static:
445
+ if verbose:
446
+ print("[torch_export_patches] restored ShapeEnv._set_replacement")
447
+
448
+ ShapeEnv._set_replacement = f_shape_env__set_replacement
449
+
450
+ if verbose:
451
+ print("[torch_export_patches] restored ShapeEnv._log_guard")
452
+
453
+ ShapeEnv._log_guard = f_shape_env__log_guard
454
+
455
+ if stop_if_static > 1:
456
+ if verbose:
457
+ print("[torch_export_patches] restored ShapeEnv._check_frozen")
458
+ ShapeEnv._check_frozen = f_shape_env__check_frozen
459
+
460
+ if patch_torch and catch_constraints:
461
+ # to catch or skip dynamic_shapes issues
462
+ torch._export.non_strict_utils.produce_guards_and_solve_constraints = (
463
+ f_produce_guards_and_solve_constraints
464
+ )
465
+ torch._export.utils._check_input_constraints_for_graph = (
466
+ f__check_input_constraints_for_graph
467
+ )
468
+ if verbose:
469
+ print("[torch_export_patches] restored shape constraints")
470
+
471
+
472
+ def _patch_transformers(
473
+ verbose: int, patch_details: PatchDetails
474
+ ) -> Tuple[Optional[Callable], ...]:
475
+ import transformers
476
+
477
+ try:
478
+ import transformers.masking_utils as masking_utils
479
+ except ImportError:
480
+ masking_utils = None
481
+
482
+ try:
483
+ import transformers.integrations.sdpa_attention as sdpa_attention
484
+ except ImportError:
485
+ sdpa_attention = None
486
+
487
+ try:
488
+ import transformers.modeling_utils as modeling_utils
489
+ except ImportError:
490
+ modeling_utils = None
491
+
492
+ try:
493
+ import transformers.modeling_rope_utils as modeling_rope_utils
494
+ except ImportError:
495
+ modeling_rope_utils = None
496
+
497
+ if (
498
+ patch_details
499
+ and modeling_rope_utils
500
+ and hasattr(modeling_rope_utils, "dynamic_rope_update")
501
+ ):
502
+ patch_details.append(
503
+ "patch_transformers",
504
+ modeling_rope_utils.dynamic_rope_update,
505
+ patch_transformers_list.patched_dynamic_rope_update,
506
+ )
507
+
508
+ if verbose:
509
+ print(f"[torch_export_patches] transformers.__version__={transformers.__version__!r}")
510
+ assert not sdpa_attention.sdpa_attention_forward.__name__.startswith("patched_"), (
511
+ f"Function 'sdpa_attention.sdpa_attention_forward' is already patched, "
512
+ f"sdpa_attention.sdpa_attention_forward={sdpa_attention.sdpa_attention_forward}"
513
+ )
514
+
515
+ f_transformers__vmap_for_bhqkv = None
516
+ f_transformers_eager_mask = None
517
+ f_transformers_sdpa_attention_forward = None
518
+ f_transformers_sdpa_mask = None
519
+ f_transformers_sdpa_mask_recent_torch = None
520
+
521
+ if ( # vmap
522
+ masking_utils
523
+ and patch_transformers_list.patch_masking_utils
524
+ and hasattr(masking_utils, "_vmap_for_bhqkv")
525
+ ):
526
+ if verbose:
527
+ print("[torch_export_patches] patches transformers.masking_utils._vmap_for_bhqkv")
528
+ f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
529
+ masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
530
+ if patch_details:
531
+ patch_details.append(
532
+ "transformers",
533
+ f_transformers__vmap_for_bhqkv,
534
+ patch_transformers_list.patched__vmap_for_bhqkv,
535
+ )
536
+
537
+ if verbose:
538
+ print(
539
+ "[torch_export_patches] patches "
540
+ "transformers.masking_utils.sdpa_mask_recent_torch"
541
+ )
542
+ f_transformers_sdpa_mask_recent_torch = masking_utils.sdpa_mask_recent_torch
543
+ masking_utils.sdpa_mask_recent_torch = (
544
+ patch_transformers_list.patched_sdpa_mask_recent_torch
545
+ )
546
+ if patch_details:
547
+ patch_details.append(
548
+ "transformers",
549
+ f_transformers_sdpa_mask_recent_torch,
550
+ patch_transformers_list.patched_sdpa_mask_recent_torch,
551
+ )
552
+ if masking_utils.sdpa_mask == f_transformers_sdpa_mask_recent_torch:
553
+ if verbose:
554
+ print("[torch_export_patches] patches transformers.masking_utils.sdpa_mask")
555
+ f_transformers_sdpa_mask = masking_utils.sdpa_mask
556
+ masking_utils.sdpa_mask = patch_transformers_list.patched_sdpa_mask_recent_torch
557
+ if patch_details:
558
+ patch_details.append(
559
+ "transformers",
560
+ f_transformers_sdpa_mask,
561
+ patch_transformers_list.patched_sdpa_mask_recent_torch,
562
+ )
563
+ else:
564
+ f_transformers_sdpa_mask = None
565
+
566
+ if ( # eager_mask
567
+ masking_utils
568
+ and patch_transformers_list.patch_masking_utils
569
+ and hasattr(masking_utils, "eager_mask")
570
+ ):
571
+ if verbose:
572
+ print("[torch_export_patches] patches transformers.masking_utils.eager_mask")
573
+ f_transformers_eager_mask = masking_utils.eager_mask
574
+ masking_utils.eager_mask = patch_transformers_list.patched_eager_mask
575
+ if patch_details:
576
+ patch_details.append(
577
+ "transformers",
578
+ f_transformers_eager_mask,
579
+ patch_transformers_list.patched_eager_mask,
580
+ )
581
+ if (
582
+ "eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
583
+ and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
584
+ == f_transformers_eager_mask
585
+ ):
586
+ if verbose:
587
+ print(
588
+ "[torch_export_patches] patches "
589
+ "transformers.masking_utils.eager_mask "
590
+ "in ALL_MASK_ATTENTION_FUNCTIONS"
591
+ )
592
+ masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
593
+ patch_transformers_list.patched_eager_mask
594
+ )
595
+
596
+ if ( # sdpa_mask
597
+ masking_utils
598
+ and patch_transformers_list.patch_masking_utils
599
+ and hasattr(masking_utils, "sdpa_mask")
600
+ and f_transformers_sdpa_mask is not None
601
+ ):
602
+ if verbose:
603
+ print(
604
+ "[torch_export_patches] patches "
605
+ "transformers.masking_utils.sdpa_mask "
606
+ "in ALL_MASK_ATTENTION_FUNCTIONS"
607
+ )
608
+ if (
609
+ "sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
610
+ and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] == f_transformers_sdpa_mask
611
+ ):
612
+ masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = (
613
+ patch_transformers_list.patched_sdpa_mask_recent_torch
614
+ )
615
+
616
+ if ( # sdpa_attention_forward
617
+ sdpa_attention is not None
618
+ and modeling_utils is not None
619
+ and hasattr(sdpa_attention, "sdpa_attention_forward")
620
+ and hasattr(sdpa_attention, "use_gqa_in_sdpa")
621
+ and hasattr(modeling_utils, "AttentionInterface")
622
+ ):
623
+ if verbose:
624
+ print(
625
+ "[torch_export_patches] patches "
626
+ "transformers.integrations.sdpa_attention.sdpa_attention_forward"
627
+ )
628
+ f_transformers_sdpa_attention_forward = sdpa_attention.sdpa_attention_forward
629
+ assert not f_transformers_sdpa_attention_forward.__name__.startswith("patched_"), (
630
+ f"Function 'sdpa_attention.sdpa_attention_forward' is already patched, "
631
+ f"sdpa_attention.sdpa_attention_forward={f_transformers_sdpa_attention_forward}"
632
+ )
633
+ sdpa_attention.sdpa_attention_forward = (
634
+ patch_transformers_list.patched_sdpa_attention_forward
635
+ )
636
+ modeling_utils.sdpa_attention_forward = (
637
+ patch_transformers_list.patched_sdpa_attention_forward
638
+ )
639
+ modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
640
+ patch_transformers_list.patched_sdpa_attention_forward
641
+ )
642
+ if patch_details:
643
+ patch_details.append(
644
+ "transformers",
645
+ f_transformers_sdpa_attention_forward,
646
+ patch_transformers_list.patched_sdpa_attention_forward,
647
+ )
648
+
649
+ revert_patches_info = patch_module_or_classes(
650
+ patch_transformers_list, verbose=verbose, patch_details=patch_details
651
+ )
652
+
653
+ return (
654
+ f_transformers__vmap_for_bhqkv,
655
+ f_transformers_eager_mask,
656
+ f_transformers_sdpa_attention_forward,
657
+ f_transformers_sdpa_mask,
658
+ f_transformers_sdpa_mask_recent_torch,
659
+ revert_patches_info,
660
+ )
661
+
662
+
663
+ def _unpatch_transformers(
664
+ verbose: int,
665
+ _patch_details: PatchDetails,
666
+ f_transformers__vmap_for_bhqkv: Optional[Callable],
667
+ f_transformers_eager_mask: Optional[Callable],
668
+ f_transformers_sdpa_attention_forward: Optional[Callable],
669
+ f_transformers_sdpa_mask: Optional[Callable],
670
+ f_transformers_sdpa_mask_recent_torch: Optional[Callable],
671
+ revert_patches_info: Optional[Callable],
672
+ ):
673
+
674
+ try:
675
+ import transformers.masking_utils as masking_utils
676
+ except ImportError:
677
+ masking_utils = None
678
+
679
+ try:
680
+ import transformers.integrations.sdpa_attention as sdpa_attention
681
+ except ImportError:
682
+ sdpa_attention = None
683
+
684
+ try:
685
+ import transformers.modeling_utils as modeling_utils
686
+ except ImportError:
687
+ modeling_utils = None
688
+
689
+ try:
690
+ import transformers.masking_utils as masking_utils
691
+ except ImportError:
692
+ masking_utils = None
693
+ if verbose:
694
+ print("[torch_export_patches] unpatches transformers")
695
+
696
+ if ( # vmap
697
+ masking_utils
698
+ and patch_transformers_list.patch_masking_utils
699
+ and hasattr(masking_utils, "_vmap_for_bhqkv")
700
+ ):
701
+ assert f_transformers__vmap_for_bhqkv.__name__ == "_vmap_for_bhqkv", (
702
+ f"corrupted function '_vmap_for_bhqkv', its name is "
703
+ f"{f_transformers__vmap_for_bhqkv.__name__!r}"
704
+ )
705
+ masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
706
+
707
+ if verbose:
708
+ print("[torch_export_patches] restored transformers.masking_utils._vmap_for_bhqkv")
709
+
710
+ assert f_transformers_sdpa_mask_recent_torch.__name__ == "sdpa_mask_recent_torch", (
711
+ f"corrupted function 'sdpa_mask_recent_torch', its name is "
712
+ f"{f_transformers_sdpa_mask_recent_torch.__name__!r}"
713
+ )
714
+ masking_utils.sdpa_mask_recent_torch = f_transformers_sdpa_mask_recent_torch
715
+
716
+ if verbose:
717
+ print(
718
+ "[torch_export_patches] restored "
719
+ "transformers.masking_utils.sdpa_mask_recent_torch"
720
+ )
721
+
722
+ if f_transformers_sdpa_mask is not None:
723
+ assert f_transformers_sdpa_mask.__name__ in (
724
+ "sdpa_mask",
725
+ "sdpa_mask_recent_torch",
726
+ ), (
727
+ f"corrupted function 'sdpa_mask', its name is "
728
+ f"{f_transformers_sdpa_mask.__name__!r}"
729
+ )
730
+ masking_utils.sdpa_mask = f_transformers_sdpa_mask
731
+ if verbose:
732
+ print("[torch_export_patches] restored transformers.masking_utils.sdpa_mask")
733
+
734
+ if ( # eager_mask
735
+ masking_utils
736
+ and patch_transformers_list.patch_masking_utils
737
+ and hasattr(masking_utils, "eager_mask")
738
+ ):
739
+ assert f_transformers_eager_mask.__name__ == "eager_mask", (
740
+ f"corrupted function 'eager_mask', its name is "
741
+ f"{f_transformers_eager_mask.__name__!r}"
742
+ )
743
+ masking_utils.eager_mask = f_transformers_eager_mask
744
+ if verbose:
745
+ print("[torch_export_patches] restored transformers.masking_utils.eager_mask")
746
+ assert masking_utils.eager_mask.__name__ == "eager_mask", (
747
+ f"corrupted function 'eager_mask', its name is "
748
+ f"{masking_utils.eager_mask.__name__!r}"
749
+ )
750
+ if (
751
+ "eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
752
+ and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
753
+ == patch_transformers_list.patched_eager_mask
754
+ ):
755
+ masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = f_transformers_eager_mask
756
+ if verbose:
757
+ print(
758
+ "[torch_export_patches] restored "
759
+ "transformers.masking_utils.eager_mask "
760
+ "in ALL_MASK_ATTENTION_FUNCTIONS"
761
+ )
762
+ assert masking_utils.eager_mask.__name__ == "eager_mask", (
763
+ f"corrupted function 'eager_mask', its name is "
764
+ f"{masking_utils.eager_mask.__name__!r}"
765
+ )
766
+
767
+ if ( # sdpa_mask
768
+ masking_utils
769
+ and patch_transformers_list.patch_masking_utils
770
+ and hasattr(masking_utils, "sdpa_mask")
771
+ ):
772
+ if (
773
+ "sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
774
+ and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
775
+ == patch_transformers_list.patched_sdpa_mask_recent_torch
776
+ ):
777
+ masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = f_transformers_sdpa_mask
778
+ if verbose:
779
+ print(
780
+ "[torch_export_patches] restored "
781
+ "transformers.masking_utils.sdpa_mask "
782
+ "in ALL_MASK_ATTENTION_FUNCTIONS"
783
+ )
784
+
785
+ if ( # sdpa_attention_forward
786
+ sdpa_attention is not None
787
+ and modeling_utils is not None
788
+ and hasattr(sdpa_attention, "sdpa_attention_forward")
789
+ and hasattr(sdpa_attention, "use_gqa_in_sdpa")
790
+ and hasattr(modeling_utils, "AttentionInterface")
791
+ ):
792
+ sdpa_attention.sdpa_attention_forward = f_transformers_sdpa_attention_forward
793
+ modeling_utils.sdpa_attention_forward = f_transformers_sdpa_attention_forward
794
+ modeling_utils.AttentionInterface._global_mapping["sdpa"] = (
795
+ f_transformers_sdpa_attention_forward
796
+ )
797
+ if verbose:
798
+ print(
799
+ "[torch_export_patches] restored "
800
+ "transformers.integrations.sdpa_attention."
801
+ "sdpa_attention_forward"
802
+ )
803
+
804
+ unpatch_module_or_classes(patch_transformers_list, revert_patches_info, verbose=verbose)
805
+
806
+
807
+ @contextlib.contextmanager
808
+ def torch_export_patches(
809
+ patch_sympy: bool = True,
810
+ patch_torch: Union[bool, int] = True,
811
+ patch_transformers: bool = False,
812
+ patch_diffusers: bool = False,
813
+ catch_constraints: bool = True,
814
+ stop_if_static: int = 0,
815
+ verbose: int = 0,
816
+ patch: bool = True,
817
+ custom_patches: Optional[List[type["torch.nn.Module"]]] = None, # noqa: F821
818
+ rewrite: Optional[List[Callable]] = None,
819
+ dump_rewriting: Optional[str] = None,
820
+ patch_details: Optional[PatchDetails] = None,
821
+ ) -> Callable:
822
+ """
823
+ Tries to bypass some situations :func:`torch.export.export` does not support.
824
+ See also :ref:`l-patches-explained` and :ref:`l-patch-coverage`.
825
+
826
+ :param patch_sympy: fix missing method ``name`` for IntegerConstant
827
+ :param patch_torch: patches :epkg:`torch` with supported implementation
828
+ :param patch_transformers: patches :epkg:`transformers` with supported implementation
829
+ :param patch_diffusers: patches :epkg:`diffusers` with supported implementation
830
+ :param catch_constraints: catch constraints related to dynamic shapes,
831
+ as a result, some dynamic dimension may turn into static ones,
832
+ the environment variable ``SKIP_SOLVE_CONSTRAINTS=0``
833
+ can be put to stop at that stage.
834
+ :param stop_if_static: see example :ref:`l-plot-export-locale-issue`,
835
+ to stop the export as soon as an issue is detected with dynamic shapes
836
+ and show a stack trace indicating the exact location of the issue,
837
+ ``if stop_if_static > 1``, more methods are replace to catch more
838
+ issues
839
+ :param patch: if False, disable all patches but keeps the registration of
840
+ serialization functions if other patch functions are enabled
841
+ :param custom_patches: to apply custom patches,
842
+ every patched class must define static attributes
843
+ ``_PATCHES_``, ``_PATCHED_CLASS_``
844
+ :param rewrite: list of methods to automatically rewrite
845
+ before exporting, methods with control flow need to be rewritten
846
+ before being exported if the execution path depends on the inputs,
847
+ this is done by function :func:`transform_method
848
+ <onnx_diagnostic.torch_export_patches.patch_module.transform_method>`,
849
+ its documentation provides possible values
850
+ :param dump_rewriting: dumps rewriting information in file beginning with that prefix
851
+ :param patch_details: if specified, this class is used to stored every rewritten done.
852
+ :param verbose: to show which patches is applied
853
+
854
+ The list of available patches.
855
+
856
+ * ``torch.jit.isinstance``
857
+ * ``torch._dynamo.mark_static_address``
858
+ * ``torch._subclasses.fake_impls.infer_size``
859
+ * ``torch.vmap``
860
+ * fix missing method ``name`` for ``sympy.S.IntegerConstant``
861
+ * ``AttentionMaskConverter._make_causal_mask``
862
+ * Serialization of ``MambaCache`` (in :epkg:`transformers`)
863
+ * Serialization of ``DynamicCache`` (in :epkg:`transformers`)
864
+ * reduce errors due to shape inference
865
+ * fixes some transformers classes,
866
+ see :mod:`onnx_diagnostic.torch_export_patches.patches.patch_transformers`
867
+
868
+ Serialization issues happen when a module takes one input or output
869
+ has a type :func:`torch.export.export` cannot serialize.
870
+
871
+ Examples:
872
+
873
+ .. code-block:: python
874
+
875
+ with torch_export_patches(patch_transformers=True) as modificator:
876
+ inputs = modificator(inputs)
877
+ onx = to_onnx(..., inputs, ...)
878
+
879
+ .. code-block:: python
880
+
881
+ with torch_export_patches(patch_transformers=True) as modificator:
882
+ inputs = modificator(inputs)
883
+ onx = torch.onnx.export(..., inputs, ...)
884
+
885
+ It can be used as well to fix the torch export:
886
+
887
+ .. code-block:: python
888
+
889
+ with torch_export_patches(patch_transformers=True) as modificator:
890
+ inputs = modificator(inputs)
891
+ ep = torch.export.export(..., inputs, ...)
892
+
893
+ When running the model through the exported program, only the
894
+ serialization functions need to be restored:
895
+
896
+ .. code-block:: python
897
+
898
+ with register_additional_serialization_functions() as modificator:
899
+ inputs = modificator(inputs)
900
+ ep = torch.export.export(..., inputs, ...)
901
+
902
+ When exporting a model with a cache, the following error message
903
+ may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``.
904
+ It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`.
905
+ """
906
+ if verbose:
907
+ print(f"[torch_export_patches] patch_sympy={patch_sympy!r}")
908
+ print(f" . patch_torch={patch_torch!r}")
909
+ print(f" . patch_transformers={patch_transformers!r}")
910
+ print(f" . patch_diffusers={patch_diffusers!r}")
911
+ print(f" . catch_constraints={catch_constraints!r}")
912
+ print(f" . stop_if_static={stop_if_static!r}")
913
+ print(f" . patch={patch!r}")
914
+ print(f" . custom_patches={custom_patches!r}")
915
+ print(f"[torch_export_patches] dump_rewriting={dump_rewriting!r}")
916
+
917
+ if rewrite:
918
+ from .patch_module import torch_export_rewrite
919
+
920
+ with (
921
+ torch_export_rewrite(
922
+ rewrite=rewrite,
923
+ dump_rewriting=dump_rewriting,
924
+ verbose=verbose,
925
+ patch_details=patch_details,
926
+ ),
927
+ torch_export_patches( # type: ignore[var-annotated]
928
+ patch_sympy=patch_sympy,
929
+ patch_torch=patch_torch,
930
+ patch_transformers=patch_transformers,
931
+ patch_diffusers=patch_diffusers,
932
+ catch_constraints=catch_constraints,
933
+ stop_if_static=stop_if_static,
934
+ verbose=verbose,
935
+ patch=patch,
936
+ custom_patches=custom_patches,
937
+ patch_details=patch_details,
938
+ ) as f,
939
+ ):
940
+ try:
941
+ yield f
942
+ finally:
943
+ pass
944
+ elif not patch:
945
+ fct_callable = lambda x: x # noqa: E731
946
+ done = register_cache_serialization(
947
+ patch_transformers=patch_transformers,
948
+ patch_diffusers=patch_diffusers,
949
+ verbose=verbose,
950
+ )
951
+ try:
952
+ yield fct_callable
953
+ finally:
954
+ unregister_cache_serialization(done, verbose=verbose)
955
+ else:
956
+ if verbose:
957
+ print(
958
+ "[torch_export_patches] replace torch.jit.isinstance, "
959
+ "torch._dynamo.mark_static_address"
960
+ )
961
+
962
+ # caches
963
+
964
+ cache_done = register_cache_serialization(
965
+ patch_transformers=patch_transformers,
966
+ patch_diffusers=patch_diffusers,
967
+ verbose=verbose,
968
+ )
969
+
970
+ # patches
971
+
972
+ if patch_sympy:
973
+ (f_sympy_name,) = _patch_sympy(verbose, patch_details)
974
+
975
+ if patch_torch:
976
+ (
977
+ f___constrain_user_specified_dimhint_range,
978
+ f__broadcast_in_dim_meta,
979
+ f__broadcast_shapes,
980
+ f__check_input_constraints_for_graph,
981
+ f__maybe_broadcast,
982
+ f_broadcast_in_dim,
983
+ f_infer_size,
984
+ f_jit_isinstance,
985
+ f_mark_static_address,
986
+ f_produce_guards_and_solve_constraints,
987
+ f_shape_env__check_frozen,
988
+ f_shape_env__evaluate_expr,
989
+ f_shape_env__log_guard,
990
+ f_shape_env__set_replacement,
991
+ f_vmap,
992
+ ) = _patch_torch(
993
+ verbose, patch_details, patch_torch, catch_constraints, stop_if_static
994
+ )
995
+
996
+ if patch_transformers:
997
+ (
998
+ f_transformers__vmap_for_bhqkv,
999
+ f_transformers_eager_mask,
1000
+ f_transformers_sdpa_attention_forward,
1001
+ f_transformers_sdpa_mask,
1002
+ f_transformers_sdpa_mask_recent_torch,
1003
+ revert_patches_info,
1004
+ ) = _patch_transformers(verbose, patch_details)
1005
+
1006
+ if custom_patches:
1007
+ if verbose:
1008
+ print("[torch_export_patches] applies custom patches")
1009
+ revert_custom_patches_info = patch_module_or_classes(
1010
+ custom_patches, verbose=verbose, patch_details=patch_details
1011
+ )
1012
+
1013
+ # export
1014
+
1015
+ fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x)
1016
+
1017
+ if verbose:
1018
+ print("[torch_export_patches] done patching")
1019
+
1020
+ try:
1021
+ yield fct_callable
1022
+ finally:
1023
+
1024
+ # unpatch
1025
+
1026
+ if verbose:
1027
+ print("[torch_export_patches] remove patches")
1028
+
1029
+ if patch_sympy:
1030
+ _unpatch_sympy(verbose, f_sympy_name)
1031
+
1032
+ if patch_torch:
1033
+ _unpatch_torch(
1034
+ verbose,
1035
+ patch_details,
1036
+ patch_torch,
1037
+ catch_constraints,
1038
+ stop_if_static,
1039
+ f___constrain_user_specified_dimhint_range,
1040
+ f__broadcast_in_dim_meta,
1041
+ f__broadcast_shapes,
1042
+ f__check_input_constraints_for_graph,
1043
+ f__maybe_broadcast,
1044
+ f_broadcast_in_dim,
1045
+ f_infer_size,
1046
+ f_jit_isinstance,
1047
+ f_mark_static_address,
1048
+ f_produce_guards_and_solve_constraints,
1049
+ f_shape_env__check_frozen,
1050
+ f_shape_env__evaluate_expr,
1051
+ f_shape_env__log_guard,
1052
+ f_shape_env__set_replacement,
1053
+ f_vmap,
1054
+ )
1055
+
1056
+ if patch_transformers:
1057
+ _unpatch_transformers(
1058
+ verbose,
1059
+ patch_details,
1060
+ f_transformers__vmap_for_bhqkv,
1061
+ f_transformers_eager_mask,
1062
+ f_transformers_sdpa_attention_forward,
1063
+ f_transformers_sdpa_mask,
1064
+ f_transformers_sdpa_mask_recent_torch,
1065
+ revert_patches_info,
1066
+ )
1067
+
1068
+ if custom_patches:
1069
+ if verbose:
1070
+ print("[torch_export_patches] unpatches custom patches")
1071
+ unpatch_module_or_classes(
1072
+ custom_patches, revert_custom_patches_info, verbose=verbose
1073
+ )
1074
+
1075
+ ########
1076
+ # caches
1077
+ ########
1078
+
1079
+ unregister_cache_serialization(cache_done, verbose=verbose)
1080
+
1081
+
1082
+ def replacement_before_exporting(args: Any) -> Any:
1083
+ """Does replacements on the given inputs if needed."""
1084
+ if args is None:
1085
+ return None
1086
+ if isinstance(args, (int, float)):
1087
+ return args
1088
+ if type(args) not in {dict, tuple, list}:
1089
+ # BaseModelOutput is a dict
1090
+ return args
1091
+ if isinstance(args, dict):
1092
+ return {k: replacement_before_exporting(v) for k, v in args.items()}
1093
+ if isinstance(args, tuple):
1094
+ return tuple(replacement_before_exporting(v) for v in args)
1095
+ if isinstance(args, list):
1096
+ return [replacement_before_exporting(v) for v in args]
1097
+
1098
+ return args