onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1090 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import inspect
|
|
3
|
+
import operator
|
|
4
|
+
import os
|
|
5
|
+
import traceback
|
|
6
|
+
from functools import reduce
|
|
7
|
+
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
|
|
8
|
+
import torch
|
|
9
|
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def retrieve_stacktrace():
|
|
13
|
+
"""Retrieves and prints the current stack trace, avoids every torch file."""
|
|
14
|
+
rows = []
|
|
15
|
+
stack_frames = traceback.extract_stack()
|
|
16
|
+
for frame in stack_frames:
|
|
17
|
+
filename, lineno, function_name, code_line = frame
|
|
18
|
+
if "/torch/" in filename:
|
|
19
|
+
continue
|
|
20
|
+
rows.append(f"File: {filename}, Line {lineno}, in {function_name}")
|
|
21
|
+
if code_line:
|
|
22
|
+
rows.append(f" {code_line}")
|
|
23
|
+
return "\n".join(rows)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _catch_produce_guards_and_solve_constraints(
|
|
27
|
+
previous_function: Callable,
|
|
28
|
+
fake_mode: FakeTensorMode,
|
|
29
|
+
gm: torch.fx.GraphModule,
|
|
30
|
+
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
|
|
31
|
+
equalities_inputs: "EqualityConstraint", # noqa: F821
|
|
32
|
+
original_signature: inspect.Signature,
|
|
33
|
+
verbose: int = 0,
|
|
34
|
+
**kwargs,
|
|
35
|
+
):
|
|
36
|
+
try:
|
|
37
|
+
return previous_function(
|
|
38
|
+
fake_mode=fake_mode,
|
|
39
|
+
gm=gm,
|
|
40
|
+
dynamic_shapes=dynamic_shapes,
|
|
41
|
+
equalities_inputs=equalities_inputs,
|
|
42
|
+
original_signature=original_signature,
|
|
43
|
+
**kwargs,
|
|
44
|
+
)
|
|
45
|
+
except Exception as e:
|
|
46
|
+
if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")):
|
|
47
|
+
raise
|
|
48
|
+
if verbose:
|
|
49
|
+
print(
|
|
50
|
+
f"[_catch_produce_guards_and_solve_constraints] ERROR: "
|
|
51
|
+
f"produce_guards_and_solve_constraints failed, "
|
|
52
|
+
f"use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping\n"
|
|
53
|
+
f"fake_mode={fake_mode}\n"
|
|
54
|
+
f"dynamic_shapes={dynamic_shapes}\n"
|
|
55
|
+
f"equalities_inputs={equalities_inputs}\n"
|
|
56
|
+
f"original_signature={original_signature}\n"
|
|
57
|
+
f"kwargs={kwargs}\n"
|
|
58
|
+
f"exc={e}\ngm={gm}"
|
|
59
|
+
)
|
|
60
|
+
torch._dynamo.reset()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def patch__check_input_constraints_for_graph(
|
|
64
|
+
previous_function: Callable,
|
|
65
|
+
input_placeholders: list[torch.fx.Node],
|
|
66
|
+
flat_args_with_path,
|
|
67
|
+
range_constraints,
|
|
68
|
+
verbose: int = 0,
|
|
69
|
+
) -> None:
|
|
70
|
+
try:
|
|
71
|
+
# PATCHED: catches exception and prints out the information instead of
|
|
72
|
+
# stopping the conversion.
|
|
73
|
+
return previous_function(input_placeholders, flat_args_with_path, range_constraints)
|
|
74
|
+
except Exception as e:
|
|
75
|
+
if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")):
|
|
76
|
+
raise
|
|
77
|
+
if verbose:
|
|
78
|
+
print(
|
|
79
|
+
f"[_check_input_constraints_for_graph] ERROR: "
|
|
80
|
+
f"_check_input_constraints_for_graph failed, "
|
|
81
|
+
f"use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping\n"
|
|
82
|
+
f"input_placeholders={input_placeholders}\n"
|
|
83
|
+
f"range_constraints={range_constraints}\n"
|
|
84
|
+
f"exc={e}"
|
|
85
|
+
)
|
|
86
|
+
torch._dynamo.reset()
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def patched_infer_size(a, b):
|
|
90
|
+
"""Patches ``torch._subclasses.fake_impls.infer_size``."""
|
|
91
|
+
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
|
92
|
+
|
|
93
|
+
dimsA = len(a)
|
|
94
|
+
dimsB = len(b)
|
|
95
|
+
ndim = max(dimsA, dimsB)
|
|
96
|
+
expandedSizes = [0] * ndim
|
|
97
|
+
for i in range(ndim - 1, -1, -1):
|
|
98
|
+
offset = ndim - 1 - i
|
|
99
|
+
dimA = dimsA - 1 - offset
|
|
100
|
+
dimB = dimsB - 1 - offset
|
|
101
|
+
sizeA = a[dimA] if dimA >= 0 else 1
|
|
102
|
+
sizeB = b[dimB] if dimB >= 0 else 1
|
|
103
|
+
|
|
104
|
+
# NB: It is very important to test for broadcasting, before testing
|
|
105
|
+
# sizeA == sizeB. This is because the broadcasting tests are likely
|
|
106
|
+
# to be statically known (in particular, if sizeA/sizeB is unbacked
|
|
107
|
+
# but size-like, we will unsoundly assume they never equal 1), but
|
|
108
|
+
# the sizeA == sizeB test may not be statically known. However, once
|
|
109
|
+
# we have established that no broadcasting is happening, the
|
|
110
|
+
# sizeA == sizeB is now expect_true and we can defer it as a runtime
|
|
111
|
+
# assert (this works because Python will return the terminal
|
|
112
|
+
# expression of an or statement as-is, without bool()'ing it; if this
|
|
113
|
+
# were not the case, we'd need to write this using torch.sym_or() or
|
|
114
|
+
# something like that).
|
|
115
|
+
try:
|
|
116
|
+
b1 = guard_or_false(sizeA == 1)
|
|
117
|
+
except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
|
|
118
|
+
b1 = False
|
|
119
|
+
try:
|
|
120
|
+
b2 = guard_or_false(sizeB == 1)
|
|
121
|
+
except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
|
|
122
|
+
b2 = False
|
|
123
|
+
try:
|
|
124
|
+
b3 = guard_or_false(sizeA == sizeB)
|
|
125
|
+
except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
|
|
126
|
+
b3 = False
|
|
127
|
+
if b1 or b2 or b3:
|
|
128
|
+
expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
|
|
129
|
+
else:
|
|
130
|
+
# PATCHED: generic case, the dimension is known, no need to assert
|
|
131
|
+
expandedSizes[i] = torch.sym_max(sizeA, sizeB)
|
|
132
|
+
return tuple(expandedSizes)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def patched__broadcast_shapes(*_shapes):
|
|
136
|
+
"""Patches ``torch._refs._broadcast_shapes``."""
|
|
137
|
+
from functools import reduce
|
|
138
|
+
from torch._prims_common import IntLike
|
|
139
|
+
from torch.fx.experimental.symbolic_shapes import (
|
|
140
|
+
guard_or_false,
|
|
141
|
+
is_nested_int,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
shapes = tuple(
|
|
145
|
+
(x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes)
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Short-circuits on no input
|
|
149
|
+
if len(shapes) == 0:
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
for shape in shapes:
|
|
153
|
+
if not isinstance(shape, Sequence):
|
|
154
|
+
raise RuntimeError(
|
|
155
|
+
"Input shapes should be of type ints, a tuple of ints, "
|
|
156
|
+
"or a list of ints, got ",
|
|
157
|
+
shape,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Computes common shape
|
|
161
|
+
common_shape = [1] * reduce(max, (len(shape) for shape in shapes))
|
|
162
|
+
for _arg_idx, shape in enumerate(shapes):
|
|
163
|
+
for idx in range(-1, -1 - len(shape), -1):
|
|
164
|
+
if is_nested_int(shape[idx]):
|
|
165
|
+
# Broadcasting is allowed for (j0, 1) or (j0, j0);
|
|
166
|
+
# not (j0, j1), (j0, 5), etc.
|
|
167
|
+
if is_nested_int(common_shape[idx]) and guard_or_false(
|
|
168
|
+
shape[idx] == common_shape[idx]
|
|
169
|
+
):
|
|
170
|
+
continue
|
|
171
|
+
else:
|
|
172
|
+
if guard_or_false(shape[idx] == common_shape[idx]):
|
|
173
|
+
continue
|
|
174
|
+
# PATCHED: two cases, if == for sure, no broadcast,
|
|
175
|
+
# otherwise maybe broadcast with max(dimensions)
|
|
176
|
+
if guard_or_false(common_shape[idx] != 1):
|
|
177
|
+
pass
|
|
178
|
+
elif guard_or_false(common_shape[idx] == 1) or guard_or_false(shape[idx] != 1):
|
|
179
|
+
if shape[idx] < 0:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
"Attempting to broadcast a dimension with negative length!"
|
|
182
|
+
)
|
|
183
|
+
common_shape[idx] = shape[idx]
|
|
184
|
+
else:
|
|
185
|
+
common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx])
|
|
186
|
+
|
|
187
|
+
return common_shape
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class patched_ShapeEnv:
|
|
191
|
+
|
|
192
|
+
def _check_frozen(
|
|
193
|
+
self, expr: "sympy.Basic", concrete_val: "sympy.Basic" # noqa: F821
|
|
194
|
+
) -> None:
|
|
195
|
+
if self.frozen:
|
|
196
|
+
self.counter["ignored_backward_guard"] += 1
|
|
197
|
+
# PATCHED: raised an exception instead of logging.
|
|
198
|
+
raise AssertionError(
|
|
199
|
+
f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
|
|
200
|
+
f"this could result in accuracy problems"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def _set_replacement(
|
|
204
|
+
self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str # noqa: F821
|
|
205
|
+
) -> None:
|
|
206
|
+
"""
|
|
207
|
+
Adds or updates a replacement for a symbol.
|
|
208
|
+
Use this instead of `self.replacements[a] = tgt`.
|
|
209
|
+
"""
|
|
210
|
+
if tgt == self.replacements.get(a, None):
|
|
211
|
+
return
|
|
212
|
+
|
|
213
|
+
if a in tgt.free_symbols:
|
|
214
|
+
return
|
|
215
|
+
|
|
216
|
+
import sympy
|
|
217
|
+
from torch._logging import structured
|
|
218
|
+
from torch.utils._traceback import CapturedTraceback
|
|
219
|
+
from torch._logging import trace_structured
|
|
220
|
+
from torch._guards import TracingContext
|
|
221
|
+
from torch.utils._sympy.functions import FloorToInt, CeilToInt
|
|
222
|
+
from torch.utils._sympy.solve import try_solve
|
|
223
|
+
from torch.fx.experimental.symbolic_shapes import (
|
|
224
|
+
_is_supported_equivalence,
|
|
225
|
+
ValueRanges,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Precondition: a == tgt
|
|
229
|
+
assert isinstance(a, sympy.Symbol)
|
|
230
|
+
|
|
231
|
+
if (
|
|
232
|
+
getattr(self, "allow_complex_guards_as_runtime_asserts", False)
|
|
233
|
+
or getattr(self, "prefer_deferred_runtime_asserts_over_guards", False)
|
|
234
|
+
) and not _is_supported_equivalence(tgt):
|
|
235
|
+
# continuing leads to placeholder shapes
|
|
236
|
+
# having complex expressions that we can't resolve
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
# Handles nested tensor symbolic variables which don't have
|
|
240
|
+
# var_to_range bounds
|
|
241
|
+
tgt_bound = None
|
|
242
|
+
if a in self.var_to_range:
|
|
243
|
+
src_bound = self.var_to_range[a]
|
|
244
|
+
|
|
245
|
+
# First, refine the value range of a based on the computed value range
|
|
246
|
+
# of tgt. This is always OK to do, even if we decide not to do the
|
|
247
|
+
# substitution in the end. This might be a no-op, if a already has
|
|
248
|
+
# a tighter bound
|
|
249
|
+
tgt_bound = self.bound_sympy(tgt)
|
|
250
|
+
self._update_var_to_range(a, tgt_bound)
|
|
251
|
+
|
|
252
|
+
# Next, check if we can update the range of free symbols in tgt
|
|
253
|
+
# based on the range in a. But only do it if:
|
|
254
|
+
# - the source bound non-trivially improves over what we get out of
|
|
255
|
+
# the existing bounds.
|
|
256
|
+
# - the replacement is univariate and we can invert the tgt expression
|
|
257
|
+
if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1:
|
|
258
|
+
b = next(iter(tgt.free_symbols))
|
|
259
|
+
# Try to invert the equality
|
|
260
|
+
r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False)
|
|
261
|
+
if r is not None:
|
|
262
|
+
self.log.debug(
|
|
263
|
+
"set_replacement: solve for %s in %s == %s gives %s",
|
|
264
|
+
b,
|
|
265
|
+
a,
|
|
266
|
+
tgt,
|
|
267
|
+
r,
|
|
268
|
+
)
|
|
269
|
+
# The solution here can be non-integral, for example, if
|
|
270
|
+
# we have s0 = 2*s1, then s1 = s0/2. What we would like
|
|
271
|
+
# to do is calculated the bounds in arbitrary precision,
|
|
272
|
+
# and then requantize the bound to integers when we are
|
|
273
|
+
# done.
|
|
274
|
+
rat_b_bound = self.bound_sympy(r[1])
|
|
275
|
+
b_bound = ValueRanges(
|
|
276
|
+
CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper)
|
|
277
|
+
)
|
|
278
|
+
self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a])
|
|
279
|
+
tgt_bound = self.bound_sympy(tgt)
|
|
280
|
+
assert tgt_bound.issubset(
|
|
281
|
+
src_bound
|
|
282
|
+
), f"{tgt_bound=} not a subset of {src_bound=}"
|
|
283
|
+
|
|
284
|
+
# TODO: Should we propagate size-like-ness?
|
|
285
|
+
#
|
|
286
|
+
# Pros: if u0 is size-like, intuitively u0 == u1 should cause u1
|
|
287
|
+
# to become size-like.
|
|
288
|
+
#
|
|
289
|
+
# Cons: if u0 is size-like, what about u0 - 1 == u1? You CAN'T
|
|
290
|
+
# propagate in this case, because what if u0 == 0, then u1 is negative
|
|
291
|
+
# and clearly isn't a size. So, at minimum, any f(x) whose value
|
|
292
|
+
# range isn't [0, inf] given x in [0, inf] cannot propagate
|
|
293
|
+
# size-like-ness. But there are many situations where you could
|
|
294
|
+
# imagine u1 is going to be size-like and actually you just didn't
|
|
295
|
+
# have a refined enough value range on u0. Since even innocuous
|
|
296
|
+
# looking arithmetic operations can destroy size-like-ness, it's
|
|
297
|
+
# best to not propagate it at all and force the user to annotate it
|
|
298
|
+
# as necessary.
|
|
299
|
+
#
|
|
300
|
+
# Compromise: we preserve size-like-ness only for exact equality
|
|
301
|
+
# and nothing else.
|
|
302
|
+
if a in self.size_like and isinstance(tgt, sympy.Symbol):
|
|
303
|
+
self.size_like.add(tgt)
|
|
304
|
+
elif isinstance(tgt, sympy.Symbol) and tgt in self.size_like:
|
|
305
|
+
self.size_like.add(a)
|
|
306
|
+
|
|
307
|
+
# Now, decide if we will do the substitution.
|
|
308
|
+
#
|
|
309
|
+
# - If the source has a non-trivial range, only substitute if
|
|
310
|
+
# we preserve this range. Note that we may have propagated
|
|
311
|
+
# the src_range to free variables in tgt when tgt is univariate
|
|
312
|
+
# and we could find an inverse, which helps us achieve this.
|
|
313
|
+
# This ensures we never "forget" about user defined ranges,
|
|
314
|
+
# even if they end up being defined on composite formulas
|
|
315
|
+
# like s0 + s1.
|
|
316
|
+
#
|
|
317
|
+
# - If the variable is unbacked, only substitute if the substitution
|
|
318
|
+
# would preserve the bounds also under size-like-ness conditions.
|
|
319
|
+
|
|
320
|
+
if not tgt_bound.issubset(src_bound):
|
|
321
|
+
self.log.debug(
|
|
322
|
+
"skipped set_replacement %s = %s (%s) [%s not subset of %s]",
|
|
323
|
+
a,
|
|
324
|
+
tgt,
|
|
325
|
+
msg,
|
|
326
|
+
tgt_bound,
|
|
327
|
+
src_bound,
|
|
328
|
+
)
|
|
329
|
+
return
|
|
330
|
+
elif a in self.size_like:
|
|
331
|
+
tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True)
|
|
332
|
+
src_bound_so = self.bound_sympy(a, size_oblivious=True)
|
|
333
|
+
if not tgt_bound_so.issubset(src_bound_so):
|
|
334
|
+
self.log.debug(
|
|
335
|
+
"skipped set_replacement %s = %s (%s) "
|
|
336
|
+
"[%s not subset of %s (size-oblivious conditions)]",
|
|
337
|
+
a,
|
|
338
|
+
tgt,
|
|
339
|
+
msg,
|
|
340
|
+
tgt_bound_so,
|
|
341
|
+
src_bound_so,
|
|
342
|
+
)
|
|
343
|
+
return
|
|
344
|
+
|
|
345
|
+
if isinstance(tgt, (sympy.Integer, sympy.Float)):
|
|
346
|
+
# specializing to a constant, which is likely unexpected (unless
|
|
347
|
+
# you specified dynamic=True)
|
|
348
|
+
|
|
349
|
+
user_tb = TracingContext.extract_stack()
|
|
350
|
+
trace_structured(
|
|
351
|
+
"symbolic_shape_specialization",
|
|
352
|
+
metadata_fn=lambda: {
|
|
353
|
+
"symbol": repr(a),
|
|
354
|
+
"sources": [s.name() for s in self.var_to_sources.get(a, [])],
|
|
355
|
+
"value": repr(tgt),
|
|
356
|
+
"reason": msg,
|
|
357
|
+
"stack": structured.from_traceback(
|
|
358
|
+
CapturedTraceback.extract(skip=1).summary()
|
|
359
|
+
),
|
|
360
|
+
"user_stack": (structured.from_traceback(user_tb) if user_tb else None),
|
|
361
|
+
},
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
for source in self.var_to_sources.get(a, []):
|
|
365
|
+
if user_tb:
|
|
366
|
+
self.specialization_stacks[source] = user_tb
|
|
367
|
+
|
|
368
|
+
# PATCHED: removed lines
|
|
369
|
+
# if config.print_specializations:
|
|
370
|
+
# self.log.warning(
|
|
371
|
+
# "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt
|
|
372
|
+
# )
|
|
373
|
+
# self.log.debug("SPECIALIZATION", stack_info=True)
|
|
374
|
+
# PATCHED: replaces logging by raising an exception
|
|
375
|
+
assert msg != "range_refined_to_singleton", (
|
|
376
|
+
f"patched_ShapeEnv: A dynamic dimension becomes static! "
|
|
377
|
+
f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
|
|
378
|
+
)
|
|
379
|
+
# log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)
|
|
380
|
+
self.replacements[a] = tgt
|
|
381
|
+
# NB: the replacement may get refined, but the user will find the
|
|
382
|
+
# FIRST one most useful (TODO: Maybe we could consider tracking all of
|
|
383
|
+
# them)
|
|
384
|
+
if a not in self.replacements_slocs:
|
|
385
|
+
self.replacements_slocs[a] = self._get_sloc()
|
|
386
|
+
self._update_version_counter()
|
|
387
|
+
|
|
388
|
+
# When specializing 'a == tgt', the equality should be also conveyed to
|
|
389
|
+
# Z3, in case an expression uses 'a'.
|
|
390
|
+
self._add_target_expr(sympy.Eq(a, tgt, evaluate=False))
|
|
391
|
+
|
|
392
|
+
def _log_guard(
|
|
393
|
+
self, prefix: str, g: "SympyBoolean", forcing_spec: bool # noqa: F821
|
|
394
|
+
) -> None:
|
|
395
|
+
self._log_guard_remember(prefix=prefix, g=g, forcing_spec=forcing_spec)
|
|
396
|
+
# PATCHED: removed
|
|
397
|
+
# It happens too often to be relevant.
|
|
398
|
+
# sloc, _maybe_extra_debug = self._get_stack_summary(True)
|
|
399
|
+
# warnings.warn(
|
|
400
|
+
# f"A guard was added, prefix={prefix!r}, g={g!r}, "
|
|
401
|
+
# f"forcing_spec={forcing_spec}, location=\n{sloc}\n"
|
|
402
|
+
# f"--stack trace--\n{retrieve_stacktrace()}",
|
|
403
|
+
# RuntimeWarning,
|
|
404
|
+
# stacklevel=0,
|
|
405
|
+
# )
|
|
406
|
+
|
|
407
|
+
def _evaluate_expr(
|
|
408
|
+
self,
|
|
409
|
+
orig_expr: "sympy.Basic", # noqa: F821
|
|
410
|
+
hint: Optional[Union[bool, int, float]] = None,
|
|
411
|
+
fx_node: Optional[torch.fx.Node] = None,
|
|
412
|
+
size_oblivious: bool = False,
|
|
413
|
+
fallback_value: Optional[bool] = None,
|
|
414
|
+
*,
|
|
415
|
+
forcing_spec: bool = False,
|
|
416
|
+
) -> "sympy.Basic": # noqa: F821
|
|
417
|
+
# TODO: split conjunctions and evaluate them separately
|
|
418
|
+
import sympy
|
|
419
|
+
from torch.fx.experimental import _config as config
|
|
420
|
+
from torch.fx.experimental.symbolic_shapes import (
|
|
421
|
+
SympyBoolean,
|
|
422
|
+
log,
|
|
423
|
+
SymT,
|
|
424
|
+
symbol_is_type,
|
|
425
|
+
)
|
|
426
|
+
from torch._guards import ShapeGuard
|
|
427
|
+
|
|
428
|
+
if isinstance(
|
|
429
|
+
orig_expr,
|
|
430
|
+
(sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse),
|
|
431
|
+
):
|
|
432
|
+
return orig_expr
|
|
433
|
+
|
|
434
|
+
# Don't track this one. (Because this cache is inside this function the
|
|
435
|
+
# cache only lasts for the invocation of this function call)
|
|
436
|
+
@functools.cache
|
|
437
|
+
def compute_concrete_val() -> sympy.Basic:
|
|
438
|
+
if hint is None:
|
|
439
|
+
# This is only ever called for expressions WITHOUT unbacked
|
|
440
|
+
# symbols
|
|
441
|
+
r = self.size_hint(orig_expr)
|
|
442
|
+
assert r is not None
|
|
443
|
+
return r
|
|
444
|
+
else:
|
|
445
|
+
return sympy.sympify(hint)
|
|
446
|
+
|
|
447
|
+
concrete_val: Optional[sympy.Basic]
|
|
448
|
+
|
|
449
|
+
# Check if:
|
|
450
|
+
# 1. 'translation_validation' is set
|
|
451
|
+
# 2. the corresponding 'fx_node' is not 'None'
|
|
452
|
+
# 3. the guard should not be suppressed
|
|
453
|
+
# 4. the guard doesn't contain backed symfloat symbols
|
|
454
|
+
# since z3 can't handle floats
|
|
455
|
+
# 5. fallback_value is none.
|
|
456
|
+
# If all of the above check, we create an FX node representing the
|
|
457
|
+
# actual expression to be guarded.
|
|
458
|
+
node = None
|
|
459
|
+
fresh = False
|
|
460
|
+
if (
|
|
461
|
+
self._translation_validation_enabled
|
|
462
|
+
and fx_node is not None
|
|
463
|
+
and not self._suppress_guards_tls()
|
|
464
|
+
and not size_oblivious
|
|
465
|
+
and not any(symbol_is_type(s, SymT.FLOAT) for s in orig_expr.free_symbols)
|
|
466
|
+
and fallback_value is None
|
|
467
|
+
):
|
|
468
|
+
# TODO: does this even worked with unbacked :think:
|
|
469
|
+
concrete_val = compute_concrete_val()
|
|
470
|
+
if concrete_val is sympy.true:
|
|
471
|
+
node, fresh = self._create_fx_call_function(torch._assert, (fx_node,))
|
|
472
|
+
elif concrete_val is sympy.false:
|
|
473
|
+
neg, _ = self._create_fx_call_function(operator.not_, (fx_node,))
|
|
474
|
+
node, fresh = self._create_fx_call_function(torch._assert, (neg,))
|
|
475
|
+
else:
|
|
476
|
+
eql, _ = self._create_fx_call_function(operator.eq, (fx_node, concrete_val))
|
|
477
|
+
node, fresh = self._create_fx_call_function(torch._assert, (eql,))
|
|
478
|
+
|
|
479
|
+
assert node is not None
|
|
480
|
+
# If this is a fresh node, we have to remember the event index that
|
|
481
|
+
# corresponds to this assertion node.
|
|
482
|
+
# Reason: so that, given an assertion node, we can replay the ShapeEnv
|
|
483
|
+
# events until the point where this assertion node was freshly created.
|
|
484
|
+
if fresh:
|
|
485
|
+
self._add_fx_node_metadata(node)
|
|
486
|
+
|
|
487
|
+
# After creating the FX node corresponding to orig_expr, we must make sure that
|
|
488
|
+
# no error will be raised until the end of this function.
|
|
489
|
+
#
|
|
490
|
+
# Reason: the translation validation may become invalid otherwise.
|
|
491
|
+
#
|
|
492
|
+
# If an error is raised before the end of this function, we remove the FX node
|
|
493
|
+
# inserted, and re-raise the error.
|
|
494
|
+
guard = None
|
|
495
|
+
|
|
496
|
+
try:
|
|
497
|
+
if orig_expr.is_number:
|
|
498
|
+
self.log.debug("eval %s [trivial]", orig_expr)
|
|
499
|
+
if hint is not None:
|
|
500
|
+
if isinstance(hint, bool):
|
|
501
|
+
assert orig_expr == hint, f"{orig_expr} != {hint}"
|
|
502
|
+
else:
|
|
503
|
+
assert sympy.Eq(orig_expr, hint), f"{orig_expr} != {hint}"
|
|
504
|
+
return orig_expr
|
|
505
|
+
|
|
506
|
+
expr = orig_expr
|
|
507
|
+
|
|
508
|
+
static_expr = self._maybe_evaluate_static(expr, size_oblivious=size_oblivious)
|
|
509
|
+
if static_expr is not None:
|
|
510
|
+
self.log.debug(
|
|
511
|
+
"eval %s == %s [statically known]",
|
|
512
|
+
(f"size_oblivious({orig_expr})" if size_oblivious else size_oblivious),
|
|
513
|
+
static_expr,
|
|
514
|
+
)
|
|
515
|
+
if not size_oblivious and config.backed_size_oblivious and hint is not None:
|
|
516
|
+
# TODO: maybe reconcile this with use of counterfactual hints
|
|
517
|
+
# in unbacked case
|
|
518
|
+
assert static_expr == hint, f"{static_expr} != {hint}"
|
|
519
|
+
return static_expr
|
|
520
|
+
|
|
521
|
+
transmute_into_runtime_assert = False
|
|
522
|
+
|
|
523
|
+
concrete_val = None
|
|
524
|
+
if not (expr.free_symbols <= self.var_to_val.keys()):
|
|
525
|
+
# TODO: dedupe this with _maybe_evaluate_static
|
|
526
|
+
# Attempt to eliminate the unbacked SymInt
|
|
527
|
+
new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
|
|
528
|
+
assert new_expr is not None
|
|
529
|
+
if not (new_expr.free_symbols <= self.var_to_val.keys()):
|
|
530
|
+
ok = False
|
|
531
|
+
|
|
532
|
+
# fallback_value is set when guard_or_true or guard_or_false are used.
|
|
533
|
+
if not ok and fallback_value is not None:
|
|
534
|
+
self._log_suppressed_dde(orig_expr, fallback_value)
|
|
535
|
+
return fallback_value
|
|
536
|
+
|
|
537
|
+
# oblivious_var_to_val will be defined iff we have sizes
|
|
538
|
+
# with DimDynamic.OBLIVIOUS_SIZE type.
|
|
539
|
+
# See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
|
|
540
|
+
if (
|
|
541
|
+
self.oblivious_var_to_val
|
|
542
|
+
and not (
|
|
543
|
+
correct_hint := orig_expr.xreplace(self.oblivious_var_to_val)
|
|
544
|
+
).free_symbols
|
|
545
|
+
and not (
|
|
546
|
+
counterfactual_hint := orig_expr.xreplace(
|
|
547
|
+
{k: max(2, v) for k, v in self.oblivious_var_to_val.items()}
|
|
548
|
+
)
|
|
549
|
+
).free_symbols
|
|
550
|
+
and correct_hint == counterfactual_hint
|
|
551
|
+
):
|
|
552
|
+
# TODO: better logging
|
|
553
|
+
log.info(
|
|
554
|
+
"oblivious_size %s -> %s (passed counterfactual)",
|
|
555
|
+
orig_expr,
|
|
556
|
+
# pyrefly: ignore # unbound-name
|
|
557
|
+
correct_hint,
|
|
558
|
+
)
|
|
559
|
+
# pyrefly: ignore # unbound-name
|
|
560
|
+
concrete_val = correct_hint
|
|
561
|
+
# NB: do NOT transmute into runtime assert
|
|
562
|
+
ok = True
|
|
563
|
+
|
|
564
|
+
# unbacked_var_to_val is not None iff propagate_real_tensors is on.
|
|
565
|
+
# if propagate_real_tensors is on, we check the example values
|
|
566
|
+
# to generate (unsound_result)
|
|
567
|
+
# and if they pass we add a runtime assertions and continue.
|
|
568
|
+
if (
|
|
569
|
+
not ok
|
|
570
|
+
and self.unbacked_var_to_val
|
|
571
|
+
and not (
|
|
572
|
+
unsound_result := orig_expr.xreplace(
|
|
573
|
+
self.unbacked_var_to_val
|
|
574
|
+
).xreplace(self.var_to_val)
|
|
575
|
+
).free_symbols
|
|
576
|
+
):
|
|
577
|
+
# pyrefly: ignore # unbound-name
|
|
578
|
+
self._log_real_tensor_propagation(orig_expr, unsound_result)
|
|
579
|
+
transmute_into_runtime_assert = True
|
|
580
|
+
# pyrefly: ignore # unbound-name
|
|
581
|
+
concrete_val = unsound_result
|
|
582
|
+
ok = True
|
|
583
|
+
|
|
584
|
+
# Check if this is coming from a python assert statement,
|
|
585
|
+
# if so, convert it to a runtime assertion
|
|
586
|
+
# instead of failing.
|
|
587
|
+
if not ok and self.trace_asserts and self._is_python_assert():
|
|
588
|
+
concrete_val = sympy.true
|
|
589
|
+
transmute_into_runtime_assert = True
|
|
590
|
+
ok = True
|
|
591
|
+
|
|
592
|
+
# PATCHED: ok -> True
|
|
593
|
+
ok = True
|
|
594
|
+
# if not ok:
|
|
595
|
+
# raise self._make_data_dependent_error(
|
|
596
|
+
# expr.xreplace(self.var_to_val),
|
|
597
|
+
# expr,
|
|
598
|
+
# expr_sym_node_id=self._expr_sym_node_id,
|
|
599
|
+
# )
|
|
600
|
+
else:
|
|
601
|
+
expr = new_expr
|
|
602
|
+
|
|
603
|
+
if concrete_val is None:
|
|
604
|
+
concrete_val = compute_concrete_val()
|
|
605
|
+
self._check_frozen(expr, concrete_val)
|
|
606
|
+
|
|
607
|
+
if (
|
|
608
|
+
config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY
|
|
609
|
+
and isinstance(hint, bool)
|
|
610
|
+
and isinstance(expr, (sympy.Eq, sympy.Ne))
|
|
611
|
+
):
|
|
612
|
+
expr = sympy.Not(expr)
|
|
613
|
+
|
|
614
|
+
# Turn this into a boolean expression, no longer need to consult
|
|
615
|
+
# concrete_val
|
|
616
|
+
if concrete_val is sympy.true:
|
|
617
|
+
g = cast(SympyBoolean, expr)
|
|
618
|
+
elif concrete_val is sympy.false:
|
|
619
|
+
g = sympy.Not(expr)
|
|
620
|
+
else:
|
|
621
|
+
g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type]
|
|
622
|
+
|
|
623
|
+
if transmute_into_runtime_assert:
|
|
624
|
+
self.guard_or_defer_runtime_assert(
|
|
625
|
+
g, f"propagate_real_tensors: {orig_expr} == {concrete_val}"
|
|
626
|
+
)
|
|
627
|
+
return concrete_val
|
|
628
|
+
|
|
629
|
+
if not self._suppress_guards_tls():
|
|
630
|
+
self._log_guard("eval", g, forcing_spec=forcing_spec)
|
|
631
|
+
|
|
632
|
+
# TODO: If we successfully eliminate a symbol via equality, it
|
|
633
|
+
# is not actually necessary to save a guard for the equality,
|
|
634
|
+
# as we will implicitly generate a guard when we match that
|
|
635
|
+
# input against the symbol. Probably the easiest way to
|
|
636
|
+
# implement this is to have maybe_guard_rel return a bool
|
|
637
|
+
# saying if it "subsumed" the guard (and therefore the guard
|
|
638
|
+
# is no longer necessary)
|
|
639
|
+
self._maybe_guard_rel(g)
|
|
640
|
+
|
|
641
|
+
if (
|
|
642
|
+
torch.compiler.is_exporting()
|
|
643
|
+
and self.prefer_deferred_runtime_asserts_over_guards
|
|
644
|
+
):
|
|
645
|
+
# it's fine to defer simple guards here without checking,
|
|
646
|
+
# the _maybe_guard_rel() call above will set replacements if possible,
|
|
647
|
+
# and so the result here will be statically known
|
|
648
|
+
self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
|
|
649
|
+
else:
|
|
650
|
+
# at this point, we've evaluated the concrete expr value, and have
|
|
651
|
+
# flipped/negated the guard if necessary. Now we know what to guard
|
|
652
|
+
# or defer to runtime assert on.
|
|
653
|
+
guard = ShapeGuard(g, self._get_sloc(), size_oblivious=size_oblivious)
|
|
654
|
+
self.guards.append(guard)
|
|
655
|
+
self.axioms.update(dict(self.get_implications(self.simplify(g))))
|
|
656
|
+
else:
|
|
657
|
+
self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec)
|
|
658
|
+
|
|
659
|
+
except Exception:
|
|
660
|
+
if fresh:
|
|
661
|
+
self._remove_fx_node(node)
|
|
662
|
+
raise
|
|
663
|
+
|
|
664
|
+
if not self._suppress_guards_tls():
|
|
665
|
+
if guard is not None: # we might have deferred this to runtime assert
|
|
666
|
+
for s in g.free_symbols:
|
|
667
|
+
self.symbol_guard_counter[s] += 1
|
|
668
|
+
# Forcing_spec to avoid infinite recursion
|
|
669
|
+
if (
|
|
670
|
+
not forcing_spec
|
|
671
|
+
and config.symbol_guard_limit_before_specialize is not None
|
|
672
|
+
and self.symbol_guard_counter[s]
|
|
673
|
+
> config.symbol_guard_limit_before_specialize
|
|
674
|
+
):
|
|
675
|
+
# Force specialization
|
|
676
|
+
self.log.info(
|
|
677
|
+
"symbol_guard_limit_before_specialize=%s exceeded on %s",
|
|
678
|
+
config.symbol_guard_limit_before_specialize,
|
|
679
|
+
s,
|
|
680
|
+
)
|
|
681
|
+
self.evaluate_expr(s, forcing_spec=True)
|
|
682
|
+
|
|
683
|
+
return concrete_val
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
def patched_vmap(func, in_dims=0, out_dims=0, use_scan: bool = False):
|
|
687
|
+
"""
|
|
688
|
+
Python implementation of :func:`torch.vmap`.
|
|
689
|
+
The implementation raises an issue when it is being exported with
|
|
690
|
+
:func:`torch.export.export` when the function is called with
|
|
691
|
+
non tensors arguments and the batch size is dynamic.
|
|
692
|
+
"""
|
|
693
|
+
from ...helpers import string_type
|
|
694
|
+
|
|
695
|
+
def wrapped(*args):
|
|
696
|
+
assert all(not isinstance(a, dict) for a in args), (
|
|
697
|
+
f"dictionaries are not implemented in "
|
|
698
|
+
f"args={string_type(args, with_shape=True)}"
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
in_dims_ = (
|
|
702
|
+
([in_dims] * len(args))
|
|
703
|
+
if not isinstance(in_dims, (list, tuple))
|
|
704
|
+
else list(in_dims)
|
|
705
|
+
)
|
|
706
|
+
assert len(in_dims_) == len(args), (
|
|
707
|
+
f"Mismtch between in_dims={in_dims_} and "
|
|
708
|
+
f"args={string_type(args, with_shape=True)}"
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
batch_size = None
|
|
712
|
+
batched_args = []
|
|
713
|
+
for arg, in_dim in zip(args, in_dims_):
|
|
714
|
+
if in_dim is None:
|
|
715
|
+
batched_args.append(arg)
|
|
716
|
+
continue
|
|
717
|
+
|
|
718
|
+
assert batch_size is None or batch_size == arg.size(in_dim), (
|
|
719
|
+
f"Unable to continue, batch_size={batch_size}, in_dim={in_dim}, "
|
|
720
|
+
f"arg.size(in_dim)={arg.size(in_dim)}"
|
|
721
|
+
)
|
|
722
|
+
if batch_size is None:
|
|
723
|
+
batch_size = arg.size(in_dim)
|
|
724
|
+
arg = arg.movedim(in_dim, 0)
|
|
725
|
+
batched_args.append(arg)
|
|
726
|
+
|
|
727
|
+
if use_scan or (
|
|
728
|
+
all(isinstance(a, torch.Tensor) for a in args)
|
|
729
|
+
and isinstance(batch_size, torch.SymInt)
|
|
730
|
+
):
|
|
731
|
+
batched_tensors = [
|
|
732
|
+
(
|
|
733
|
+
arg
|
|
734
|
+
if (isinstance(arg, torch.Tensor) and in_dim is not None)
|
|
735
|
+
else arg.unsqueeze(0).expand((batch_size, *arg.shape))
|
|
736
|
+
)
|
|
737
|
+
for arg, in_dim in zip(batched_args, in_dims_)
|
|
738
|
+
]
|
|
739
|
+
results = torch.ops.higher_order.scan(
|
|
740
|
+
lambda *args, **kwargs: [func(*args, **kwargs)], [], batched_tensors, []
|
|
741
|
+
)
|
|
742
|
+
stacked = results[0]
|
|
743
|
+
if out_dims != 0:
|
|
744
|
+
return stacked.movedim(0, out_dims)
|
|
745
|
+
return stacked
|
|
746
|
+
|
|
747
|
+
else:
|
|
748
|
+
torch._check(
|
|
749
|
+
not isinstance(batch_size, torch.SymInt),
|
|
750
|
+
lambda: (
|
|
751
|
+
f"patched_vmap supports dynamic batch_size only if all arguments "
|
|
752
|
+
f"are tensors but types are {[type(a) for a in args]}"
|
|
753
|
+
),
|
|
754
|
+
)
|
|
755
|
+
batched_tensors = [
|
|
756
|
+
(
|
|
757
|
+
(None, arg)
|
|
758
|
+
if (isinstance(arg, torch.Tensor) and in_dim is not None)
|
|
759
|
+
else (arg, arg)
|
|
760
|
+
)
|
|
761
|
+
for arg, in_dim in zip(batched_args, in_dims_)
|
|
762
|
+
]
|
|
763
|
+
|
|
764
|
+
results = []
|
|
765
|
+
for i in range(batch_size):
|
|
766
|
+
input_slice = [v if v is not None else arg[i] for v, arg in batched_tensors]
|
|
767
|
+
result = func(*input_slice)
|
|
768
|
+
results.append(result)
|
|
769
|
+
|
|
770
|
+
if isinstance(results[0], torch.Tensor):
|
|
771
|
+
stacked = torch.stack(results)
|
|
772
|
+
if out_dims != 0:
|
|
773
|
+
return stacked.movedim(0, out_dims)
|
|
774
|
+
return stacked
|
|
775
|
+
return results
|
|
776
|
+
|
|
777
|
+
return wrapped
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
def patched__constrain_user_specified_dimhint_range(
|
|
781
|
+
symint: torch.SymInt,
|
|
782
|
+
hint: int,
|
|
783
|
+
dim: "_DimHint", # noqa: F821
|
|
784
|
+
range_constraints,
|
|
785
|
+
shape_env,
|
|
786
|
+
keypath: "KeyPath", # noqa: F821
|
|
787
|
+
i: Optional[int] = None,
|
|
788
|
+
) -> Optional[str]:
|
|
789
|
+
"""Patches ``torch._export.non_strict_utils._constrain_user_specified_dimhint_range``."""
|
|
790
|
+
from torch._export.non_strict_utils import is_int, int_oo, _DimHintType, ValueRanges
|
|
791
|
+
|
|
792
|
+
trace_vr = (
|
|
793
|
+
range_constraints[symint.node.expr]
|
|
794
|
+
if not is_int(symint)
|
|
795
|
+
else ValueRanges(int(symint), int(symint))
|
|
796
|
+
)
|
|
797
|
+
# warn on 0/1 specialization for Dim.AUTO; not an actual error
|
|
798
|
+
# PATCHED: remove logging
|
|
799
|
+
# if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1):
|
|
800
|
+
# pathstr = f"inputs{pytree.keystr(keypath)}"
|
|
801
|
+
# if i is not None:
|
|
802
|
+
# pathstr += f".shape[{i}]"
|
|
803
|
+
# msg = (
|
|
804
|
+
# f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along "
|
|
805
|
+
# f"with a sample input with hint = {hint}."
|
|
806
|
+
# )
|
|
807
|
+
# log.warning(msg)
|
|
808
|
+
|
|
809
|
+
try:
|
|
810
|
+
user_vr = ValueRanges(
|
|
811
|
+
lower=0 if dim.min is None else dim.min,
|
|
812
|
+
upper=int_oo if dim.max is None else dim.max,
|
|
813
|
+
)
|
|
814
|
+
if is_int(symint):
|
|
815
|
+
out_vr = trace_vr & user_vr
|
|
816
|
+
else:
|
|
817
|
+
range_constraints[symint.node.expr] &= user_vr
|
|
818
|
+
shape_env.var_to_range[symint.node._expr] &= user_vr
|
|
819
|
+
out_vr = range_constraints[symint.node.expr]
|
|
820
|
+
|
|
821
|
+
# check for Dim.DYNAMIC specializations; special case error message on 0/1
|
|
822
|
+
if dim.type == _DimHintType.DYNAMIC and out_vr.is_singleton():
|
|
823
|
+
path = f"inputs{torch.utils._pytree.keystr(keypath)}"
|
|
824
|
+
if i is not None:
|
|
825
|
+
path += f".shape[{i}]"
|
|
826
|
+
if (
|
|
827
|
+
trace_vr.is_singleton()
|
|
828
|
+
and hint in (0, 1)
|
|
829
|
+
# PATCHED: line removed
|
|
830
|
+
# and not torch.fx.experimental._config.backed_size_oblivious
|
|
831
|
+
):
|
|
832
|
+
return None
|
|
833
|
+
# PATCHED: line removed
|
|
834
|
+
# msg = (
|
|
835
|
+
# f"- Received user-specified dim hint "
|
|
836
|
+
# f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
|
|
837
|
+
# f"but export 0/1 specialized due to hint of "
|
|
838
|
+
# f"{hint} for dimension {path}."
|
|
839
|
+
# )
|
|
840
|
+
else:
|
|
841
|
+
msg = (
|
|
842
|
+
f"- Received user-specified dim hint "
|
|
843
|
+
f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
|
|
844
|
+
f"but tracing inferred a static shape of "
|
|
845
|
+
f"{out_vr.lower} for dimension {path}."
|
|
846
|
+
)
|
|
847
|
+
return msg
|
|
848
|
+
|
|
849
|
+
except torch.utils._sympy.value_ranges.ValueRangeError:
|
|
850
|
+
path = f"inputs{torch.utils._pytree.keystr(keypath)}"
|
|
851
|
+
if i is not None:
|
|
852
|
+
path += f".shape[{i}]"
|
|
853
|
+
msg = (
|
|
854
|
+
f"- Received user-specified min/max range of [{dim.min}, {dim.max}], "
|
|
855
|
+
f"conflicting with the inferred min/max range of "
|
|
856
|
+
f"[{trace_vr.lower}, {trace_vr.upper}], "
|
|
857
|
+
f"for {path}."
|
|
858
|
+
)
|
|
859
|
+
return msg
|
|
860
|
+
|
|
861
|
+
return None
|
|
862
|
+
|
|
863
|
+
|
|
864
|
+
def patched__maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
|
|
865
|
+
"""Patches ``torch._refs._maybe_broadcast``."""
|
|
866
|
+
from torch._prims_common import ShapeType, TensorLike, Number
|
|
867
|
+
|
|
868
|
+
# Computes common shape
|
|
869
|
+
common_shape = patched__broadcast_shapes(
|
|
870
|
+
*(t.shape if isinstance(t, TensorLike) else None for t in args)
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
def should_expand(a: ShapeType, b: ShapeType) -> bool:
|
|
874
|
+
from torch.fx.experimental.symbolic_shapes import (
|
|
875
|
+
guard_or_false,
|
|
876
|
+
sym_and,
|
|
877
|
+
sym_or,
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
if len(a) != len(b):
|
|
881
|
+
return True
|
|
882
|
+
|
|
883
|
+
for x, y in zip(a, b):
|
|
884
|
+
if guard_or_false(x != y):
|
|
885
|
+
# We know they are not the same.
|
|
886
|
+
return True
|
|
887
|
+
|
|
888
|
+
# They are the same or we do not know if they are the same or not.
|
|
889
|
+
# 1==1 no-broadcast
|
|
890
|
+
# u0==1 and 1==u0 cases. We broadcast!
|
|
891
|
+
if guard_or_false(sym_and(x == 1, y == 1)):
|
|
892
|
+
pass
|
|
893
|
+
elif guard_or_false(sym_or(x == 1, y == 1)):
|
|
894
|
+
# assume broadcasting.
|
|
895
|
+
return True
|
|
896
|
+
|
|
897
|
+
# u0==u1 assume the same, no broadcasting!
|
|
898
|
+
# PATCHED: avoid errors
|
|
899
|
+
return True # guard_or_true(x != y)
|
|
900
|
+
# torch._check(
|
|
901
|
+
# x == y,
|
|
902
|
+
# lambda x=x, y=y: (
|
|
903
|
+
# f"sizes assumed to be the same due to unbacked "
|
|
904
|
+
# f"broadcasting semantics x={x!r}, y={y!r}"
|
|
905
|
+
# ),
|
|
906
|
+
# )
|
|
907
|
+
|
|
908
|
+
return False
|
|
909
|
+
|
|
910
|
+
def __maybe_broadcast(x, shape):
|
|
911
|
+
if x is None:
|
|
912
|
+
return None
|
|
913
|
+
elif isinstance(x, Number):
|
|
914
|
+
return x
|
|
915
|
+
elif isinstance(x, TensorLike):
|
|
916
|
+
if preserve_cpu_scalar_tensors and torch._prims_common.is_cpu_scalar_tensor(x):
|
|
917
|
+
return x
|
|
918
|
+
|
|
919
|
+
if should_expand(x.shape, common_shape):
|
|
920
|
+
return x.expand(common_shape)
|
|
921
|
+
|
|
922
|
+
return x
|
|
923
|
+
else:
|
|
924
|
+
raise RuntimeError(f"Unexpected type when broadcasting: {str(type(x))}!")
|
|
925
|
+
|
|
926
|
+
return tuple(__maybe_broadcast(x, common_shape) for x in args)
|
|
927
|
+
|
|
928
|
+
|
|
929
|
+
def patched__broadcast_in_dim_meta(
|
|
930
|
+
a: torch._prims_common.TensorLikeType,
|
|
931
|
+
shape: torch._prims_common.ShapeType,
|
|
932
|
+
broadcast_dimensions: Sequence[int],
|
|
933
|
+
):
|
|
934
|
+
"""Patches ``torch._prims._broadcast_in_dim_meta``."""
|
|
935
|
+
from torch.fx.experimental.symbolic_shapes import (
|
|
936
|
+
guard_or_false,
|
|
937
|
+
guard_or_true,
|
|
938
|
+
sym_or,
|
|
939
|
+
)
|
|
940
|
+
|
|
941
|
+
# Type checks
|
|
942
|
+
assert isinstance(a, torch._prims_common.TensorLike)
|
|
943
|
+
assert isinstance(shape, Sequence)
|
|
944
|
+
assert isinstance(broadcast_dimensions, Sequence)
|
|
945
|
+
|
|
946
|
+
# every dimension must be accounted for
|
|
947
|
+
assert a.ndim == len(broadcast_dimensions)
|
|
948
|
+
|
|
949
|
+
# broadcast shape must have weakly more dimensions
|
|
950
|
+
assert len(shape) >= a.ndim
|
|
951
|
+
|
|
952
|
+
# broadcast_dimensions must be an ascending sequence
|
|
953
|
+
# (no relative reordering of dims) of integers and
|
|
954
|
+
# each dimension must be within the new shape
|
|
955
|
+
def _greater_than_reduce(acc, x):
|
|
956
|
+
assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
|
|
957
|
+
assert x > acc
|
|
958
|
+
assert x < len(shape)
|
|
959
|
+
|
|
960
|
+
return x
|
|
961
|
+
|
|
962
|
+
reduce(_greater_than_reduce, broadcast_dimensions, -1)
|
|
963
|
+
|
|
964
|
+
# shape must be broadcastable to
|
|
965
|
+
for idx, new_idx in enumerate(broadcast_dimensions):
|
|
966
|
+
torch._check(
|
|
967
|
+
sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
|
|
968
|
+
lambda idx=idx, new_idx=new_idx: (
|
|
969
|
+
f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
|
|
970
|
+
),
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
new_strides = []
|
|
974
|
+
original_idx = 0
|
|
975
|
+
for idx in range(len(shape)):
|
|
976
|
+
if idx in broadcast_dimensions:
|
|
977
|
+
# Assigns a stride of zero to dimensions
|
|
978
|
+
# which were actually broadcast
|
|
979
|
+
if guard_or_false(a.shape[original_idx] == 1):
|
|
980
|
+
if guard_or_false(a.shape[original_idx] == shape[idx]):
|
|
981
|
+
new_strides.append(a.stride()[original_idx])
|
|
982
|
+
else:
|
|
983
|
+
new_strides.append(0)
|
|
984
|
+
# PATCHED: disabled this check
|
|
985
|
+
elif guard_or_false(a.shape[original_idx] != 1):
|
|
986
|
+
new_strides.append(a.stride()[original_idx])
|
|
987
|
+
else:
|
|
988
|
+
# This checks generates the following issue:
|
|
989
|
+
# non-broadcasting semantics require s3 == Max(s10, s3), False,
|
|
990
|
+
# guard_or_false(a.shape[idx]==1)=False, a.stride()=(1, 2),
|
|
991
|
+
# idx=1, a.shape=torch.Size([2, s3]), shape=[2, Max(s10, s3)],
|
|
992
|
+
# original_idx=1
|
|
993
|
+
torch._check(
|
|
994
|
+
a.shape[original_idx] == shape[idx],
|
|
995
|
+
lambda idx=idx, original_idx=original_idx: (
|
|
996
|
+
f"non-broadcasting semantics require "
|
|
997
|
+
f"{a.shape[original_idx]} == {shape[idx]}, "
|
|
998
|
+
f"{guard_or_false(a.shape[idx] != 1)}, "
|
|
999
|
+
f"guard_or_false(a.shape[idx]==1)="
|
|
1000
|
+
f"{guard_or_false(a.shape[idx] == 1)}, "
|
|
1001
|
+
f"a.stride()={a.stride()}, idx={idx}, a.shape={a.shape}, "
|
|
1002
|
+
f"shape={shape}, original_idx={original_idx}"
|
|
1003
|
+
),
|
|
1004
|
+
)
|
|
1005
|
+
new_strides.append(a.stride()[original_idx])
|
|
1006
|
+
original_idx = original_idx + 1
|
|
1007
|
+
else:
|
|
1008
|
+
if guard_or_true(shape[idx] != 1):
|
|
1009
|
+
# consistent with previous use of guard_size_oblivious
|
|
1010
|
+
new_strides.append(0)
|
|
1011
|
+
elif original_idx == a.ndim:
|
|
1012
|
+
new_strides.append(1)
|
|
1013
|
+
else:
|
|
1014
|
+
new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
|
|
1015
|
+
|
|
1016
|
+
return a.as_strided(shape, new_strides, a.storage_offset())
|
|
1017
|
+
|
|
1018
|
+
|
|
1019
|
+
def patched__broadcast_in_dim_meta_level_2(
|
|
1020
|
+
a: torch._prims_common.TensorLikeType,
|
|
1021
|
+
shape: torch._prims_common.ShapeType,
|
|
1022
|
+
broadcast_dimensions: Sequence[int],
|
|
1023
|
+
):
|
|
1024
|
+
"""Patches ``torch._prims._broadcast_in_dim_meta``."""
|
|
1025
|
+
from torch.fx.experimental.symbolic_shapes import (
|
|
1026
|
+
guard_or_false,
|
|
1027
|
+
guard_or_true,
|
|
1028
|
+
sym_or,
|
|
1029
|
+
)
|
|
1030
|
+
|
|
1031
|
+
# Type checks
|
|
1032
|
+
assert isinstance(a, torch._prims_common.TensorLike)
|
|
1033
|
+
assert isinstance(shape, Sequence)
|
|
1034
|
+
assert isinstance(broadcast_dimensions, Sequence)
|
|
1035
|
+
|
|
1036
|
+
# every dimension must be accounted for
|
|
1037
|
+
assert a.ndim == len(broadcast_dimensions)
|
|
1038
|
+
|
|
1039
|
+
# broadcast shape must have weakly more dimensions
|
|
1040
|
+
assert len(shape) >= a.ndim
|
|
1041
|
+
|
|
1042
|
+
# broadcast_dimensions must be an ascending sequence
|
|
1043
|
+
# (no relative reordering of dims) of integers and
|
|
1044
|
+
# each dimension must be within the new shape
|
|
1045
|
+
def _greater_than_reduce(acc, x):
|
|
1046
|
+
assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
|
|
1047
|
+
assert x > acc
|
|
1048
|
+
assert x < len(shape)
|
|
1049
|
+
|
|
1050
|
+
return x
|
|
1051
|
+
|
|
1052
|
+
reduce(_greater_than_reduce, broadcast_dimensions, -1)
|
|
1053
|
+
|
|
1054
|
+
# shape must be broadcastable to
|
|
1055
|
+
for idx, new_idx in enumerate(broadcast_dimensions):
|
|
1056
|
+
torch._check(
|
|
1057
|
+
sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
|
|
1058
|
+
lambda idx=idx, new_idx=new_idx: (
|
|
1059
|
+
f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
|
|
1060
|
+
),
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
new_strides = []
|
|
1064
|
+
original_idx = 0
|
|
1065
|
+
for idx in range(len(shape)):
|
|
1066
|
+
if idx in broadcast_dimensions:
|
|
1067
|
+
# Assigns a stride of zero to dimensions
|
|
1068
|
+
# which were actually broadcast
|
|
1069
|
+
if guard_or_false(a.shape[original_idx] == 1):
|
|
1070
|
+
if guard_or_false(a.shape[original_idx] == shape[idx]):
|
|
1071
|
+
new_strides.append(a.stride()[original_idx])
|
|
1072
|
+
else:
|
|
1073
|
+
new_strides.append(0)
|
|
1074
|
+
# PATCHED: disabled this check
|
|
1075
|
+
elif guard_or_false(a.shape[original_idx] != 1):
|
|
1076
|
+
new_strides.append(a.stride()[original_idx])
|
|
1077
|
+
else:
|
|
1078
|
+
# PATCHED: torch._check was removed
|
|
1079
|
+
new_strides.append(a.stride()[original_idx])
|
|
1080
|
+
original_idx = original_idx + 1
|
|
1081
|
+
else:
|
|
1082
|
+
if guard_or_true(shape[idx] != 1):
|
|
1083
|
+
# consistent with previous use of guard_size_oblivious
|
|
1084
|
+
new_strides.append(0)
|
|
1085
|
+
elif original_idx == a.ndim:
|
|
1086
|
+
new_strides.append(1)
|
|
1087
|
+
else:
|
|
1088
|
+
new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
|
|
1089
|
+
|
|
1090
|
+
return a.as_strided(shape, new_strides, a.storage_offset())
|