onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1090 @@
1
+ import functools
2
+ import inspect
3
+ import operator
4
+ import os
5
+ import traceback
6
+ from functools import reduce
7
+ from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
8
+ import torch
9
+ from torch._subclasses.fake_tensor import FakeTensorMode
10
+
11
+
12
+ def retrieve_stacktrace():
13
+ """Retrieves and prints the current stack trace, avoids every torch file."""
14
+ rows = []
15
+ stack_frames = traceback.extract_stack()
16
+ for frame in stack_frames:
17
+ filename, lineno, function_name, code_line = frame
18
+ if "/torch/" in filename:
19
+ continue
20
+ rows.append(f"File: {filename}, Line {lineno}, in {function_name}")
21
+ if code_line:
22
+ rows.append(f" {code_line}")
23
+ return "\n".join(rows)
24
+
25
+
26
+ def _catch_produce_guards_and_solve_constraints(
27
+ previous_function: Callable,
28
+ fake_mode: FakeTensorMode,
29
+ gm: torch.fx.GraphModule,
30
+ dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
31
+ equalities_inputs: "EqualityConstraint", # noqa: F821
32
+ original_signature: inspect.Signature,
33
+ verbose: int = 0,
34
+ **kwargs,
35
+ ):
36
+ try:
37
+ return previous_function(
38
+ fake_mode=fake_mode,
39
+ gm=gm,
40
+ dynamic_shapes=dynamic_shapes,
41
+ equalities_inputs=equalities_inputs,
42
+ original_signature=original_signature,
43
+ **kwargs,
44
+ )
45
+ except Exception as e:
46
+ if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")):
47
+ raise
48
+ if verbose:
49
+ print(
50
+ f"[_catch_produce_guards_and_solve_constraints] ERROR: "
51
+ f"produce_guards_and_solve_constraints failed, "
52
+ f"use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping\n"
53
+ f"fake_mode={fake_mode}\n"
54
+ f"dynamic_shapes={dynamic_shapes}\n"
55
+ f"equalities_inputs={equalities_inputs}\n"
56
+ f"original_signature={original_signature}\n"
57
+ f"kwargs={kwargs}\n"
58
+ f"exc={e}\ngm={gm}"
59
+ )
60
+ torch._dynamo.reset()
61
+
62
+
63
+ def patch__check_input_constraints_for_graph(
64
+ previous_function: Callable,
65
+ input_placeholders: list[torch.fx.Node],
66
+ flat_args_with_path,
67
+ range_constraints,
68
+ verbose: int = 0,
69
+ ) -> None:
70
+ try:
71
+ # PATCHED: catches exception and prints out the information instead of
72
+ # stopping the conversion.
73
+ return previous_function(input_placeholders, flat_args_with_path, range_constraints)
74
+ except Exception as e:
75
+ if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")):
76
+ raise
77
+ if verbose:
78
+ print(
79
+ f"[_check_input_constraints_for_graph] ERROR: "
80
+ f"_check_input_constraints_for_graph failed, "
81
+ f"use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping\n"
82
+ f"input_placeholders={input_placeholders}\n"
83
+ f"range_constraints={range_constraints}\n"
84
+ f"exc={e}"
85
+ )
86
+ torch._dynamo.reset()
87
+
88
+
89
+ def patched_infer_size(a, b):
90
+ """Patches ``torch._subclasses.fake_impls.infer_size``."""
91
+ from torch.fx.experimental.symbolic_shapes import guard_or_false
92
+
93
+ dimsA = len(a)
94
+ dimsB = len(b)
95
+ ndim = max(dimsA, dimsB)
96
+ expandedSizes = [0] * ndim
97
+ for i in range(ndim - 1, -1, -1):
98
+ offset = ndim - 1 - i
99
+ dimA = dimsA - 1 - offset
100
+ dimB = dimsB - 1 - offset
101
+ sizeA = a[dimA] if dimA >= 0 else 1
102
+ sizeB = b[dimB] if dimB >= 0 else 1
103
+
104
+ # NB: It is very important to test for broadcasting, before testing
105
+ # sizeA == sizeB. This is because the broadcasting tests are likely
106
+ # to be statically known (in particular, if sizeA/sizeB is unbacked
107
+ # but size-like, we will unsoundly assume they never equal 1), but
108
+ # the sizeA == sizeB test may not be statically known. However, once
109
+ # we have established that no broadcasting is happening, the
110
+ # sizeA == sizeB is now expect_true and we can defer it as a runtime
111
+ # assert (this works because Python will return the terminal
112
+ # expression of an or statement as-is, without bool()'ing it; if this
113
+ # were not the case, we'd need to write this using torch.sym_or() or
114
+ # something like that).
115
+ try:
116
+ b1 = guard_or_false(sizeA == 1)
117
+ except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
118
+ b1 = False
119
+ try:
120
+ b2 = guard_or_false(sizeB == 1)
121
+ except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
122
+ b2 = False
123
+ try:
124
+ b3 = guard_or_false(sizeA == sizeB)
125
+ except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
126
+ b3 = False
127
+ if b1 or b2 or b3:
128
+ expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
129
+ else:
130
+ # PATCHED: generic case, the dimension is known, no need to assert
131
+ expandedSizes[i] = torch.sym_max(sizeA, sizeB)
132
+ return tuple(expandedSizes)
133
+
134
+
135
+ def patched__broadcast_shapes(*_shapes):
136
+ """Patches ``torch._refs._broadcast_shapes``."""
137
+ from functools import reduce
138
+ from torch._prims_common import IntLike
139
+ from torch.fx.experimental.symbolic_shapes import (
140
+ guard_or_false,
141
+ is_nested_int,
142
+ )
143
+
144
+ shapes = tuple(
145
+ (x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes)
146
+ )
147
+
148
+ # Short-circuits on no input
149
+ if len(shapes) == 0:
150
+ return None
151
+
152
+ for shape in shapes:
153
+ if not isinstance(shape, Sequence):
154
+ raise RuntimeError(
155
+ "Input shapes should be of type ints, a tuple of ints, "
156
+ "or a list of ints, got ",
157
+ shape,
158
+ )
159
+
160
+ # Computes common shape
161
+ common_shape = [1] * reduce(max, (len(shape) for shape in shapes))
162
+ for _arg_idx, shape in enumerate(shapes):
163
+ for idx in range(-1, -1 - len(shape), -1):
164
+ if is_nested_int(shape[idx]):
165
+ # Broadcasting is allowed for (j0, 1) or (j0, j0);
166
+ # not (j0, j1), (j0, 5), etc.
167
+ if is_nested_int(common_shape[idx]) and guard_or_false(
168
+ shape[idx] == common_shape[idx]
169
+ ):
170
+ continue
171
+ else:
172
+ if guard_or_false(shape[idx] == common_shape[idx]):
173
+ continue
174
+ # PATCHED: two cases, if == for sure, no broadcast,
175
+ # otherwise maybe broadcast with max(dimensions)
176
+ if guard_or_false(common_shape[idx] != 1):
177
+ pass
178
+ elif guard_or_false(common_shape[idx] == 1) or guard_or_false(shape[idx] != 1):
179
+ if shape[idx] < 0:
180
+ raise ValueError(
181
+ "Attempting to broadcast a dimension with negative length!"
182
+ )
183
+ common_shape[idx] = shape[idx]
184
+ else:
185
+ common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx])
186
+
187
+ return common_shape
188
+
189
+
190
+ class patched_ShapeEnv:
191
+
192
+ def _check_frozen(
193
+ self, expr: "sympy.Basic", concrete_val: "sympy.Basic" # noqa: F821
194
+ ) -> None:
195
+ if self.frozen:
196
+ self.counter["ignored_backward_guard"] += 1
197
+ # PATCHED: raised an exception instead of logging.
198
+ raise AssertionError(
199
+ f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
200
+ f"this could result in accuracy problems"
201
+ )
202
+
203
+ def _set_replacement(
204
+ self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str # noqa: F821
205
+ ) -> None:
206
+ """
207
+ Adds or updates a replacement for a symbol.
208
+ Use this instead of `self.replacements[a] = tgt`.
209
+ """
210
+ if tgt == self.replacements.get(a, None):
211
+ return
212
+
213
+ if a in tgt.free_symbols:
214
+ return
215
+
216
+ import sympy
217
+ from torch._logging import structured
218
+ from torch.utils._traceback import CapturedTraceback
219
+ from torch._logging import trace_structured
220
+ from torch._guards import TracingContext
221
+ from torch.utils._sympy.functions import FloorToInt, CeilToInt
222
+ from torch.utils._sympy.solve import try_solve
223
+ from torch.fx.experimental.symbolic_shapes import (
224
+ _is_supported_equivalence,
225
+ ValueRanges,
226
+ )
227
+
228
+ # Precondition: a == tgt
229
+ assert isinstance(a, sympy.Symbol)
230
+
231
+ if (
232
+ getattr(self, "allow_complex_guards_as_runtime_asserts", False)
233
+ or getattr(self, "prefer_deferred_runtime_asserts_over_guards", False)
234
+ ) and not _is_supported_equivalence(tgt):
235
+ # continuing leads to placeholder shapes
236
+ # having complex expressions that we can't resolve
237
+ return
238
+
239
+ # Handles nested tensor symbolic variables which don't have
240
+ # var_to_range bounds
241
+ tgt_bound = None
242
+ if a in self.var_to_range:
243
+ src_bound = self.var_to_range[a]
244
+
245
+ # First, refine the value range of a based on the computed value range
246
+ # of tgt. This is always OK to do, even if we decide not to do the
247
+ # substitution in the end. This might be a no-op, if a already has
248
+ # a tighter bound
249
+ tgt_bound = self.bound_sympy(tgt)
250
+ self._update_var_to_range(a, tgt_bound)
251
+
252
+ # Next, check if we can update the range of free symbols in tgt
253
+ # based on the range in a. But only do it if:
254
+ # - the source bound non-trivially improves over what we get out of
255
+ # the existing bounds.
256
+ # - the replacement is univariate and we can invert the tgt expression
257
+ if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1:
258
+ b = next(iter(tgt.free_symbols))
259
+ # Try to invert the equality
260
+ r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False)
261
+ if r is not None:
262
+ self.log.debug(
263
+ "set_replacement: solve for %s in %s == %s gives %s",
264
+ b,
265
+ a,
266
+ tgt,
267
+ r,
268
+ )
269
+ # The solution here can be non-integral, for example, if
270
+ # we have s0 = 2*s1, then s1 = s0/2. What we would like
271
+ # to do is calculated the bounds in arbitrary precision,
272
+ # and then requantize the bound to integers when we are
273
+ # done.
274
+ rat_b_bound = self.bound_sympy(r[1])
275
+ b_bound = ValueRanges(
276
+ CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper)
277
+ )
278
+ self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a])
279
+ tgt_bound = self.bound_sympy(tgt)
280
+ assert tgt_bound.issubset(
281
+ src_bound
282
+ ), f"{tgt_bound=} not a subset of {src_bound=}"
283
+
284
+ # TODO: Should we propagate size-like-ness?
285
+ #
286
+ # Pros: if u0 is size-like, intuitively u0 == u1 should cause u1
287
+ # to become size-like.
288
+ #
289
+ # Cons: if u0 is size-like, what about u0 - 1 == u1? You CAN'T
290
+ # propagate in this case, because what if u0 == 0, then u1 is negative
291
+ # and clearly isn't a size. So, at minimum, any f(x) whose value
292
+ # range isn't [0, inf] given x in [0, inf] cannot propagate
293
+ # size-like-ness. But there are many situations where you could
294
+ # imagine u1 is going to be size-like and actually you just didn't
295
+ # have a refined enough value range on u0. Since even innocuous
296
+ # looking arithmetic operations can destroy size-like-ness, it's
297
+ # best to not propagate it at all and force the user to annotate it
298
+ # as necessary.
299
+ #
300
+ # Compromise: we preserve size-like-ness only for exact equality
301
+ # and nothing else.
302
+ if a in self.size_like and isinstance(tgt, sympy.Symbol):
303
+ self.size_like.add(tgt)
304
+ elif isinstance(tgt, sympy.Symbol) and tgt in self.size_like:
305
+ self.size_like.add(a)
306
+
307
+ # Now, decide if we will do the substitution.
308
+ #
309
+ # - If the source has a non-trivial range, only substitute if
310
+ # we preserve this range. Note that we may have propagated
311
+ # the src_range to free variables in tgt when tgt is univariate
312
+ # and we could find an inverse, which helps us achieve this.
313
+ # This ensures we never "forget" about user defined ranges,
314
+ # even if they end up being defined on composite formulas
315
+ # like s0 + s1.
316
+ #
317
+ # - If the variable is unbacked, only substitute if the substitution
318
+ # would preserve the bounds also under size-like-ness conditions.
319
+
320
+ if not tgt_bound.issubset(src_bound):
321
+ self.log.debug(
322
+ "skipped set_replacement %s = %s (%s) [%s not subset of %s]",
323
+ a,
324
+ tgt,
325
+ msg,
326
+ tgt_bound,
327
+ src_bound,
328
+ )
329
+ return
330
+ elif a in self.size_like:
331
+ tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True)
332
+ src_bound_so = self.bound_sympy(a, size_oblivious=True)
333
+ if not tgt_bound_so.issubset(src_bound_so):
334
+ self.log.debug(
335
+ "skipped set_replacement %s = %s (%s) "
336
+ "[%s not subset of %s (size-oblivious conditions)]",
337
+ a,
338
+ tgt,
339
+ msg,
340
+ tgt_bound_so,
341
+ src_bound_so,
342
+ )
343
+ return
344
+
345
+ if isinstance(tgt, (sympy.Integer, sympy.Float)):
346
+ # specializing to a constant, which is likely unexpected (unless
347
+ # you specified dynamic=True)
348
+
349
+ user_tb = TracingContext.extract_stack()
350
+ trace_structured(
351
+ "symbolic_shape_specialization",
352
+ metadata_fn=lambda: {
353
+ "symbol": repr(a),
354
+ "sources": [s.name() for s in self.var_to_sources.get(a, [])],
355
+ "value": repr(tgt),
356
+ "reason": msg,
357
+ "stack": structured.from_traceback(
358
+ CapturedTraceback.extract(skip=1).summary()
359
+ ),
360
+ "user_stack": (structured.from_traceback(user_tb) if user_tb else None),
361
+ },
362
+ )
363
+
364
+ for source in self.var_to_sources.get(a, []):
365
+ if user_tb:
366
+ self.specialization_stacks[source] = user_tb
367
+
368
+ # PATCHED: removed lines
369
+ # if config.print_specializations:
370
+ # self.log.warning(
371
+ # "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt
372
+ # )
373
+ # self.log.debug("SPECIALIZATION", stack_info=True)
374
+ # PATCHED: replaces logging by raising an exception
375
+ assert msg != "range_refined_to_singleton", (
376
+ f"patched_ShapeEnv: A dynamic dimension becomes static! "
377
+ f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
378
+ )
379
+ # log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)
380
+ self.replacements[a] = tgt
381
+ # NB: the replacement may get refined, but the user will find the
382
+ # FIRST one most useful (TODO: Maybe we could consider tracking all of
383
+ # them)
384
+ if a not in self.replacements_slocs:
385
+ self.replacements_slocs[a] = self._get_sloc()
386
+ self._update_version_counter()
387
+
388
+ # When specializing 'a == tgt', the equality should be also conveyed to
389
+ # Z3, in case an expression uses 'a'.
390
+ self._add_target_expr(sympy.Eq(a, tgt, evaluate=False))
391
+
392
+ def _log_guard(
393
+ self, prefix: str, g: "SympyBoolean", forcing_spec: bool # noqa: F821
394
+ ) -> None:
395
+ self._log_guard_remember(prefix=prefix, g=g, forcing_spec=forcing_spec)
396
+ # PATCHED: removed
397
+ # It happens too often to be relevant.
398
+ # sloc, _maybe_extra_debug = self._get_stack_summary(True)
399
+ # warnings.warn(
400
+ # f"A guard was added, prefix={prefix!r}, g={g!r}, "
401
+ # f"forcing_spec={forcing_spec}, location=\n{sloc}\n"
402
+ # f"--stack trace--\n{retrieve_stacktrace()}",
403
+ # RuntimeWarning,
404
+ # stacklevel=0,
405
+ # )
406
+
407
+ def _evaluate_expr(
408
+ self,
409
+ orig_expr: "sympy.Basic", # noqa: F821
410
+ hint: Optional[Union[bool, int, float]] = None,
411
+ fx_node: Optional[torch.fx.Node] = None,
412
+ size_oblivious: bool = False,
413
+ fallback_value: Optional[bool] = None,
414
+ *,
415
+ forcing_spec: bool = False,
416
+ ) -> "sympy.Basic": # noqa: F821
417
+ # TODO: split conjunctions and evaluate them separately
418
+ import sympy
419
+ from torch.fx.experimental import _config as config
420
+ from torch.fx.experimental.symbolic_shapes import (
421
+ SympyBoolean,
422
+ log,
423
+ SymT,
424
+ symbol_is_type,
425
+ )
426
+ from torch._guards import ShapeGuard
427
+
428
+ if isinstance(
429
+ orig_expr,
430
+ (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse),
431
+ ):
432
+ return orig_expr
433
+
434
+ # Don't track this one. (Because this cache is inside this function the
435
+ # cache only lasts for the invocation of this function call)
436
+ @functools.cache
437
+ def compute_concrete_val() -> sympy.Basic:
438
+ if hint is None:
439
+ # This is only ever called for expressions WITHOUT unbacked
440
+ # symbols
441
+ r = self.size_hint(orig_expr)
442
+ assert r is not None
443
+ return r
444
+ else:
445
+ return sympy.sympify(hint)
446
+
447
+ concrete_val: Optional[sympy.Basic]
448
+
449
+ # Check if:
450
+ # 1. 'translation_validation' is set
451
+ # 2. the corresponding 'fx_node' is not 'None'
452
+ # 3. the guard should not be suppressed
453
+ # 4. the guard doesn't contain backed symfloat symbols
454
+ # since z3 can't handle floats
455
+ # 5. fallback_value is none.
456
+ # If all of the above check, we create an FX node representing the
457
+ # actual expression to be guarded.
458
+ node = None
459
+ fresh = False
460
+ if (
461
+ self._translation_validation_enabled
462
+ and fx_node is not None
463
+ and not self._suppress_guards_tls()
464
+ and not size_oblivious
465
+ and not any(symbol_is_type(s, SymT.FLOAT) for s in orig_expr.free_symbols)
466
+ and fallback_value is None
467
+ ):
468
+ # TODO: does this even worked with unbacked :think:
469
+ concrete_val = compute_concrete_val()
470
+ if concrete_val is sympy.true:
471
+ node, fresh = self._create_fx_call_function(torch._assert, (fx_node,))
472
+ elif concrete_val is sympy.false:
473
+ neg, _ = self._create_fx_call_function(operator.not_, (fx_node,))
474
+ node, fresh = self._create_fx_call_function(torch._assert, (neg,))
475
+ else:
476
+ eql, _ = self._create_fx_call_function(operator.eq, (fx_node, concrete_val))
477
+ node, fresh = self._create_fx_call_function(torch._assert, (eql,))
478
+
479
+ assert node is not None
480
+ # If this is a fresh node, we have to remember the event index that
481
+ # corresponds to this assertion node.
482
+ # Reason: so that, given an assertion node, we can replay the ShapeEnv
483
+ # events until the point where this assertion node was freshly created.
484
+ if fresh:
485
+ self._add_fx_node_metadata(node)
486
+
487
+ # After creating the FX node corresponding to orig_expr, we must make sure that
488
+ # no error will be raised until the end of this function.
489
+ #
490
+ # Reason: the translation validation may become invalid otherwise.
491
+ #
492
+ # If an error is raised before the end of this function, we remove the FX node
493
+ # inserted, and re-raise the error.
494
+ guard = None
495
+
496
+ try:
497
+ if orig_expr.is_number:
498
+ self.log.debug("eval %s [trivial]", orig_expr)
499
+ if hint is not None:
500
+ if isinstance(hint, bool):
501
+ assert orig_expr == hint, f"{orig_expr} != {hint}"
502
+ else:
503
+ assert sympy.Eq(orig_expr, hint), f"{orig_expr} != {hint}"
504
+ return orig_expr
505
+
506
+ expr = orig_expr
507
+
508
+ static_expr = self._maybe_evaluate_static(expr, size_oblivious=size_oblivious)
509
+ if static_expr is not None:
510
+ self.log.debug(
511
+ "eval %s == %s [statically known]",
512
+ (f"size_oblivious({orig_expr})" if size_oblivious else size_oblivious),
513
+ static_expr,
514
+ )
515
+ if not size_oblivious and config.backed_size_oblivious and hint is not None:
516
+ # TODO: maybe reconcile this with use of counterfactual hints
517
+ # in unbacked case
518
+ assert static_expr == hint, f"{static_expr} != {hint}"
519
+ return static_expr
520
+
521
+ transmute_into_runtime_assert = False
522
+
523
+ concrete_val = None
524
+ if not (expr.free_symbols <= self.var_to_val.keys()):
525
+ # TODO: dedupe this with _maybe_evaluate_static
526
+ # Attempt to eliminate the unbacked SymInt
527
+ new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
528
+ assert new_expr is not None
529
+ if not (new_expr.free_symbols <= self.var_to_val.keys()):
530
+ ok = False
531
+
532
+ # fallback_value is set when guard_or_true or guard_or_false are used.
533
+ if not ok and fallback_value is not None:
534
+ self._log_suppressed_dde(orig_expr, fallback_value)
535
+ return fallback_value
536
+
537
+ # oblivious_var_to_val will be defined iff we have sizes
538
+ # with DimDynamic.OBLIVIOUS_SIZE type.
539
+ # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
540
+ if (
541
+ self.oblivious_var_to_val
542
+ and not (
543
+ correct_hint := orig_expr.xreplace(self.oblivious_var_to_val)
544
+ ).free_symbols
545
+ and not (
546
+ counterfactual_hint := orig_expr.xreplace(
547
+ {k: max(2, v) for k, v in self.oblivious_var_to_val.items()}
548
+ )
549
+ ).free_symbols
550
+ and correct_hint == counterfactual_hint
551
+ ):
552
+ # TODO: better logging
553
+ log.info(
554
+ "oblivious_size %s -> %s (passed counterfactual)",
555
+ orig_expr,
556
+ # pyrefly: ignore # unbound-name
557
+ correct_hint,
558
+ )
559
+ # pyrefly: ignore # unbound-name
560
+ concrete_val = correct_hint
561
+ # NB: do NOT transmute into runtime assert
562
+ ok = True
563
+
564
+ # unbacked_var_to_val is not None iff propagate_real_tensors is on.
565
+ # if propagate_real_tensors is on, we check the example values
566
+ # to generate (unsound_result)
567
+ # and if they pass we add a runtime assertions and continue.
568
+ if (
569
+ not ok
570
+ and self.unbacked_var_to_val
571
+ and not (
572
+ unsound_result := orig_expr.xreplace(
573
+ self.unbacked_var_to_val
574
+ ).xreplace(self.var_to_val)
575
+ ).free_symbols
576
+ ):
577
+ # pyrefly: ignore # unbound-name
578
+ self._log_real_tensor_propagation(orig_expr, unsound_result)
579
+ transmute_into_runtime_assert = True
580
+ # pyrefly: ignore # unbound-name
581
+ concrete_val = unsound_result
582
+ ok = True
583
+
584
+ # Check if this is coming from a python assert statement,
585
+ # if so, convert it to a runtime assertion
586
+ # instead of failing.
587
+ if not ok and self.trace_asserts and self._is_python_assert():
588
+ concrete_val = sympy.true
589
+ transmute_into_runtime_assert = True
590
+ ok = True
591
+
592
+ # PATCHED: ok -> True
593
+ ok = True
594
+ # if not ok:
595
+ # raise self._make_data_dependent_error(
596
+ # expr.xreplace(self.var_to_val),
597
+ # expr,
598
+ # expr_sym_node_id=self._expr_sym_node_id,
599
+ # )
600
+ else:
601
+ expr = new_expr
602
+
603
+ if concrete_val is None:
604
+ concrete_val = compute_concrete_val()
605
+ self._check_frozen(expr, concrete_val)
606
+
607
+ if (
608
+ config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY
609
+ and isinstance(hint, bool)
610
+ and isinstance(expr, (sympy.Eq, sympy.Ne))
611
+ ):
612
+ expr = sympy.Not(expr)
613
+
614
+ # Turn this into a boolean expression, no longer need to consult
615
+ # concrete_val
616
+ if concrete_val is sympy.true:
617
+ g = cast(SympyBoolean, expr)
618
+ elif concrete_val is sympy.false:
619
+ g = sympy.Not(expr)
620
+ else:
621
+ g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type]
622
+
623
+ if transmute_into_runtime_assert:
624
+ self.guard_or_defer_runtime_assert(
625
+ g, f"propagate_real_tensors: {orig_expr} == {concrete_val}"
626
+ )
627
+ return concrete_val
628
+
629
+ if not self._suppress_guards_tls():
630
+ self._log_guard("eval", g, forcing_spec=forcing_spec)
631
+
632
+ # TODO: If we successfully eliminate a symbol via equality, it
633
+ # is not actually necessary to save a guard for the equality,
634
+ # as we will implicitly generate a guard when we match that
635
+ # input against the symbol. Probably the easiest way to
636
+ # implement this is to have maybe_guard_rel return a bool
637
+ # saying if it "subsumed" the guard (and therefore the guard
638
+ # is no longer necessary)
639
+ self._maybe_guard_rel(g)
640
+
641
+ if (
642
+ torch.compiler.is_exporting()
643
+ and self.prefer_deferred_runtime_asserts_over_guards
644
+ ):
645
+ # it's fine to defer simple guards here without checking,
646
+ # the _maybe_guard_rel() call above will set replacements if possible,
647
+ # and so the result here will be statically known
648
+ self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
649
+ else:
650
+ # at this point, we've evaluated the concrete expr value, and have
651
+ # flipped/negated the guard if necessary. Now we know what to guard
652
+ # or defer to runtime assert on.
653
+ guard = ShapeGuard(g, self._get_sloc(), size_oblivious=size_oblivious)
654
+ self.guards.append(guard)
655
+ self.axioms.update(dict(self.get_implications(self.simplify(g))))
656
+ else:
657
+ self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec)
658
+
659
+ except Exception:
660
+ if fresh:
661
+ self._remove_fx_node(node)
662
+ raise
663
+
664
+ if not self._suppress_guards_tls():
665
+ if guard is not None: # we might have deferred this to runtime assert
666
+ for s in g.free_symbols:
667
+ self.symbol_guard_counter[s] += 1
668
+ # Forcing_spec to avoid infinite recursion
669
+ if (
670
+ not forcing_spec
671
+ and config.symbol_guard_limit_before_specialize is not None
672
+ and self.symbol_guard_counter[s]
673
+ > config.symbol_guard_limit_before_specialize
674
+ ):
675
+ # Force specialization
676
+ self.log.info(
677
+ "symbol_guard_limit_before_specialize=%s exceeded on %s",
678
+ config.symbol_guard_limit_before_specialize,
679
+ s,
680
+ )
681
+ self.evaluate_expr(s, forcing_spec=True)
682
+
683
+ return concrete_val
684
+
685
+
686
+ def patched_vmap(func, in_dims=0, out_dims=0, use_scan: bool = False):
687
+ """
688
+ Python implementation of :func:`torch.vmap`.
689
+ The implementation raises an issue when it is being exported with
690
+ :func:`torch.export.export` when the function is called with
691
+ non tensors arguments and the batch size is dynamic.
692
+ """
693
+ from ...helpers import string_type
694
+
695
+ def wrapped(*args):
696
+ assert all(not isinstance(a, dict) for a in args), (
697
+ f"dictionaries are not implemented in "
698
+ f"args={string_type(args, with_shape=True)}"
699
+ )
700
+
701
+ in_dims_ = (
702
+ ([in_dims] * len(args))
703
+ if not isinstance(in_dims, (list, tuple))
704
+ else list(in_dims)
705
+ )
706
+ assert len(in_dims_) == len(args), (
707
+ f"Mismtch between in_dims={in_dims_} and "
708
+ f"args={string_type(args, with_shape=True)}"
709
+ )
710
+
711
+ batch_size = None
712
+ batched_args = []
713
+ for arg, in_dim in zip(args, in_dims_):
714
+ if in_dim is None:
715
+ batched_args.append(arg)
716
+ continue
717
+
718
+ assert batch_size is None or batch_size == arg.size(in_dim), (
719
+ f"Unable to continue, batch_size={batch_size}, in_dim={in_dim}, "
720
+ f"arg.size(in_dim)={arg.size(in_dim)}"
721
+ )
722
+ if batch_size is None:
723
+ batch_size = arg.size(in_dim)
724
+ arg = arg.movedim(in_dim, 0)
725
+ batched_args.append(arg)
726
+
727
+ if use_scan or (
728
+ all(isinstance(a, torch.Tensor) for a in args)
729
+ and isinstance(batch_size, torch.SymInt)
730
+ ):
731
+ batched_tensors = [
732
+ (
733
+ arg
734
+ if (isinstance(arg, torch.Tensor) and in_dim is not None)
735
+ else arg.unsqueeze(0).expand((batch_size, *arg.shape))
736
+ )
737
+ for arg, in_dim in zip(batched_args, in_dims_)
738
+ ]
739
+ results = torch.ops.higher_order.scan(
740
+ lambda *args, **kwargs: [func(*args, **kwargs)], [], batched_tensors, []
741
+ )
742
+ stacked = results[0]
743
+ if out_dims != 0:
744
+ return stacked.movedim(0, out_dims)
745
+ return stacked
746
+
747
+ else:
748
+ torch._check(
749
+ not isinstance(batch_size, torch.SymInt),
750
+ lambda: (
751
+ f"patched_vmap supports dynamic batch_size only if all arguments "
752
+ f"are tensors but types are {[type(a) for a in args]}"
753
+ ),
754
+ )
755
+ batched_tensors = [
756
+ (
757
+ (None, arg)
758
+ if (isinstance(arg, torch.Tensor) and in_dim is not None)
759
+ else (arg, arg)
760
+ )
761
+ for arg, in_dim in zip(batched_args, in_dims_)
762
+ ]
763
+
764
+ results = []
765
+ for i in range(batch_size):
766
+ input_slice = [v if v is not None else arg[i] for v, arg in batched_tensors]
767
+ result = func(*input_slice)
768
+ results.append(result)
769
+
770
+ if isinstance(results[0], torch.Tensor):
771
+ stacked = torch.stack(results)
772
+ if out_dims != 0:
773
+ return stacked.movedim(0, out_dims)
774
+ return stacked
775
+ return results
776
+
777
+ return wrapped
778
+
779
+
780
+ def patched__constrain_user_specified_dimhint_range(
781
+ symint: torch.SymInt,
782
+ hint: int,
783
+ dim: "_DimHint", # noqa: F821
784
+ range_constraints,
785
+ shape_env,
786
+ keypath: "KeyPath", # noqa: F821
787
+ i: Optional[int] = None,
788
+ ) -> Optional[str]:
789
+ """Patches ``torch._export.non_strict_utils._constrain_user_specified_dimhint_range``."""
790
+ from torch._export.non_strict_utils import is_int, int_oo, _DimHintType, ValueRanges
791
+
792
+ trace_vr = (
793
+ range_constraints[symint.node.expr]
794
+ if not is_int(symint)
795
+ else ValueRanges(int(symint), int(symint))
796
+ )
797
+ # warn on 0/1 specialization for Dim.AUTO; not an actual error
798
+ # PATCHED: remove logging
799
+ # if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1):
800
+ # pathstr = f"inputs{pytree.keystr(keypath)}"
801
+ # if i is not None:
802
+ # pathstr += f".shape[{i}]"
803
+ # msg = (
804
+ # f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along "
805
+ # f"with a sample input with hint = {hint}."
806
+ # )
807
+ # log.warning(msg)
808
+
809
+ try:
810
+ user_vr = ValueRanges(
811
+ lower=0 if dim.min is None else dim.min,
812
+ upper=int_oo if dim.max is None else dim.max,
813
+ )
814
+ if is_int(symint):
815
+ out_vr = trace_vr & user_vr
816
+ else:
817
+ range_constraints[symint.node.expr] &= user_vr
818
+ shape_env.var_to_range[symint.node._expr] &= user_vr
819
+ out_vr = range_constraints[symint.node.expr]
820
+
821
+ # check for Dim.DYNAMIC specializations; special case error message on 0/1
822
+ if dim.type == _DimHintType.DYNAMIC and out_vr.is_singleton():
823
+ path = f"inputs{torch.utils._pytree.keystr(keypath)}"
824
+ if i is not None:
825
+ path += f".shape[{i}]"
826
+ if (
827
+ trace_vr.is_singleton()
828
+ and hint in (0, 1)
829
+ # PATCHED: line removed
830
+ # and not torch.fx.experimental._config.backed_size_oblivious
831
+ ):
832
+ return None
833
+ # PATCHED: line removed
834
+ # msg = (
835
+ # f"- Received user-specified dim hint "
836
+ # f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
837
+ # f"but export 0/1 specialized due to hint of "
838
+ # f"{hint} for dimension {path}."
839
+ # )
840
+ else:
841
+ msg = (
842
+ f"- Received user-specified dim hint "
843
+ f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
844
+ f"but tracing inferred a static shape of "
845
+ f"{out_vr.lower} for dimension {path}."
846
+ )
847
+ return msg
848
+
849
+ except torch.utils._sympy.value_ranges.ValueRangeError:
850
+ path = f"inputs{torch.utils._pytree.keystr(keypath)}"
851
+ if i is not None:
852
+ path += f".shape[{i}]"
853
+ msg = (
854
+ f"- Received user-specified min/max range of [{dim.min}, {dim.max}], "
855
+ f"conflicting with the inferred min/max range of "
856
+ f"[{trace_vr.lower}, {trace_vr.upper}], "
857
+ f"for {path}."
858
+ )
859
+ return msg
860
+
861
+ return None
862
+
863
+
864
+ def patched__maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
865
+ """Patches ``torch._refs._maybe_broadcast``."""
866
+ from torch._prims_common import ShapeType, TensorLike, Number
867
+
868
+ # Computes common shape
869
+ common_shape = patched__broadcast_shapes(
870
+ *(t.shape if isinstance(t, TensorLike) else None for t in args)
871
+ )
872
+
873
+ def should_expand(a: ShapeType, b: ShapeType) -> bool:
874
+ from torch.fx.experimental.symbolic_shapes import (
875
+ guard_or_false,
876
+ sym_and,
877
+ sym_or,
878
+ )
879
+
880
+ if len(a) != len(b):
881
+ return True
882
+
883
+ for x, y in zip(a, b):
884
+ if guard_or_false(x != y):
885
+ # We know they are not the same.
886
+ return True
887
+
888
+ # They are the same or we do not know if they are the same or not.
889
+ # 1==1 no-broadcast
890
+ # u0==1 and 1==u0 cases. We broadcast!
891
+ if guard_or_false(sym_and(x == 1, y == 1)):
892
+ pass
893
+ elif guard_or_false(sym_or(x == 1, y == 1)):
894
+ # assume broadcasting.
895
+ return True
896
+
897
+ # u0==u1 assume the same, no broadcasting!
898
+ # PATCHED: avoid errors
899
+ return True # guard_or_true(x != y)
900
+ # torch._check(
901
+ # x == y,
902
+ # lambda x=x, y=y: (
903
+ # f"sizes assumed to be the same due to unbacked "
904
+ # f"broadcasting semantics x={x!r}, y={y!r}"
905
+ # ),
906
+ # )
907
+
908
+ return False
909
+
910
+ def __maybe_broadcast(x, shape):
911
+ if x is None:
912
+ return None
913
+ elif isinstance(x, Number):
914
+ return x
915
+ elif isinstance(x, TensorLike):
916
+ if preserve_cpu_scalar_tensors and torch._prims_common.is_cpu_scalar_tensor(x):
917
+ return x
918
+
919
+ if should_expand(x.shape, common_shape):
920
+ return x.expand(common_shape)
921
+
922
+ return x
923
+ else:
924
+ raise RuntimeError(f"Unexpected type when broadcasting: {str(type(x))}!")
925
+
926
+ return tuple(__maybe_broadcast(x, common_shape) for x in args)
927
+
928
+
929
+ def patched__broadcast_in_dim_meta(
930
+ a: torch._prims_common.TensorLikeType,
931
+ shape: torch._prims_common.ShapeType,
932
+ broadcast_dimensions: Sequence[int],
933
+ ):
934
+ """Patches ``torch._prims._broadcast_in_dim_meta``."""
935
+ from torch.fx.experimental.symbolic_shapes import (
936
+ guard_or_false,
937
+ guard_or_true,
938
+ sym_or,
939
+ )
940
+
941
+ # Type checks
942
+ assert isinstance(a, torch._prims_common.TensorLike)
943
+ assert isinstance(shape, Sequence)
944
+ assert isinstance(broadcast_dimensions, Sequence)
945
+
946
+ # every dimension must be accounted for
947
+ assert a.ndim == len(broadcast_dimensions)
948
+
949
+ # broadcast shape must have weakly more dimensions
950
+ assert len(shape) >= a.ndim
951
+
952
+ # broadcast_dimensions must be an ascending sequence
953
+ # (no relative reordering of dims) of integers and
954
+ # each dimension must be within the new shape
955
+ def _greater_than_reduce(acc, x):
956
+ assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
957
+ assert x > acc
958
+ assert x < len(shape)
959
+
960
+ return x
961
+
962
+ reduce(_greater_than_reduce, broadcast_dimensions, -1)
963
+
964
+ # shape must be broadcastable to
965
+ for idx, new_idx in enumerate(broadcast_dimensions):
966
+ torch._check(
967
+ sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
968
+ lambda idx=idx, new_idx=new_idx: (
969
+ f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
970
+ ),
971
+ )
972
+
973
+ new_strides = []
974
+ original_idx = 0
975
+ for idx in range(len(shape)):
976
+ if idx in broadcast_dimensions:
977
+ # Assigns a stride of zero to dimensions
978
+ # which were actually broadcast
979
+ if guard_or_false(a.shape[original_idx] == 1):
980
+ if guard_or_false(a.shape[original_idx] == shape[idx]):
981
+ new_strides.append(a.stride()[original_idx])
982
+ else:
983
+ new_strides.append(0)
984
+ # PATCHED: disabled this check
985
+ elif guard_or_false(a.shape[original_idx] != 1):
986
+ new_strides.append(a.stride()[original_idx])
987
+ else:
988
+ # This checks generates the following issue:
989
+ # non-broadcasting semantics require s3 == Max(s10, s3), False,
990
+ # guard_or_false(a.shape[idx]==1)=False, a.stride()=(1, 2),
991
+ # idx=1, a.shape=torch.Size([2, s3]), shape=[2, Max(s10, s3)],
992
+ # original_idx=1
993
+ torch._check(
994
+ a.shape[original_idx] == shape[idx],
995
+ lambda idx=idx, original_idx=original_idx: (
996
+ f"non-broadcasting semantics require "
997
+ f"{a.shape[original_idx]} == {shape[idx]}, "
998
+ f"{guard_or_false(a.shape[idx] != 1)}, "
999
+ f"guard_or_false(a.shape[idx]==1)="
1000
+ f"{guard_or_false(a.shape[idx] == 1)}, "
1001
+ f"a.stride()={a.stride()}, idx={idx}, a.shape={a.shape}, "
1002
+ f"shape={shape}, original_idx={original_idx}"
1003
+ ),
1004
+ )
1005
+ new_strides.append(a.stride()[original_idx])
1006
+ original_idx = original_idx + 1
1007
+ else:
1008
+ if guard_or_true(shape[idx] != 1):
1009
+ # consistent with previous use of guard_size_oblivious
1010
+ new_strides.append(0)
1011
+ elif original_idx == a.ndim:
1012
+ new_strides.append(1)
1013
+ else:
1014
+ new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
1015
+
1016
+ return a.as_strided(shape, new_strides, a.storage_offset())
1017
+
1018
+
1019
+ def patched__broadcast_in_dim_meta_level_2(
1020
+ a: torch._prims_common.TensorLikeType,
1021
+ shape: torch._prims_common.ShapeType,
1022
+ broadcast_dimensions: Sequence[int],
1023
+ ):
1024
+ """Patches ``torch._prims._broadcast_in_dim_meta``."""
1025
+ from torch.fx.experimental.symbolic_shapes import (
1026
+ guard_or_false,
1027
+ guard_or_true,
1028
+ sym_or,
1029
+ )
1030
+
1031
+ # Type checks
1032
+ assert isinstance(a, torch._prims_common.TensorLike)
1033
+ assert isinstance(shape, Sequence)
1034
+ assert isinstance(broadcast_dimensions, Sequence)
1035
+
1036
+ # every dimension must be accounted for
1037
+ assert a.ndim == len(broadcast_dimensions)
1038
+
1039
+ # broadcast shape must have weakly more dimensions
1040
+ assert len(shape) >= a.ndim
1041
+
1042
+ # broadcast_dimensions must be an ascending sequence
1043
+ # (no relative reordering of dims) of integers and
1044
+ # each dimension must be within the new shape
1045
+ def _greater_than_reduce(acc, x):
1046
+ assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
1047
+ assert x > acc
1048
+ assert x < len(shape)
1049
+
1050
+ return x
1051
+
1052
+ reduce(_greater_than_reduce, broadcast_dimensions, -1)
1053
+
1054
+ # shape must be broadcastable to
1055
+ for idx, new_idx in enumerate(broadcast_dimensions):
1056
+ torch._check(
1057
+ sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
1058
+ lambda idx=idx, new_idx=new_idx: (
1059
+ f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
1060
+ ),
1061
+ )
1062
+
1063
+ new_strides = []
1064
+ original_idx = 0
1065
+ for idx in range(len(shape)):
1066
+ if idx in broadcast_dimensions:
1067
+ # Assigns a stride of zero to dimensions
1068
+ # which were actually broadcast
1069
+ if guard_or_false(a.shape[original_idx] == 1):
1070
+ if guard_or_false(a.shape[original_idx] == shape[idx]):
1071
+ new_strides.append(a.stride()[original_idx])
1072
+ else:
1073
+ new_strides.append(0)
1074
+ # PATCHED: disabled this check
1075
+ elif guard_or_false(a.shape[original_idx] != 1):
1076
+ new_strides.append(a.stride()[original_idx])
1077
+ else:
1078
+ # PATCHED: torch._check was removed
1079
+ new_strides.append(a.stride()[original_idx])
1080
+ original_idx = original_idx + 1
1081
+ else:
1082
+ if guard_or_true(shape[idx] != 1):
1083
+ # consistent with previous use of guard_size_oblivious
1084
+ new_strides.append(0)
1085
+ elif original_idx == a.ndim:
1086
+ new_strides.append(1)
1087
+ else:
1088
+ new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
1089
+
1090
+ return a.as_strided(shape, new_strides, a.storage_offset())