onnx-diagnostic 0.7.13__py3-none-any.whl → 0.7.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +15 -3
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/config_helper.py +2 -1
- onnx_diagnostic/helpers/log_helper.py +53 -17
- onnx_diagnostic/helpers/rt_helper.py +3 -3
- onnx_diagnostic/tasks/image_text_to_text.py +6 -5
- onnx_diagnostic/tasks/text_generation.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +7 -1
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +1 -4
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +31 -13
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +445 -9
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +79 -28
- onnx_diagnostic/torch_models/hghub/model_inputs.py +31 -5
- onnx_diagnostic/torch_models/validate.py +41 -28
- {onnx_diagnostic-0.7.13.dist-info → onnx_diagnostic-0.7.15.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.13.dist-info → onnx_diagnostic-0.7.15.dist-info}/RECORD +21 -21
- {onnx_diagnostic-0.7.13.dist-info → onnx_diagnostic-0.7.15.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.13.dist-info → onnx_diagnostic-0.7.15.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.13.dist-info → onnx_diagnostic-0.7.15.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,10 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import inspect
|
|
3
|
+
import operator
|
|
2
4
|
import os
|
|
3
5
|
import traceback
|
|
4
|
-
from
|
|
6
|
+
from functools import reduce
|
|
7
|
+
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
|
|
5
8
|
import torch
|
|
6
9
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
7
10
|
|
|
@@ -85,7 +88,7 @@ def patch__check_input_constraints_for_graph(
|
|
|
85
88
|
|
|
86
89
|
def patched_infer_size(a, b):
|
|
87
90
|
"""Patches ``torch._subclasses.fake_impls.infer_size``."""
|
|
88
|
-
from torch.fx.experimental.symbolic_shapes import
|
|
91
|
+
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
|
89
92
|
|
|
90
93
|
dimsA = len(a)
|
|
91
94
|
dimsB = len(b)
|
|
@@ -110,19 +113,19 @@ def patched_infer_size(a, b):
|
|
|
110
113
|
# were not the case, we'd need to write this using torch.sym_or() or
|
|
111
114
|
# something like that).
|
|
112
115
|
try:
|
|
113
|
-
b1 =
|
|
116
|
+
b1 = guard_or_false(sizeA == 1)
|
|
114
117
|
except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
|
|
115
118
|
b1 = False
|
|
116
119
|
try:
|
|
117
|
-
b2 =
|
|
120
|
+
b2 = guard_or_false(sizeB == 1)
|
|
118
121
|
except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
|
|
119
122
|
b2 = False
|
|
120
123
|
try:
|
|
121
|
-
b3 =
|
|
124
|
+
b3 = guard_or_false(sizeA == sizeB)
|
|
122
125
|
except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
|
|
123
126
|
b3 = False
|
|
124
127
|
if b1 or b2 or b3:
|
|
125
|
-
expandedSizes[i] = sizeB if
|
|
128
|
+
expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
|
|
126
129
|
else:
|
|
127
130
|
# PATCHED: generic case, the dimension is known, no need to assert
|
|
128
131
|
expandedSizes[i] = torch.sym_max(sizeA, sizeB)
|
|
@@ -134,7 +137,6 @@ def patched__broadcast_shapes(*_shapes):
|
|
|
134
137
|
from functools import reduce
|
|
135
138
|
from torch._prims_common import IntLike
|
|
136
139
|
from torch.fx.experimental.symbolic_shapes import (
|
|
137
|
-
guard_size_oblivious,
|
|
138
140
|
guard_or_false,
|
|
139
141
|
is_nested_int,
|
|
140
142
|
)
|
|
@@ -171,13 +173,15 @@ def patched__broadcast_shapes(*_shapes):
|
|
|
171
173
|
continue
|
|
172
174
|
# PATCHED: two cases, if == for sure, no broadcast,
|
|
173
175
|
# otherwise maybe broadcast with max(dimensions)
|
|
174
|
-
if
|
|
176
|
+
if guard_or_false(common_shape[idx] != 1):
|
|
177
|
+
pass
|
|
178
|
+
elif guard_or_false(common_shape[idx] == 1) or guard_or_false(shape[idx] != 1):
|
|
175
179
|
if shape[idx] < 0:
|
|
176
180
|
raise ValueError(
|
|
177
181
|
"Attempting to broadcast a dimension with negative length!"
|
|
178
182
|
)
|
|
179
183
|
common_shape[idx] = shape[idx]
|
|
180
|
-
|
|
184
|
+
else:
|
|
181
185
|
common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx])
|
|
182
186
|
|
|
183
187
|
return common_shape
|
|
@@ -357,6 +361,10 @@ class patched_ShapeEnv:
|
|
|
357
361
|
},
|
|
358
362
|
)
|
|
359
363
|
|
|
364
|
+
for source in self.var_to_sources.get(a, []):
|
|
365
|
+
if user_tb:
|
|
366
|
+
self.specialization_stacks[source] = user_tb
|
|
367
|
+
|
|
360
368
|
# PATCHED: removed lines
|
|
361
369
|
# if config.print_specializations:
|
|
362
370
|
# self.log.warning(
|
|
@@ -396,6 +404,284 @@ class patched_ShapeEnv:
|
|
|
396
404
|
# stacklevel=0,
|
|
397
405
|
# )
|
|
398
406
|
|
|
407
|
+
def _evaluate_expr(
|
|
408
|
+
self,
|
|
409
|
+
orig_expr: "sympy.Basic", # noqa: F821
|
|
410
|
+
hint: Optional[Union[bool, int, float]] = None,
|
|
411
|
+
fx_node: Optional[torch.fx.Node] = None,
|
|
412
|
+
size_oblivious: bool = False,
|
|
413
|
+
fallback_value: Optional[bool] = None,
|
|
414
|
+
*,
|
|
415
|
+
forcing_spec: bool = False,
|
|
416
|
+
) -> "sympy.Basic": # noqa: F821
|
|
417
|
+
# TODO: split conjunctions and evaluate them separately
|
|
418
|
+
import sympy
|
|
419
|
+
from torch.fx.experimental import _config as config
|
|
420
|
+
from torch.fx.experimental.symbolic_shapes import (
|
|
421
|
+
SympyBoolean,
|
|
422
|
+
log,
|
|
423
|
+
SymT,
|
|
424
|
+
symbol_is_type,
|
|
425
|
+
)
|
|
426
|
+
from torch._guards import ShapeGuard
|
|
427
|
+
|
|
428
|
+
if isinstance(
|
|
429
|
+
orig_expr,
|
|
430
|
+
(sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse),
|
|
431
|
+
):
|
|
432
|
+
return orig_expr
|
|
433
|
+
|
|
434
|
+
# Don't track this one. (Because this cache is inside this function the
|
|
435
|
+
# cache only lasts for the invocation of this function call)
|
|
436
|
+
@functools.cache
|
|
437
|
+
def compute_concrete_val() -> sympy.Basic:
|
|
438
|
+
if hint is None:
|
|
439
|
+
# This is only ever called for expressions WITHOUT unbacked
|
|
440
|
+
# symbols
|
|
441
|
+
r = self.size_hint(orig_expr)
|
|
442
|
+
assert r is not None
|
|
443
|
+
return r
|
|
444
|
+
else:
|
|
445
|
+
return sympy.sympify(hint)
|
|
446
|
+
|
|
447
|
+
concrete_val: Optional[sympy.Basic]
|
|
448
|
+
|
|
449
|
+
# Check if:
|
|
450
|
+
# 1. 'translation_validation' is set
|
|
451
|
+
# 2. the corresponding 'fx_node' is not 'None'
|
|
452
|
+
# 3. the guard should not be suppressed
|
|
453
|
+
# 4. the guard doesn't contain backed symfloat symbols
|
|
454
|
+
# since z3 can't handle floats
|
|
455
|
+
# 5. fallback_value is none.
|
|
456
|
+
# If all of the above check, we create an FX node representing the
|
|
457
|
+
# actual expression to be guarded.
|
|
458
|
+
node = None
|
|
459
|
+
fresh = False
|
|
460
|
+
if (
|
|
461
|
+
self._translation_validation_enabled
|
|
462
|
+
and fx_node is not None
|
|
463
|
+
and not self._suppress_guards_tls()
|
|
464
|
+
and not size_oblivious
|
|
465
|
+
and not any(symbol_is_type(s, SymT.FLOAT) for s in orig_expr.free_symbols)
|
|
466
|
+
and fallback_value is None
|
|
467
|
+
):
|
|
468
|
+
# TODO: does this even worked with unbacked :think:
|
|
469
|
+
concrete_val = compute_concrete_val()
|
|
470
|
+
if concrete_val is sympy.true:
|
|
471
|
+
node, fresh = self._create_fx_call_function(torch._assert, (fx_node,))
|
|
472
|
+
elif concrete_val is sympy.false:
|
|
473
|
+
neg, _ = self._create_fx_call_function(operator.not_, (fx_node,))
|
|
474
|
+
node, fresh = self._create_fx_call_function(torch._assert, (neg,))
|
|
475
|
+
else:
|
|
476
|
+
eql, _ = self._create_fx_call_function(operator.eq, (fx_node, concrete_val))
|
|
477
|
+
node, fresh = self._create_fx_call_function(torch._assert, (eql,))
|
|
478
|
+
|
|
479
|
+
assert node is not None
|
|
480
|
+
# If this is a fresh node, we have to remember the event index that
|
|
481
|
+
# corresponds to this assertion node.
|
|
482
|
+
# Reason: so that, given an assertion node, we can replay the ShapeEnv
|
|
483
|
+
# events until the point where this assertion node was freshly created.
|
|
484
|
+
if fresh:
|
|
485
|
+
self._add_fx_node_metadata(node)
|
|
486
|
+
|
|
487
|
+
# After creating the FX node corresponding to orig_expr, we must make sure that
|
|
488
|
+
# no error will be raised until the end of this function.
|
|
489
|
+
#
|
|
490
|
+
# Reason: the translation validation may become invalid otherwise.
|
|
491
|
+
#
|
|
492
|
+
# If an error is raised before the end of this function, we remove the FX node
|
|
493
|
+
# inserted, and re-raise the error.
|
|
494
|
+
guard = None
|
|
495
|
+
|
|
496
|
+
try:
|
|
497
|
+
if orig_expr.is_number:
|
|
498
|
+
self.log.debug("eval %s [trivial]", orig_expr)
|
|
499
|
+
if hint is not None:
|
|
500
|
+
if isinstance(hint, bool):
|
|
501
|
+
assert orig_expr == hint, f"{orig_expr} != {hint}"
|
|
502
|
+
else:
|
|
503
|
+
assert sympy.Eq(orig_expr, hint), f"{orig_expr} != {hint}"
|
|
504
|
+
return orig_expr
|
|
505
|
+
|
|
506
|
+
expr = orig_expr
|
|
507
|
+
|
|
508
|
+
static_expr = self._maybe_evaluate_static(expr, size_oblivious=size_oblivious)
|
|
509
|
+
if static_expr is not None:
|
|
510
|
+
self.log.debug(
|
|
511
|
+
"eval %s == %s [statically known]",
|
|
512
|
+
(f"size_oblivious({orig_expr})" if size_oblivious else size_oblivious),
|
|
513
|
+
static_expr,
|
|
514
|
+
)
|
|
515
|
+
if not size_oblivious and config.backed_size_oblivious and hint is not None:
|
|
516
|
+
# TODO: maybe reconcile this with use of counterfactual hints
|
|
517
|
+
# in unbacked case
|
|
518
|
+
assert static_expr == hint, f"{static_expr} != {hint}"
|
|
519
|
+
return static_expr
|
|
520
|
+
|
|
521
|
+
transmute_into_runtime_assert = False
|
|
522
|
+
|
|
523
|
+
concrete_val = None
|
|
524
|
+
if not (expr.free_symbols <= self.var_to_val.keys()):
|
|
525
|
+
# TODO: dedupe this with _maybe_evaluate_static
|
|
526
|
+
# Attempt to eliminate the unbacked SymInt
|
|
527
|
+
new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
|
|
528
|
+
assert new_expr is not None
|
|
529
|
+
if not (new_expr.free_symbols <= self.var_to_val.keys()):
|
|
530
|
+
ok = False
|
|
531
|
+
|
|
532
|
+
# fallback_value is set when guard_or_true or guard_or_false are used.
|
|
533
|
+
if not ok and fallback_value is not None:
|
|
534
|
+
self._log_suppressed_dde(orig_expr, fallback_value)
|
|
535
|
+
return fallback_value
|
|
536
|
+
|
|
537
|
+
# oblivious_var_to_val will be defined iff we have sizes
|
|
538
|
+
# with DimDynamic.OBLIVIOUS_SIZE type.
|
|
539
|
+
# See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
|
|
540
|
+
if (
|
|
541
|
+
self.oblivious_var_to_val
|
|
542
|
+
and not (
|
|
543
|
+
correct_hint := orig_expr.xreplace(self.oblivious_var_to_val)
|
|
544
|
+
).free_symbols
|
|
545
|
+
and not (
|
|
546
|
+
counterfactual_hint := orig_expr.xreplace(
|
|
547
|
+
{k: max(2, v) for k, v in self.oblivious_var_to_val.items()}
|
|
548
|
+
)
|
|
549
|
+
).free_symbols
|
|
550
|
+
and correct_hint == counterfactual_hint
|
|
551
|
+
):
|
|
552
|
+
# TODO: better logging
|
|
553
|
+
log.info(
|
|
554
|
+
"oblivious_size %s -> %s (passed counterfactual)",
|
|
555
|
+
orig_expr,
|
|
556
|
+
# pyrefly: ignore # unbound-name
|
|
557
|
+
correct_hint,
|
|
558
|
+
)
|
|
559
|
+
# pyrefly: ignore # unbound-name
|
|
560
|
+
concrete_val = correct_hint
|
|
561
|
+
# NB: do NOT transmute into runtime assert
|
|
562
|
+
ok = True
|
|
563
|
+
|
|
564
|
+
# unbacked_var_to_val is not None iff propagate_real_tensors is on.
|
|
565
|
+
# if propagate_real_tensors is on, we check the example values
|
|
566
|
+
# to generate (unsound_result)
|
|
567
|
+
# and if they pass we add a runtime assertions and continue.
|
|
568
|
+
if (
|
|
569
|
+
not ok
|
|
570
|
+
and self.unbacked_var_to_val
|
|
571
|
+
and not (
|
|
572
|
+
unsound_result := orig_expr.xreplace(
|
|
573
|
+
self.unbacked_var_to_val
|
|
574
|
+
).xreplace(self.var_to_val)
|
|
575
|
+
).free_symbols
|
|
576
|
+
):
|
|
577
|
+
# pyrefly: ignore # unbound-name
|
|
578
|
+
self._log_real_tensor_propagation(orig_expr, unsound_result)
|
|
579
|
+
transmute_into_runtime_assert = True
|
|
580
|
+
# pyrefly: ignore # unbound-name
|
|
581
|
+
concrete_val = unsound_result
|
|
582
|
+
ok = True
|
|
583
|
+
|
|
584
|
+
# Check if this is coming from a python assert statement,
|
|
585
|
+
# if so, convert it to a runtime assertion
|
|
586
|
+
# instead of failing.
|
|
587
|
+
if not ok and self.trace_asserts and self._is_python_assert():
|
|
588
|
+
concrete_val = sympy.true
|
|
589
|
+
transmute_into_runtime_assert = True
|
|
590
|
+
ok = True
|
|
591
|
+
|
|
592
|
+
# PATCHED: ok -> True
|
|
593
|
+
ok = True
|
|
594
|
+
# if not ok:
|
|
595
|
+
# raise self._make_data_dependent_error(
|
|
596
|
+
# expr.xreplace(self.var_to_val),
|
|
597
|
+
# expr,
|
|
598
|
+
# expr_sym_node_id=self._expr_sym_node_id,
|
|
599
|
+
# )
|
|
600
|
+
else:
|
|
601
|
+
expr = new_expr
|
|
602
|
+
|
|
603
|
+
if concrete_val is None:
|
|
604
|
+
concrete_val = compute_concrete_val()
|
|
605
|
+
self._check_frozen(expr, concrete_val)
|
|
606
|
+
|
|
607
|
+
if (
|
|
608
|
+
config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY
|
|
609
|
+
and isinstance(hint, bool)
|
|
610
|
+
and isinstance(expr, (sympy.Eq, sympy.Ne))
|
|
611
|
+
):
|
|
612
|
+
expr = sympy.Not(expr)
|
|
613
|
+
|
|
614
|
+
# Turn this into a boolean expression, no longer need to consult
|
|
615
|
+
# concrete_val
|
|
616
|
+
if concrete_val is sympy.true:
|
|
617
|
+
g = cast(SympyBoolean, expr)
|
|
618
|
+
elif concrete_val is sympy.false:
|
|
619
|
+
g = sympy.Not(expr)
|
|
620
|
+
else:
|
|
621
|
+
g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type]
|
|
622
|
+
|
|
623
|
+
if transmute_into_runtime_assert:
|
|
624
|
+
self.guard_or_defer_runtime_assert(
|
|
625
|
+
g, f"propagate_real_tensors: {orig_expr} == {concrete_val}"
|
|
626
|
+
)
|
|
627
|
+
return concrete_val
|
|
628
|
+
|
|
629
|
+
if not self._suppress_guards_tls():
|
|
630
|
+
self._log_guard("eval", g, forcing_spec=forcing_spec)
|
|
631
|
+
|
|
632
|
+
# TODO: If we successfully eliminate a symbol via equality, it
|
|
633
|
+
# is not actually necessary to save a guard for the equality,
|
|
634
|
+
# as we will implicitly generate a guard when we match that
|
|
635
|
+
# input against the symbol. Probably the easiest way to
|
|
636
|
+
# implement this is to have maybe_guard_rel return a bool
|
|
637
|
+
# saying if it "subsumed" the guard (and therefore the guard
|
|
638
|
+
# is no longer necessary)
|
|
639
|
+
self._maybe_guard_rel(g)
|
|
640
|
+
|
|
641
|
+
if (
|
|
642
|
+
torch.compiler.is_exporting()
|
|
643
|
+
and self.prefer_deferred_runtime_asserts_over_guards
|
|
644
|
+
):
|
|
645
|
+
# it's fine to defer simple guards here without checking,
|
|
646
|
+
# the _maybe_guard_rel() call above will set replacements if possible,
|
|
647
|
+
# and so the result here will be statically known
|
|
648
|
+
self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
|
|
649
|
+
else:
|
|
650
|
+
# at this point, we've evaluated the concrete expr value, and have
|
|
651
|
+
# flipped/negated the guard if necessary. Now we know what to guard
|
|
652
|
+
# or defer to runtime assert on.
|
|
653
|
+
guard = ShapeGuard(g, self._get_sloc(), size_oblivious=size_oblivious)
|
|
654
|
+
self.guards.append(guard)
|
|
655
|
+
self.axioms.update(dict(self.get_implications(self.simplify(g))))
|
|
656
|
+
else:
|
|
657
|
+
self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec)
|
|
658
|
+
|
|
659
|
+
except Exception:
|
|
660
|
+
if fresh:
|
|
661
|
+
self._remove_fx_node(node)
|
|
662
|
+
raise
|
|
663
|
+
|
|
664
|
+
if not self._suppress_guards_tls():
|
|
665
|
+
if guard is not None: # we might have deferred this to runtime assert
|
|
666
|
+
for s in g.free_symbols:
|
|
667
|
+
self.symbol_guard_counter[s] += 1
|
|
668
|
+
# Forcing_spec to avoid infinite recursion
|
|
669
|
+
if (
|
|
670
|
+
not forcing_spec
|
|
671
|
+
and config.symbol_guard_limit_before_specialize is not None
|
|
672
|
+
and self.symbol_guard_counter[s]
|
|
673
|
+
> config.symbol_guard_limit_before_specialize
|
|
674
|
+
):
|
|
675
|
+
# Force specialization
|
|
676
|
+
self.log.info(
|
|
677
|
+
"symbol_guard_limit_before_specialize=%s exceeded on %s",
|
|
678
|
+
config.symbol_guard_limit_before_specialize,
|
|
679
|
+
s,
|
|
680
|
+
)
|
|
681
|
+
self.evaluate_expr(s, forcing_spec=True)
|
|
682
|
+
|
|
683
|
+
return concrete_val
|
|
684
|
+
|
|
399
685
|
|
|
400
686
|
def patched_vmap(func, in_dims=0, out_dims=0):
|
|
401
687
|
"""
|
|
@@ -570,3 +856,153 @@ def patched__constrain_user_specified_dimhint_range(
|
|
|
570
856
|
return msg
|
|
571
857
|
|
|
572
858
|
return None
|
|
859
|
+
|
|
860
|
+
|
|
861
|
+
def patched__maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
|
|
862
|
+
"""Patches ``torch._refs._maybe_broadcast``."""
|
|
863
|
+
from torch._prims_common import ShapeType, TensorLike, Number
|
|
864
|
+
|
|
865
|
+
# Computes common shape
|
|
866
|
+
common_shape = patched__broadcast_shapes(
|
|
867
|
+
*(t.shape if isinstance(t, TensorLike) else None for t in args)
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
def should_expand(a: ShapeType, b: ShapeType) -> bool:
|
|
871
|
+
from torch.fx.experimental.symbolic_shapes import (
|
|
872
|
+
guard_or_false,
|
|
873
|
+
sym_and,
|
|
874
|
+
sym_or,
|
|
875
|
+
)
|
|
876
|
+
|
|
877
|
+
if len(a) != len(b):
|
|
878
|
+
return True
|
|
879
|
+
|
|
880
|
+
for x, y in zip(a, b):
|
|
881
|
+
if guard_or_false(x != y):
|
|
882
|
+
# We know they are not the same.
|
|
883
|
+
return True
|
|
884
|
+
|
|
885
|
+
# They are the same or we do not know if they are the same or not.
|
|
886
|
+
# 1==1 no-broadcast
|
|
887
|
+
# u0==1 and 1==u0 cases. We broadcast!
|
|
888
|
+
if guard_or_false(sym_and(x == 1, y == 1)):
|
|
889
|
+
pass
|
|
890
|
+
elif guard_or_false(sym_or(x == 1, y == 1)):
|
|
891
|
+
# assume broadcasting.
|
|
892
|
+
return True
|
|
893
|
+
|
|
894
|
+
# u0==u1 assume the same, no broadcasting!
|
|
895
|
+
# PATCHED: avoid errors
|
|
896
|
+
return True # guard_or_true(x != y)
|
|
897
|
+
# torch._check(
|
|
898
|
+
# x == y,
|
|
899
|
+
# lambda x=x, y=y: (
|
|
900
|
+
# f"sizes assumed to be the same due to unbacked "
|
|
901
|
+
# f"broadcasting semantics x={x!r}, y={y!r}"
|
|
902
|
+
# ),
|
|
903
|
+
# )
|
|
904
|
+
|
|
905
|
+
return False
|
|
906
|
+
|
|
907
|
+
def __maybe_broadcast(x, shape):
|
|
908
|
+
if x is None:
|
|
909
|
+
return None
|
|
910
|
+
elif isinstance(x, Number):
|
|
911
|
+
return x
|
|
912
|
+
elif isinstance(x, TensorLike):
|
|
913
|
+
if preserve_cpu_scalar_tensors and torch._prims_common.is_cpu_scalar_tensor(x):
|
|
914
|
+
return x
|
|
915
|
+
|
|
916
|
+
if should_expand(x.shape, common_shape):
|
|
917
|
+
return x.expand(common_shape)
|
|
918
|
+
|
|
919
|
+
return x
|
|
920
|
+
else:
|
|
921
|
+
raise RuntimeError(f"Unexpected type when broadcasting: {str(type(x))}!")
|
|
922
|
+
|
|
923
|
+
return tuple(__maybe_broadcast(x, common_shape) for x in args)
|
|
924
|
+
|
|
925
|
+
|
|
926
|
+
def patched__broadcast_in_dim_meta(
|
|
927
|
+
a: torch._prims_common.TensorLikeType,
|
|
928
|
+
shape: torch._prims_common.ShapeType,
|
|
929
|
+
broadcast_dimensions: Sequence[int],
|
|
930
|
+
):
|
|
931
|
+
"""Patches ``torch._prims._broadcast_in_dim_meta``."""
|
|
932
|
+
from torch.fx.experimental.symbolic_shapes import (
|
|
933
|
+
guard_or_false,
|
|
934
|
+
guard_or_true,
|
|
935
|
+
sym_or,
|
|
936
|
+
)
|
|
937
|
+
|
|
938
|
+
# Type checks
|
|
939
|
+
assert isinstance(a, torch._prims_common.TensorLike)
|
|
940
|
+
assert isinstance(shape, Sequence)
|
|
941
|
+
assert isinstance(broadcast_dimensions, Sequence)
|
|
942
|
+
|
|
943
|
+
# every dimension must be accounted for
|
|
944
|
+
assert a.ndim == len(broadcast_dimensions)
|
|
945
|
+
|
|
946
|
+
# broadcast shape must have weakly more dimensions
|
|
947
|
+
assert len(shape) >= a.ndim
|
|
948
|
+
|
|
949
|
+
# broadcast_dimensions must be an ascending sequence
|
|
950
|
+
# (no relative reordering of dims) of integers and
|
|
951
|
+
# each dimension must be within the new shape
|
|
952
|
+
def _greater_than_reduce(acc, x):
|
|
953
|
+
assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
|
|
954
|
+
assert x > acc
|
|
955
|
+
assert x < len(shape)
|
|
956
|
+
|
|
957
|
+
return x
|
|
958
|
+
|
|
959
|
+
reduce(_greater_than_reduce, broadcast_dimensions, -1)
|
|
960
|
+
|
|
961
|
+
# shape must be broadcastable to
|
|
962
|
+
for idx, new_idx in enumerate(broadcast_dimensions):
|
|
963
|
+
torch._check(
|
|
964
|
+
sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
|
|
965
|
+
lambda idx=idx, new_idx=new_idx: (
|
|
966
|
+
f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
|
|
967
|
+
),
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
new_strides = []
|
|
971
|
+
original_idx = 0
|
|
972
|
+
for idx in range(len(shape)):
|
|
973
|
+
if idx in broadcast_dimensions:
|
|
974
|
+
# Assigns a stride of zero to dimensions
|
|
975
|
+
# which were actually broadcast
|
|
976
|
+
if guard_or_false(a.shape[original_idx] == 1):
|
|
977
|
+
if guard_or_false(a.shape[original_idx] == shape[idx]):
|
|
978
|
+
new_strides.append(a.stride()[original_idx])
|
|
979
|
+
else:
|
|
980
|
+
new_strides.append(0)
|
|
981
|
+
# PATCHED: disabled this check
|
|
982
|
+
elif guard_or_false(a.shape[original_idx] != 1):
|
|
983
|
+
new_strides.append(a.stride()[original_idx])
|
|
984
|
+
else:
|
|
985
|
+
torch._check(
|
|
986
|
+
a.shape[original_idx] == shape[idx],
|
|
987
|
+
lambda idx=idx, original_idx=original_idx: (
|
|
988
|
+
f"non-broadcasting semantics require "
|
|
989
|
+
f"{a.shape[original_idx]} == {shape[idx]}, "
|
|
990
|
+
f"{guard_or_false(a.shape[idx] != 1)}, "
|
|
991
|
+
f"guard_or_false(a.shape[idx] == 1)="
|
|
992
|
+
f"{guard_or_false(a.shape[idx] == 1)}, "
|
|
993
|
+
f"a.stride()={a.stride()}, idx={idx}, "
|
|
994
|
+
f"original_idx={original_idx}"
|
|
995
|
+
),
|
|
996
|
+
)
|
|
997
|
+
new_strides.append(a.stride()[original_idx])
|
|
998
|
+
original_idx = original_idx + 1
|
|
999
|
+
else:
|
|
1000
|
+
if guard_or_true(shape[idx] != 1):
|
|
1001
|
+
# consistent with previous use of guard_size_oblivious
|
|
1002
|
+
new_strides.append(0)
|
|
1003
|
+
elif original_idx == a.ndim:
|
|
1004
|
+
new_strides.append(1)
|
|
1005
|
+
else:
|
|
1006
|
+
new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
|
|
1007
|
+
|
|
1008
|
+
return a.as_strided(shape, new_strides, a.storage_offset())
|