onnx-diagnostic 0.7.12__py3-none-any.whl → 0.7.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +7 -2
- onnx_diagnostic/export/dynamic_shapes.py +11 -2
- onnx_diagnostic/helpers/helper.py +11 -5
- onnx_diagnostic/helpers/log_helper.py +53 -17
- onnx_diagnostic/helpers/mini_onnx_builder.py +17 -0
- onnx_diagnostic/helpers/model_builder_helper.py +1 -0
- onnx_diagnostic/helpers/rt_helper.py +2 -1
- onnx_diagnostic/helpers/torch_helper.py +31 -7
- onnx_diagnostic/reference/torch_evaluator.py +2 -2
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/image_text_to_text.py +256 -141
- onnx_diagnostic/tasks/text_generation.py +30 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +184 -151
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +20 -5
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +52 -20
- onnx_diagnostic/torch_export_patches/patch_inputs.py +10 -6
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +540 -10
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +269 -4
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +36 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +55 -5
- onnx_diagnostic/torch_models/validate.py +116 -50
- onnx_diagnostic/torch_onnx/sbs.py +2 -1
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/METADATA +11 -31
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/RECORD +29 -27
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,10 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import inspect
|
|
3
|
+
import operator
|
|
2
4
|
import os
|
|
3
5
|
import traceback
|
|
4
|
-
from
|
|
6
|
+
from functools import reduce
|
|
7
|
+
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
|
|
5
8
|
import torch
|
|
6
9
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
7
10
|
|
|
@@ -65,6 +68,8 @@ def patch__check_input_constraints_for_graph(
|
|
|
65
68
|
verbose: int = 0,
|
|
66
69
|
) -> None:
|
|
67
70
|
try:
|
|
71
|
+
# PATCHED: catches exception and prints out the information instead of
|
|
72
|
+
# stopping the conversion.
|
|
68
73
|
return previous_function(input_placeholders, flat_args_with_path, range_constraints)
|
|
69
74
|
except Exception as e:
|
|
70
75
|
if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")):
|
|
@@ -122,8 +127,7 @@ def patched_infer_size(a, b):
|
|
|
122
127
|
if b1 or b2 or b3:
|
|
123
128
|
expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA
|
|
124
129
|
else:
|
|
125
|
-
#
|
|
126
|
-
# Try model SmolLM.
|
|
130
|
+
# PATCHED: generic case, the dimension is known, no need to assert
|
|
127
131
|
expandedSizes[i] = torch.sym_max(sizeA, sizeB)
|
|
128
132
|
return tuple(expandedSizes)
|
|
129
133
|
|
|
@@ -132,7 +136,11 @@ def patched__broadcast_shapes(*_shapes):
|
|
|
132
136
|
"""Patches ``torch._refs._broadcast_shapes``."""
|
|
133
137
|
from functools import reduce
|
|
134
138
|
from torch._prims_common import IntLike
|
|
135
|
-
from torch.fx.experimental.symbolic_shapes import
|
|
139
|
+
from torch.fx.experimental.symbolic_shapes import (
|
|
140
|
+
guard_size_oblivious,
|
|
141
|
+
guard_or_false,
|
|
142
|
+
is_nested_int,
|
|
143
|
+
)
|
|
136
144
|
|
|
137
145
|
shapes = tuple(
|
|
138
146
|
(x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes)
|
|
@@ -142,17 +150,30 @@ def patched__broadcast_shapes(*_shapes):
|
|
|
142
150
|
if len(shapes) == 0:
|
|
143
151
|
return None
|
|
144
152
|
|
|
145
|
-
# Type checking
|
|
146
|
-
# TODO: make common validations available as utils
|
|
147
153
|
for shape in shapes:
|
|
148
|
-
|
|
154
|
+
if not isinstance(shape, Sequence):
|
|
155
|
+
raise RuntimeError(
|
|
156
|
+
"Input shapes should be of type ints, a tuple of ints, "
|
|
157
|
+
"or a list of ints, got ",
|
|
158
|
+
shape,
|
|
159
|
+
)
|
|
149
160
|
|
|
150
161
|
# Computes common shape
|
|
151
|
-
common_shape = [
|
|
152
|
-
1,
|
|
153
|
-
] * reduce(max, (len(shape) for shape in shapes))
|
|
162
|
+
common_shape = [1] * reduce(max, (len(shape) for shape in shapes))
|
|
154
163
|
for _arg_idx, shape in enumerate(shapes):
|
|
155
164
|
for idx in range(-1, -1 - len(shape), -1):
|
|
165
|
+
if is_nested_int(shape[idx]):
|
|
166
|
+
# Broadcasting is allowed for (j0, 1) or (j0, j0);
|
|
167
|
+
# not (j0, j1), (j0, 5), etc.
|
|
168
|
+
if is_nested_int(common_shape[idx]) and guard_or_false(
|
|
169
|
+
shape[idx] == common_shape[idx]
|
|
170
|
+
):
|
|
171
|
+
continue
|
|
172
|
+
else:
|
|
173
|
+
if guard_or_false(shape[idx] == common_shape[idx]):
|
|
174
|
+
continue
|
|
175
|
+
# PATCHED: two cases, if == for sure, no broadcast,
|
|
176
|
+
# otherwise maybe broadcast with max(dimensions)
|
|
156
177
|
if guard_size_oblivious(common_shape[idx] == 1):
|
|
157
178
|
if shape[idx] < 0:
|
|
158
179
|
raise ValueError(
|
|
@@ -172,6 +193,7 @@ class patched_ShapeEnv:
|
|
|
172
193
|
) -> None:
|
|
173
194
|
if self.frozen:
|
|
174
195
|
self.counter["ignored_backward_guard"] += 1
|
|
196
|
+
# PATCHED: raised an exception instead of logging.
|
|
175
197
|
raise AssertionError(
|
|
176
198
|
f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
|
|
177
199
|
f"this could result in accuracy problems"
|
|
@@ -338,11 +360,13 @@ class patched_ShapeEnv:
|
|
|
338
360
|
},
|
|
339
361
|
)
|
|
340
362
|
|
|
363
|
+
# PATCHED: removed lines
|
|
341
364
|
# if config.print_specializations:
|
|
342
365
|
# self.log.warning(
|
|
343
366
|
# "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt
|
|
344
367
|
# )
|
|
345
368
|
# self.log.debug("SPECIALIZATION", stack_info=True)
|
|
369
|
+
# PATCHED: replaces logging by raising an exception
|
|
346
370
|
assert msg != "range_refined_to_singleton", (
|
|
347
371
|
f"patched_ShapeEnv: A dynamic dimension becomes static! "
|
|
348
372
|
f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
|
|
@@ -364,6 +388,7 @@ class patched_ShapeEnv:
|
|
|
364
388
|
self, prefix: str, g: "SympyBoolean", forcing_spec: bool # noqa: F821
|
|
365
389
|
) -> None:
|
|
366
390
|
self._log_guard_remember(prefix=prefix, g=g, forcing_spec=forcing_spec)
|
|
391
|
+
# PATCHED: removed
|
|
367
392
|
# It happens too often to be relevant.
|
|
368
393
|
# sloc, _maybe_extra_debug = self._get_stack_summary(True)
|
|
369
394
|
# warnings.warn(
|
|
@@ -374,6 +399,284 @@ class patched_ShapeEnv:
|
|
|
374
399
|
# stacklevel=0,
|
|
375
400
|
# )
|
|
376
401
|
|
|
402
|
+
def _evaluate_expr(
|
|
403
|
+
self,
|
|
404
|
+
orig_expr: "sympy.Basic", # noqa: F821
|
|
405
|
+
hint: Optional[Union[bool, int, float]] = None,
|
|
406
|
+
fx_node: Optional[torch.fx.Node] = None,
|
|
407
|
+
size_oblivious: bool = False,
|
|
408
|
+
fallback_value: Optional[bool] = None,
|
|
409
|
+
*,
|
|
410
|
+
forcing_spec: bool = False,
|
|
411
|
+
) -> "sympy.Basic": # noqa: F821
|
|
412
|
+
# TODO: split conjunctions and evaluate them separately
|
|
413
|
+
import sympy
|
|
414
|
+
from torch.fx.experimental import _config as config
|
|
415
|
+
from torch.fx.experimental.symbolic_shapes import (
|
|
416
|
+
SympyBoolean,
|
|
417
|
+
log,
|
|
418
|
+
SymT,
|
|
419
|
+
symbol_is_type,
|
|
420
|
+
)
|
|
421
|
+
from torch._guards import ShapeGuard
|
|
422
|
+
|
|
423
|
+
if isinstance(
|
|
424
|
+
orig_expr,
|
|
425
|
+
(sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse),
|
|
426
|
+
):
|
|
427
|
+
return orig_expr
|
|
428
|
+
|
|
429
|
+
# Don't track this one. (Because this cache is inside this function the
|
|
430
|
+
# cache only lasts for the invocation of this function call)
|
|
431
|
+
@functools.cache
|
|
432
|
+
def compute_concrete_val() -> sympy.Basic:
|
|
433
|
+
if hint is None:
|
|
434
|
+
# This is only ever called for expressions WITHOUT unbacked
|
|
435
|
+
# symbols
|
|
436
|
+
r = self.size_hint(orig_expr)
|
|
437
|
+
assert r is not None
|
|
438
|
+
return r
|
|
439
|
+
else:
|
|
440
|
+
return sympy.sympify(hint)
|
|
441
|
+
|
|
442
|
+
concrete_val: Optional[sympy.Basic]
|
|
443
|
+
|
|
444
|
+
# Check if:
|
|
445
|
+
# 1. 'translation_validation' is set
|
|
446
|
+
# 2. the corresponding 'fx_node' is not 'None'
|
|
447
|
+
# 3. the guard should not be suppressed
|
|
448
|
+
# 4. the guard doesn't contain backed symfloat symbols
|
|
449
|
+
# since z3 can't handle floats
|
|
450
|
+
# 5. fallback_value is none.
|
|
451
|
+
# If all of the above check, we create an FX node representing the
|
|
452
|
+
# actual expression to be guarded.
|
|
453
|
+
node = None
|
|
454
|
+
fresh = False
|
|
455
|
+
if (
|
|
456
|
+
self._translation_validation_enabled
|
|
457
|
+
and fx_node is not None
|
|
458
|
+
and not self._suppress_guards_tls()
|
|
459
|
+
and not size_oblivious
|
|
460
|
+
and not any(symbol_is_type(s, SymT.FLOAT) for s in orig_expr.free_symbols)
|
|
461
|
+
and fallback_value is None
|
|
462
|
+
):
|
|
463
|
+
# TODO: does this even worked with unbacked :think:
|
|
464
|
+
concrete_val = compute_concrete_val()
|
|
465
|
+
if concrete_val is sympy.true:
|
|
466
|
+
node, fresh = self._create_fx_call_function(torch._assert, (fx_node,))
|
|
467
|
+
elif concrete_val is sympy.false:
|
|
468
|
+
neg, _ = self._create_fx_call_function(operator.not_, (fx_node,))
|
|
469
|
+
node, fresh = self._create_fx_call_function(torch._assert, (neg,))
|
|
470
|
+
else:
|
|
471
|
+
eql, _ = self._create_fx_call_function(operator.eq, (fx_node, concrete_val))
|
|
472
|
+
node, fresh = self._create_fx_call_function(torch._assert, (eql,))
|
|
473
|
+
|
|
474
|
+
assert node is not None
|
|
475
|
+
# If this is a fresh node, we have to remember the event index that
|
|
476
|
+
# corresponds to this assertion node.
|
|
477
|
+
# Reason: so that, given an assertion node, we can replay the ShapeEnv
|
|
478
|
+
# events until the point where this assertion node was freshly created.
|
|
479
|
+
if fresh:
|
|
480
|
+
self._add_fx_node_metadata(node)
|
|
481
|
+
|
|
482
|
+
# After creating the FX node corresponding to orig_expr, we must make sure that
|
|
483
|
+
# no error will be raised until the end of this function.
|
|
484
|
+
#
|
|
485
|
+
# Reason: the translation validation may become invalid otherwise.
|
|
486
|
+
#
|
|
487
|
+
# If an error is raised before the end of this function, we remove the FX node
|
|
488
|
+
# inserted, and re-raise the error.
|
|
489
|
+
guard = None
|
|
490
|
+
|
|
491
|
+
try:
|
|
492
|
+
if orig_expr.is_number:
|
|
493
|
+
self.log.debug("eval %s [trivial]", orig_expr)
|
|
494
|
+
if hint is not None:
|
|
495
|
+
if isinstance(hint, bool):
|
|
496
|
+
assert orig_expr == hint, f"{orig_expr} != {hint}"
|
|
497
|
+
else:
|
|
498
|
+
assert sympy.Eq(orig_expr, hint), f"{orig_expr} != {hint}"
|
|
499
|
+
return orig_expr
|
|
500
|
+
|
|
501
|
+
expr = orig_expr
|
|
502
|
+
|
|
503
|
+
static_expr = self._maybe_evaluate_static(expr, size_oblivious=size_oblivious)
|
|
504
|
+
if static_expr is not None:
|
|
505
|
+
self.log.debug(
|
|
506
|
+
"eval %s == %s [statically known]",
|
|
507
|
+
(f"size_oblivious({orig_expr})" if size_oblivious else size_oblivious),
|
|
508
|
+
static_expr,
|
|
509
|
+
)
|
|
510
|
+
if not size_oblivious and config.backed_size_oblivious and hint is not None:
|
|
511
|
+
# TODO: maybe reconcile this with use of counterfactual hints
|
|
512
|
+
# in unbacked case
|
|
513
|
+
assert static_expr == hint, f"{static_expr} != {hint}"
|
|
514
|
+
return static_expr
|
|
515
|
+
|
|
516
|
+
transmute_into_runtime_assert = False
|
|
517
|
+
|
|
518
|
+
concrete_val = None
|
|
519
|
+
if not (expr.free_symbols <= self.var_to_val.keys()):
|
|
520
|
+
# TODO: dedupe this with _maybe_evaluate_static
|
|
521
|
+
# Attempt to eliminate the unbacked SymInt
|
|
522
|
+
new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
|
|
523
|
+
assert new_expr is not None
|
|
524
|
+
if not (new_expr.free_symbols <= self.var_to_val.keys()):
|
|
525
|
+
ok = False
|
|
526
|
+
|
|
527
|
+
# fallback_value is set when guard_or_true or guard_or_false are used.
|
|
528
|
+
if not ok and fallback_value is not None:
|
|
529
|
+
self._log_suppressed_dde(orig_expr, fallback_value)
|
|
530
|
+
return fallback_value
|
|
531
|
+
|
|
532
|
+
# oblivious_var_to_val will be defined iff we have sizes
|
|
533
|
+
# with DimDynamic.OBLIVIOUS_SIZE type.
|
|
534
|
+
# See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
|
|
535
|
+
if (
|
|
536
|
+
self.oblivious_var_to_val
|
|
537
|
+
and not (
|
|
538
|
+
correct_hint := orig_expr.xreplace(self.oblivious_var_to_val)
|
|
539
|
+
).free_symbols
|
|
540
|
+
and not (
|
|
541
|
+
counterfactual_hint := orig_expr.xreplace(
|
|
542
|
+
{k: max(2, v) for k, v in self.oblivious_var_to_val.items()}
|
|
543
|
+
)
|
|
544
|
+
).free_symbols
|
|
545
|
+
and correct_hint == counterfactual_hint
|
|
546
|
+
):
|
|
547
|
+
# TODO: better logging
|
|
548
|
+
log.info(
|
|
549
|
+
"oblivious_size %s -> %s (passed counterfactual)",
|
|
550
|
+
orig_expr,
|
|
551
|
+
# pyrefly: ignore # unbound-name
|
|
552
|
+
correct_hint,
|
|
553
|
+
)
|
|
554
|
+
# pyrefly: ignore # unbound-name
|
|
555
|
+
concrete_val = correct_hint
|
|
556
|
+
# NB: do NOT transmute into runtime assert
|
|
557
|
+
ok = True
|
|
558
|
+
|
|
559
|
+
# unbacked_var_to_val is not None iff propagate_real_tensors is on.
|
|
560
|
+
# if propagate_real_tensors is on, we check the example values
|
|
561
|
+
# to generate (unsound_result)
|
|
562
|
+
# and if they pass we add a runtime assertions and continue.
|
|
563
|
+
if (
|
|
564
|
+
not ok
|
|
565
|
+
and self.unbacked_var_to_val
|
|
566
|
+
and not (
|
|
567
|
+
unsound_result := orig_expr.xreplace(
|
|
568
|
+
self.unbacked_var_to_val
|
|
569
|
+
).xreplace(self.var_to_val)
|
|
570
|
+
).free_symbols
|
|
571
|
+
):
|
|
572
|
+
# pyrefly: ignore # unbound-name
|
|
573
|
+
self._log_real_tensor_propagation(orig_expr, unsound_result)
|
|
574
|
+
transmute_into_runtime_assert = True
|
|
575
|
+
# pyrefly: ignore # unbound-name
|
|
576
|
+
concrete_val = unsound_result
|
|
577
|
+
ok = True
|
|
578
|
+
|
|
579
|
+
# Check if this is coming from a python assert statement,
|
|
580
|
+
# if so, convert it to a runtime assertion
|
|
581
|
+
# instead of failing.
|
|
582
|
+
if not ok and self.trace_asserts and self._is_python_assert():
|
|
583
|
+
concrete_val = sympy.true
|
|
584
|
+
transmute_into_runtime_assert = True
|
|
585
|
+
ok = True
|
|
586
|
+
|
|
587
|
+
# PATCHED: ok -> True
|
|
588
|
+
ok = True
|
|
589
|
+
# if not ok:
|
|
590
|
+
# raise self._make_data_dependent_error(
|
|
591
|
+
# expr.xreplace(self.var_to_val),
|
|
592
|
+
# expr,
|
|
593
|
+
# expr_sym_node_id=self._expr_sym_node_id,
|
|
594
|
+
# )
|
|
595
|
+
else:
|
|
596
|
+
expr = new_expr
|
|
597
|
+
|
|
598
|
+
if concrete_val is None:
|
|
599
|
+
concrete_val = compute_concrete_val()
|
|
600
|
+
self._check_frozen(expr, concrete_val)
|
|
601
|
+
|
|
602
|
+
if (
|
|
603
|
+
config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY
|
|
604
|
+
and isinstance(hint, bool)
|
|
605
|
+
and isinstance(expr, (sympy.Eq, sympy.Ne))
|
|
606
|
+
):
|
|
607
|
+
expr = sympy.Not(expr)
|
|
608
|
+
|
|
609
|
+
# Turn this into a boolean expression, no longer need to consult
|
|
610
|
+
# concrete_val
|
|
611
|
+
if concrete_val is sympy.true:
|
|
612
|
+
g = cast(SympyBoolean, expr)
|
|
613
|
+
elif concrete_val is sympy.false:
|
|
614
|
+
g = sympy.Not(expr)
|
|
615
|
+
else:
|
|
616
|
+
g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type]
|
|
617
|
+
|
|
618
|
+
if transmute_into_runtime_assert:
|
|
619
|
+
self.guard_or_defer_runtime_assert(
|
|
620
|
+
g, f"propagate_real_tensors: {orig_expr} == {concrete_val}"
|
|
621
|
+
)
|
|
622
|
+
return concrete_val
|
|
623
|
+
|
|
624
|
+
if not self._suppress_guards_tls():
|
|
625
|
+
self._log_guard("eval", g, forcing_spec=forcing_spec)
|
|
626
|
+
|
|
627
|
+
# TODO: If we successfully eliminate a symbol via equality, it
|
|
628
|
+
# is not actually necessary to save a guard for the equality,
|
|
629
|
+
# as we will implicitly generate a guard when we match that
|
|
630
|
+
# input against the symbol. Probably the easiest way to
|
|
631
|
+
# implement this is to have maybe_guard_rel return a bool
|
|
632
|
+
# saying if it "subsumed" the guard (and therefore the guard
|
|
633
|
+
# is no longer necessary)
|
|
634
|
+
self._maybe_guard_rel(g)
|
|
635
|
+
|
|
636
|
+
if (
|
|
637
|
+
torch.compiler.is_exporting()
|
|
638
|
+
and self.prefer_deferred_runtime_asserts_over_guards
|
|
639
|
+
):
|
|
640
|
+
# it's fine to defer simple guards here without checking,
|
|
641
|
+
# the _maybe_guard_rel() call above will set replacements if possible,
|
|
642
|
+
# and so the result here will be statically known
|
|
643
|
+
self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
|
|
644
|
+
else:
|
|
645
|
+
# at this point, we've evaluated the concrete expr value, and have
|
|
646
|
+
# flipped/negated the guard if necessary. Now we know what to guard
|
|
647
|
+
# or defer to runtime assert on.
|
|
648
|
+
guard = ShapeGuard(g, self._get_sloc(), size_oblivious=size_oblivious)
|
|
649
|
+
self.guards.append(guard)
|
|
650
|
+
self.axioms.update(dict(self.get_implications(self.simplify(g))))
|
|
651
|
+
else:
|
|
652
|
+
self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec)
|
|
653
|
+
|
|
654
|
+
except Exception:
|
|
655
|
+
if fresh:
|
|
656
|
+
self._remove_fx_node(node)
|
|
657
|
+
raise
|
|
658
|
+
|
|
659
|
+
if not self._suppress_guards_tls():
|
|
660
|
+
if guard is not None: # we might have deferred this to runtime assert
|
|
661
|
+
for s in g.free_symbols:
|
|
662
|
+
self.symbol_guard_counter[s] += 1
|
|
663
|
+
# Forcing_spec to avoid infinite recursion
|
|
664
|
+
if (
|
|
665
|
+
not forcing_spec
|
|
666
|
+
and config.symbol_guard_limit_before_specialize is not None
|
|
667
|
+
and self.symbol_guard_counter[s]
|
|
668
|
+
> config.symbol_guard_limit_before_specialize
|
|
669
|
+
):
|
|
670
|
+
# Force specialization
|
|
671
|
+
self.log.info(
|
|
672
|
+
"symbol_guard_limit_before_specialize=%s exceeded on %s",
|
|
673
|
+
config.symbol_guard_limit_before_specialize,
|
|
674
|
+
s,
|
|
675
|
+
)
|
|
676
|
+
self.evaluate_expr(s, forcing_spec=True)
|
|
677
|
+
|
|
678
|
+
return concrete_val
|
|
679
|
+
|
|
377
680
|
|
|
378
681
|
def patched_vmap(func, in_dims=0, out_dims=0):
|
|
379
682
|
"""
|
|
@@ -464,3 +767,230 @@ def patched_vmap(func, in_dims=0, out_dims=0):
|
|
|
464
767
|
return results
|
|
465
768
|
|
|
466
769
|
return wrapped
|
|
770
|
+
|
|
771
|
+
|
|
772
|
+
def patched__constrain_user_specified_dimhint_range(
|
|
773
|
+
symint: torch.SymInt,
|
|
774
|
+
hint: int,
|
|
775
|
+
dim: "_DimHint", # noqa: F821
|
|
776
|
+
range_constraints,
|
|
777
|
+
shape_env,
|
|
778
|
+
keypath: "KeyPath", # noqa: F821
|
|
779
|
+
i: Optional[int] = None,
|
|
780
|
+
) -> Optional[str]:
|
|
781
|
+
"""Patches ``torch._export.non_strict_utils._constrain_user_specified_dimhint_range``."""
|
|
782
|
+
from torch._export.non_strict_utils import is_int, int_oo, _DimHintType, ValueRanges
|
|
783
|
+
|
|
784
|
+
trace_vr = (
|
|
785
|
+
range_constraints[symint.node.expr]
|
|
786
|
+
if not is_int(symint)
|
|
787
|
+
else ValueRanges(int(symint), int(symint))
|
|
788
|
+
)
|
|
789
|
+
# warn on 0/1 specialization for Dim.AUTO; not an actual error
|
|
790
|
+
# PATCHED: remove logging
|
|
791
|
+
# if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1):
|
|
792
|
+
# pathstr = f"inputs{pytree.keystr(keypath)}"
|
|
793
|
+
# if i is not None:
|
|
794
|
+
# pathstr += f".shape[{i}]"
|
|
795
|
+
# msg = (
|
|
796
|
+
# f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along "
|
|
797
|
+
# f"with a sample input with hint = {hint}."
|
|
798
|
+
# )
|
|
799
|
+
# log.warning(msg)
|
|
800
|
+
|
|
801
|
+
try:
|
|
802
|
+
user_vr = ValueRanges(
|
|
803
|
+
lower=0 if dim.min is None else dim.min,
|
|
804
|
+
upper=int_oo if dim.max is None else dim.max,
|
|
805
|
+
)
|
|
806
|
+
if is_int(symint):
|
|
807
|
+
out_vr = trace_vr & user_vr
|
|
808
|
+
else:
|
|
809
|
+
range_constraints[symint.node.expr] &= user_vr
|
|
810
|
+
shape_env.var_to_range[symint.node._expr] &= user_vr
|
|
811
|
+
out_vr = range_constraints[symint.node.expr]
|
|
812
|
+
|
|
813
|
+
# check for Dim.DYNAMIC specializations; special case error message on 0/1
|
|
814
|
+
if dim.type == _DimHintType.DYNAMIC and out_vr.is_singleton():
|
|
815
|
+
path = f"inputs{torch.utils._pytree.keystr(keypath)}"
|
|
816
|
+
if i is not None:
|
|
817
|
+
path += f".shape[{i}]"
|
|
818
|
+
if (
|
|
819
|
+
trace_vr.is_singleton()
|
|
820
|
+
and hint in (0, 1)
|
|
821
|
+
# PATCHED: line removed
|
|
822
|
+
# and not torch.fx.experimental._config.backed_size_oblivious
|
|
823
|
+
):
|
|
824
|
+
return None
|
|
825
|
+
# PATCHED: line removed
|
|
826
|
+
# msg = (
|
|
827
|
+
# f"- Received user-specified dim hint "
|
|
828
|
+
# f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
|
|
829
|
+
# f"but export 0/1 specialized due to hint of "
|
|
830
|
+
# f"{hint} for dimension {path}."
|
|
831
|
+
# )
|
|
832
|
+
else:
|
|
833
|
+
msg = (
|
|
834
|
+
f"- Received user-specified dim hint "
|
|
835
|
+
f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
|
|
836
|
+
f"but tracing inferred a static shape of "
|
|
837
|
+
f"{out_vr.lower} for dimension {path}."
|
|
838
|
+
)
|
|
839
|
+
return msg
|
|
840
|
+
|
|
841
|
+
except torch.utils._sympy.value_ranges.ValueRangeError:
|
|
842
|
+
path = f"inputs{torch.utils._pytree.keystr(keypath)}"
|
|
843
|
+
if i is not None:
|
|
844
|
+
path += f".shape[{i}]"
|
|
845
|
+
msg = (
|
|
846
|
+
f"- Received user-specified min/max range of [{dim.min}, {dim.max}], "
|
|
847
|
+
f"conflicting with the inferred min/max range of "
|
|
848
|
+
f"[{trace_vr.lower}, {trace_vr.upper}], "
|
|
849
|
+
f"for {path}."
|
|
850
|
+
)
|
|
851
|
+
return msg
|
|
852
|
+
|
|
853
|
+
return None
|
|
854
|
+
|
|
855
|
+
|
|
856
|
+
def patched__maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
|
|
857
|
+
"""Patches ``torch._refs._maybe_broadcast``."""
|
|
858
|
+
from torch._prims_common import ShapeType, TensorLike, Number
|
|
859
|
+
|
|
860
|
+
# Computes common shape
|
|
861
|
+
common_shape = patched__broadcast_shapes(
|
|
862
|
+
*(t.shape if isinstance(t, TensorLike) else None for t in args)
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
def should_expand(a: ShapeType, b: ShapeType) -> bool:
|
|
866
|
+
from torch.fx.experimental.symbolic_shapes import (
|
|
867
|
+
guard_or_false,
|
|
868
|
+
sym_and,
|
|
869
|
+
sym_or,
|
|
870
|
+
)
|
|
871
|
+
|
|
872
|
+
if len(a) != len(b):
|
|
873
|
+
return True
|
|
874
|
+
|
|
875
|
+
for x, y in zip(a, b):
|
|
876
|
+
if guard_or_false(x != y):
|
|
877
|
+
# We know they are not the same.
|
|
878
|
+
return True
|
|
879
|
+
|
|
880
|
+
# They are the same or we do not know if they are the same or not.
|
|
881
|
+
# 1==1 no-broadcast
|
|
882
|
+
# u0==1 and 1==u0 cases. We broadcast!
|
|
883
|
+
if guard_or_false(sym_and(x == 1, y == 1)):
|
|
884
|
+
pass
|
|
885
|
+
elif guard_or_false(sym_or(x == 1, y == 1)):
|
|
886
|
+
# assume broadcasting.
|
|
887
|
+
return True
|
|
888
|
+
|
|
889
|
+
# u0==u1 assume the same, no broadcasting!
|
|
890
|
+
# PATCHED: avoid errors
|
|
891
|
+
return True # guard_or_true(x != y)
|
|
892
|
+
# torch._check(
|
|
893
|
+
# x == y,
|
|
894
|
+
# lambda x=x, y=y: (
|
|
895
|
+
# f"sizes assumed to be the same due to unbacked "
|
|
896
|
+
# f"broadcasting semantics x={x!r}, y={y!r}"
|
|
897
|
+
# ),
|
|
898
|
+
# )
|
|
899
|
+
|
|
900
|
+
return False
|
|
901
|
+
|
|
902
|
+
def __maybe_broadcast(x, shape):
|
|
903
|
+
if x is None:
|
|
904
|
+
return None
|
|
905
|
+
elif isinstance(x, Number):
|
|
906
|
+
return x
|
|
907
|
+
elif isinstance(x, TensorLike):
|
|
908
|
+
if preserve_cpu_scalar_tensors and torch._prims_common.is_cpu_scalar_tensor(x):
|
|
909
|
+
return x
|
|
910
|
+
|
|
911
|
+
if should_expand(x.shape, common_shape):
|
|
912
|
+
return x.expand(common_shape)
|
|
913
|
+
|
|
914
|
+
return x
|
|
915
|
+
else:
|
|
916
|
+
raise RuntimeError(f"Unexpected type when broadcasting: {str(type(x))}!")
|
|
917
|
+
|
|
918
|
+
return tuple(__maybe_broadcast(x, common_shape) for x in args)
|
|
919
|
+
|
|
920
|
+
|
|
921
|
+
def patched__broadcast_in_dim_meta(
|
|
922
|
+
a: torch._prims_common.TensorLikeType,
|
|
923
|
+
shape: torch._prims_common.ShapeType,
|
|
924
|
+
broadcast_dimensions: Sequence[int],
|
|
925
|
+
):
|
|
926
|
+
"""Patches ``torch._prims._broadcast_in_dim_meta``."""
|
|
927
|
+
from torch.fx.experimental.symbolic_shapes import (
|
|
928
|
+
guard_or_false,
|
|
929
|
+
guard_or_true,
|
|
930
|
+
sym_or,
|
|
931
|
+
)
|
|
932
|
+
|
|
933
|
+
# Type checks
|
|
934
|
+
assert isinstance(a, torch._prims_common.TensorLike)
|
|
935
|
+
assert isinstance(shape, Sequence)
|
|
936
|
+
assert isinstance(broadcast_dimensions, Sequence)
|
|
937
|
+
|
|
938
|
+
# every dimension must be accounted for
|
|
939
|
+
assert a.ndim == len(broadcast_dimensions)
|
|
940
|
+
|
|
941
|
+
# broadcast shape must have weakly more dimensions
|
|
942
|
+
assert len(shape) >= a.ndim
|
|
943
|
+
|
|
944
|
+
# broadcast_dimensions must be an ascending sequence
|
|
945
|
+
# (no relative reordering of dims) of integers and
|
|
946
|
+
# each dimension must be within the new shape
|
|
947
|
+
def _greater_than_reduce(acc, x):
|
|
948
|
+
assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
|
|
949
|
+
assert x > acc
|
|
950
|
+
assert x < len(shape)
|
|
951
|
+
|
|
952
|
+
return x
|
|
953
|
+
|
|
954
|
+
reduce(_greater_than_reduce, broadcast_dimensions, -1)
|
|
955
|
+
|
|
956
|
+
# shape must be broadcastable to
|
|
957
|
+
for idx, new_idx in enumerate(broadcast_dimensions):
|
|
958
|
+
torch._check(
|
|
959
|
+
sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
|
|
960
|
+
lambda idx=idx, new_idx=new_idx: (
|
|
961
|
+
f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
|
|
962
|
+
),
|
|
963
|
+
)
|
|
964
|
+
|
|
965
|
+
new_strides = []
|
|
966
|
+
original_idx = 0
|
|
967
|
+
for idx in range(len(shape)):
|
|
968
|
+
if idx in broadcast_dimensions:
|
|
969
|
+
# Assigns a stride of zero to dimensions
|
|
970
|
+
# which were actually broadcast
|
|
971
|
+
if guard_or_false(a.shape[original_idx] == 1):
|
|
972
|
+
if guard_or_false(a.shape[original_idx] == shape[idx]):
|
|
973
|
+
new_strides.append(a.stride()[original_idx])
|
|
974
|
+
else:
|
|
975
|
+
new_strides.append(0)
|
|
976
|
+
else:
|
|
977
|
+
# PATCHED: disabled this check
|
|
978
|
+
# torch._check(
|
|
979
|
+
# a.shape[original_idx] == shape[idx],
|
|
980
|
+
# lambda idx=idx, original_idx=original_idx: (
|
|
981
|
+
# f"non-broadcasting semantics require "
|
|
982
|
+
# f"{a.shape[original_idx]} == {shape[idx]}"
|
|
983
|
+
# ),
|
|
984
|
+
# )
|
|
985
|
+
new_strides.append(a.stride()[original_idx])
|
|
986
|
+
original_idx = original_idx + 1
|
|
987
|
+
else:
|
|
988
|
+
if guard_or_true(shape[idx] != 1):
|
|
989
|
+
# consistent with previous use of guard_size_oblivious
|
|
990
|
+
new_strides.append(0)
|
|
991
|
+
elif original_idx == a.ndim:
|
|
992
|
+
new_strides.append(1)
|
|
993
|
+
else:
|
|
994
|
+
new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
|
|
995
|
+
|
|
996
|
+
return a.as_strided(shape, new_strides, a.storage_offset())
|