onnx-diagnostic 0.7.12__py3-none-any.whl → 0.7.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +7 -2
  3. onnx_diagnostic/export/dynamic_shapes.py +11 -2
  4. onnx_diagnostic/helpers/helper.py +11 -5
  5. onnx_diagnostic/helpers/log_helper.py +53 -17
  6. onnx_diagnostic/helpers/mini_onnx_builder.py +17 -0
  7. onnx_diagnostic/helpers/model_builder_helper.py +1 -0
  8. onnx_diagnostic/helpers/rt_helper.py +2 -1
  9. onnx_diagnostic/helpers/torch_helper.py +31 -7
  10. onnx_diagnostic/reference/torch_evaluator.py +2 -2
  11. onnx_diagnostic/tasks/data/__init__.py +13 -0
  12. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  13. onnx_diagnostic/tasks/image_text_to_text.py +256 -141
  14. onnx_diagnostic/tasks/text_generation.py +30 -0
  15. onnx_diagnostic/torch_export_patches/eval/__init__.py +184 -151
  16. onnx_diagnostic/torch_export_patches/eval/model_cases.py +20 -5
  17. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +52 -20
  18. onnx_diagnostic/torch_export_patches/patch_inputs.py +10 -6
  19. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +540 -10
  20. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +269 -4
  21. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +36 -0
  22. onnx_diagnostic/torch_models/hghub/model_inputs.py +55 -5
  23. onnx_diagnostic/torch_models/validate.py +116 -50
  24. onnx_diagnostic/torch_onnx/sbs.py +2 -1
  25. {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/METADATA +11 -31
  26. {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/RECORD +29 -27
  27. {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/WHEEL +0 -0
  28. {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/licenses/LICENSE.txt +0 -0
  29. {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.14.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,10 @@
1
+ import functools
1
2
  import inspect
3
+ import operator
2
4
  import os
3
5
  import traceback
4
- from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
6
+ from functools import reduce
7
+ from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
5
8
  import torch
6
9
  from torch._subclasses.fake_tensor import FakeTensorMode
7
10
 
@@ -65,6 +68,8 @@ def patch__check_input_constraints_for_graph(
65
68
  verbose: int = 0,
66
69
  ) -> None:
67
70
  try:
71
+ # PATCHED: catches exception and prints out the information instead of
72
+ # stopping the conversion.
68
73
  return previous_function(input_placeholders, flat_args_with_path, range_constraints)
69
74
  except Exception as e:
70
75
  if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")):
@@ -122,8 +127,7 @@ def patched_infer_size(a, b):
122
127
  if b1 or b2 or b3:
123
128
  expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA
124
129
  else:
125
- # In this case, the current implementation of torch fails (17/12/2024).
126
- # Try model SmolLM.
130
+ # PATCHED: generic case, the dimension is known, no need to assert
127
131
  expandedSizes[i] = torch.sym_max(sizeA, sizeB)
128
132
  return tuple(expandedSizes)
129
133
 
@@ -132,7 +136,11 @@ def patched__broadcast_shapes(*_shapes):
132
136
  """Patches ``torch._refs._broadcast_shapes``."""
133
137
  from functools import reduce
134
138
  from torch._prims_common import IntLike
135
- from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
139
+ from torch.fx.experimental.symbolic_shapes import (
140
+ guard_size_oblivious,
141
+ guard_or_false,
142
+ is_nested_int,
143
+ )
136
144
 
137
145
  shapes = tuple(
138
146
  (x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes)
@@ -142,17 +150,30 @@ def patched__broadcast_shapes(*_shapes):
142
150
  if len(shapes) == 0:
143
151
  return None
144
152
 
145
- # Type checking
146
- # TODO: make common validations available as utils
147
153
  for shape in shapes:
148
- assert isinstance(shape, Sequence)
154
+ if not isinstance(shape, Sequence):
155
+ raise RuntimeError(
156
+ "Input shapes should be of type ints, a tuple of ints, "
157
+ "or a list of ints, got ",
158
+ shape,
159
+ )
149
160
 
150
161
  # Computes common shape
151
- common_shape = [ # List[Union[int, torch.SymInt]]
152
- 1,
153
- ] * reduce(max, (len(shape) for shape in shapes))
162
+ common_shape = [1] * reduce(max, (len(shape) for shape in shapes))
154
163
  for _arg_idx, shape in enumerate(shapes):
155
164
  for idx in range(-1, -1 - len(shape), -1):
165
+ if is_nested_int(shape[idx]):
166
+ # Broadcasting is allowed for (j0, 1) or (j0, j0);
167
+ # not (j0, j1), (j0, 5), etc.
168
+ if is_nested_int(common_shape[idx]) and guard_or_false(
169
+ shape[idx] == common_shape[idx]
170
+ ):
171
+ continue
172
+ else:
173
+ if guard_or_false(shape[idx] == common_shape[idx]):
174
+ continue
175
+ # PATCHED: two cases, if == for sure, no broadcast,
176
+ # otherwise maybe broadcast with max(dimensions)
156
177
  if guard_size_oblivious(common_shape[idx] == 1):
157
178
  if shape[idx] < 0:
158
179
  raise ValueError(
@@ -172,6 +193,7 @@ class patched_ShapeEnv:
172
193
  ) -> None:
173
194
  if self.frozen:
174
195
  self.counter["ignored_backward_guard"] += 1
196
+ # PATCHED: raised an exception instead of logging.
175
197
  raise AssertionError(
176
198
  f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
177
199
  f"this could result in accuracy problems"
@@ -338,11 +360,13 @@ class patched_ShapeEnv:
338
360
  },
339
361
  )
340
362
 
363
+ # PATCHED: removed lines
341
364
  # if config.print_specializations:
342
365
  # self.log.warning(
343
366
  # "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt
344
367
  # )
345
368
  # self.log.debug("SPECIALIZATION", stack_info=True)
369
+ # PATCHED: replaces logging by raising an exception
346
370
  assert msg != "range_refined_to_singleton", (
347
371
  f"patched_ShapeEnv: A dynamic dimension becomes static! "
348
372
  f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
@@ -364,6 +388,7 @@ class patched_ShapeEnv:
364
388
  self, prefix: str, g: "SympyBoolean", forcing_spec: bool # noqa: F821
365
389
  ) -> None:
366
390
  self._log_guard_remember(prefix=prefix, g=g, forcing_spec=forcing_spec)
391
+ # PATCHED: removed
367
392
  # It happens too often to be relevant.
368
393
  # sloc, _maybe_extra_debug = self._get_stack_summary(True)
369
394
  # warnings.warn(
@@ -374,6 +399,284 @@ class patched_ShapeEnv:
374
399
  # stacklevel=0,
375
400
  # )
376
401
 
402
+ def _evaluate_expr(
403
+ self,
404
+ orig_expr: "sympy.Basic", # noqa: F821
405
+ hint: Optional[Union[bool, int, float]] = None,
406
+ fx_node: Optional[torch.fx.Node] = None,
407
+ size_oblivious: bool = False,
408
+ fallback_value: Optional[bool] = None,
409
+ *,
410
+ forcing_spec: bool = False,
411
+ ) -> "sympy.Basic": # noqa: F821
412
+ # TODO: split conjunctions and evaluate them separately
413
+ import sympy
414
+ from torch.fx.experimental import _config as config
415
+ from torch.fx.experimental.symbolic_shapes import (
416
+ SympyBoolean,
417
+ log,
418
+ SymT,
419
+ symbol_is_type,
420
+ )
421
+ from torch._guards import ShapeGuard
422
+
423
+ if isinstance(
424
+ orig_expr,
425
+ (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse),
426
+ ):
427
+ return orig_expr
428
+
429
+ # Don't track this one. (Because this cache is inside this function the
430
+ # cache only lasts for the invocation of this function call)
431
+ @functools.cache
432
+ def compute_concrete_val() -> sympy.Basic:
433
+ if hint is None:
434
+ # This is only ever called for expressions WITHOUT unbacked
435
+ # symbols
436
+ r = self.size_hint(orig_expr)
437
+ assert r is not None
438
+ return r
439
+ else:
440
+ return sympy.sympify(hint)
441
+
442
+ concrete_val: Optional[sympy.Basic]
443
+
444
+ # Check if:
445
+ # 1. 'translation_validation' is set
446
+ # 2. the corresponding 'fx_node' is not 'None'
447
+ # 3. the guard should not be suppressed
448
+ # 4. the guard doesn't contain backed symfloat symbols
449
+ # since z3 can't handle floats
450
+ # 5. fallback_value is none.
451
+ # If all of the above check, we create an FX node representing the
452
+ # actual expression to be guarded.
453
+ node = None
454
+ fresh = False
455
+ if (
456
+ self._translation_validation_enabled
457
+ and fx_node is not None
458
+ and not self._suppress_guards_tls()
459
+ and not size_oblivious
460
+ and not any(symbol_is_type(s, SymT.FLOAT) for s in orig_expr.free_symbols)
461
+ and fallback_value is None
462
+ ):
463
+ # TODO: does this even worked with unbacked :think:
464
+ concrete_val = compute_concrete_val()
465
+ if concrete_val is sympy.true:
466
+ node, fresh = self._create_fx_call_function(torch._assert, (fx_node,))
467
+ elif concrete_val is sympy.false:
468
+ neg, _ = self._create_fx_call_function(operator.not_, (fx_node,))
469
+ node, fresh = self._create_fx_call_function(torch._assert, (neg,))
470
+ else:
471
+ eql, _ = self._create_fx_call_function(operator.eq, (fx_node, concrete_val))
472
+ node, fresh = self._create_fx_call_function(torch._assert, (eql,))
473
+
474
+ assert node is not None
475
+ # If this is a fresh node, we have to remember the event index that
476
+ # corresponds to this assertion node.
477
+ # Reason: so that, given an assertion node, we can replay the ShapeEnv
478
+ # events until the point where this assertion node was freshly created.
479
+ if fresh:
480
+ self._add_fx_node_metadata(node)
481
+
482
+ # After creating the FX node corresponding to orig_expr, we must make sure that
483
+ # no error will be raised until the end of this function.
484
+ #
485
+ # Reason: the translation validation may become invalid otherwise.
486
+ #
487
+ # If an error is raised before the end of this function, we remove the FX node
488
+ # inserted, and re-raise the error.
489
+ guard = None
490
+
491
+ try:
492
+ if orig_expr.is_number:
493
+ self.log.debug("eval %s [trivial]", orig_expr)
494
+ if hint is not None:
495
+ if isinstance(hint, bool):
496
+ assert orig_expr == hint, f"{orig_expr} != {hint}"
497
+ else:
498
+ assert sympy.Eq(orig_expr, hint), f"{orig_expr} != {hint}"
499
+ return orig_expr
500
+
501
+ expr = orig_expr
502
+
503
+ static_expr = self._maybe_evaluate_static(expr, size_oblivious=size_oblivious)
504
+ if static_expr is not None:
505
+ self.log.debug(
506
+ "eval %s == %s [statically known]",
507
+ (f"size_oblivious({orig_expr})" if size_oblivious else size_oblivious),
508
+ static_expr,
509
+ )
510
+ if not size_oblivious and config.backed_size_oblivious and hint is not None:
511
+ # TODO: maybe reconcile this with use of counterfactual hints
512
+ # in unbacked case
513
+ assert static_expr == hint, f"{static_expr} != {hint}"
514
+ return static_expr
515
+
516
+ transmute_into_runtime_assert = False
517
+
518
+ concrete_val = None
519
+ if not (expr.free_symbols <= self.var_to_val.keys()):
520
+ # TODO: dedupe this with _maybe_evaluate_static
521
+ # Attempt to eliminate the unbacked SymInt
522
+ new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
523
+ assert new_expr is not None
524
+ if not (new_expr.free_symbols <= self.var_to_val.keys()):
525
+ ok = False
526
+
527
+ # fallback_value is set when guard_or_true or guard_or_false are used.
528
+ if not ok and fallback_value is not None:
529
+ self._log_suppressed_dde(orig_expr, fallback_value)
530
+ return fallback_value
531
+
532
+ # oblivious_var_to_val will be defined iff we have sizes
533
+ # with DimDynamic.OBLIVIOUS_SIZE type.
534
+ # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
535
+ if (
536
+ self.oblivious_var_to_val
537
+ and not (
538
+ correct_hint := orig_expr.xreplace(self.oblivious_var_to_val)
539
+ ).free_symbols
540
+ and not (
541
+ counterfactual_hint := orig_expr.xreplace(
542
+ {k: max(2, v) for k, v in self.oblivious_var_to_val.items()}
543
+ )
544
+ ).free_symbols
545
+ and correct_hint == counterfactual_hint
546
+ ):
547
+ # TODO: better logging
548
+ log.info(
549
+ "oblivious_size %s -> %s (passed counterfactual)",
550
+ orig_expr,
551
+ # pyrefly: ignore # unbound-name
552
+ correct_hint,
553
+ )
554
+ # pyrefly: ignore # unbound-name
555
+ concrete_val = correct_hint
556
+ # NB: do NOT transmute into runtime assert
557
+ ok = True
558
+
559
+ # unbacked_var_to_val is not None iff propagate_real_tensors is on.
560
+ # if propagate_real_tensors is on, we check the example values
561
+ # to generate (unsound_result)
562
+ # and if they pass we add a runtime assertions and continue.
563
+ if (
564
+ not ok
565
+ and self.unbacked_var_to_val
566
+ and not (
567
+ unsound_result := orig_expr.xreplace(
568
+ self.unbacked_var_to_val
569
+ ).xreplace(self.var_to_val)
570
+ ).free_symbols
571
+ ):
572
+ # pyrefly: ignore # unbound-name
573
+ self._log_real_tensor_propagation(orig_expr, unsound_result)
574
+ transmute_into_runtime_assert = True
575
+ # pyrefly: ignore # unbound-name
576
+ concrete_val = unsound_result
577
+ ok = True
578
+
579
+ # Check if this is coming from a python assert statement,
580
+ # if so, convert it to a runtime assertion
581
+ # instead of failing.
582
+ if not ok and self.trace_asserts and self._is_python_assert():
583
+ concrete_val = sympy.true
584
+ transmute_into_runtime_assert = True
585
+ ok = True
586
+
587
+ # PATCHED: ok -> True
588
+ ok = True
589
+ # if not ok:
590
+ # raise self._make_data_dependent_error(
591
+ # expr.xreplace(self.var_to_val),
592
+ # expr,
593
+ # expr_sym_node_id=self._expr_sym_node_id,
594
+ # )
595
+ else:
596
+ expr = new_expr
597
+
598
+ if concrete_val is None:
599
+ concrete_val = compute_concrete_val()
600
+ self._check_frozen(expr, concrete_val)
601
+
602
+ if (
603
+ config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY
604
+ and isinstance(hint, bool)
605
+ and isinstance(expr, (sympy.Eq, sympy.Ne))
606
+ ):
607
+ expr = sympy.Not(expr)
608
+
609
+ # Turn this into a boolean expression, no longer need to consult
610
+ # concrete_val
611
+ if concrete_val is sympy.true:
612
+ g = cast(SympyBoolean, expr)
613
+ elif concrete_val is sympy.false:
614
+ g = sympy.Not(expr)
615
+ else:
616
+ g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type]
617
+
618
+ if transmute_into_runtime_assert:
619
+ self.guard_or_defer_runtime_assert(
620
+ g, f"propagate_real_tensors: {orig_expr} == {concrete_val}"
621
+ )
622
+ return concrete_val
623
+
624
+ if not self._suppress_guards_tls():
625
+ self._log_guard("eval", g, forcing_spec=forcing_spec)
626
+
627
+ # TODO: If we successfully eliminate a symbol via equality, it
628
+ # is not actually necessary to save a guard for the equality,
629
+ # as we will implicitly generate a guard when we match that
630
+ # input against the symbol. Probably the easiest way to
631
+ # implement this is to have maybe_guard_rel return a bool
632
+ # saying if it "subsumed" the guard (and therefore the guard
633
+ # is no longer necessary)
634
+ self._maybe_guard_rel(g)
635
+
636
+ if (
637
+ torch.compiler.is_exporting()
638
+ and self.prefer_deferred_runtime_asserts_over_guards
639
+ ):
640
+ # it's fine to defer simple guards here without checking,
641
+ # the _maybe_guard_rel() call above will set replacements if possible,
642
+ # and so the result here will be statically known
643
+ self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
644
+ else:
645
+ # at this point, we've evaluated the concrete expr value, and have
646
+ # flipped/negated the guard if necessary. Now we know what to guard
647
+ # or defer to runtime assert on.
648
+ guard = ShapeGuard(g, self._get_sloc(), size_oblivious=size_oblivious)
649
+ self.guards.append(guard)
650
+ self.axioms.update(dict(self.get_implications(self.simplify(g))))
651
+ else:
652
+ self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec)
653
+
654
+ except Exception:
655
+ if fresh:
656
+ self._remove_fx_node(node)
657
+ raise
658
+
659
+ if not self._suppress_guards_tls():
660
+ if guard is not None: # we might have deferred this to runtime assert
661
+ for s in g.free_symbols:
662
+ self.symbol_guard_counter[s] += 1
663
+ # Forcing_spec to avoid infinite recursion
664
+ if (
665
+ not forcing_spec
666
+ and config.symbol_guard_limit_before_specialize is not None
667
+ and self.symbol_guard_counter[s]
668
+ > config.symbol_guard_limit_before_specialize
669
+ ):
670
+ # Force specialization
671
+ self.log.info(
672
+ "symbol_guard_limit_before_specialize=%s exceeded on %s",
673
+ config.symbol_guard_limit_before_specialize,
674
+ s,
675
+ )
676
+ self.evaluate_expr(s, forcing_spec=True)
677
+
678
+ return concrete_val
679
+
377
680
 
378
681
  def patched_vmap(func, in_dims=0, out_dims=0):
379
682
  """
@@ -464,3 +767,230 @@ def patched_vmap(func, in_dims=0, out_dims=0):
464
767
  return results
465
768
 
466
769
  return wrapped
770
+
771
+
772
+ def patched__constrain_user_specified_dimhint_range(
773
+ symint: torch.SymInt,
774
+ hint: int,
775
+ dim: "_DimHint", # noqa: F821
776
+ range_constraints,
777
+ shape_env,
778
+ keypath: "KeyPath", # noqa: F821
779
+ i: Optional[int] = None,
780
+ ) -> Optional[str]:
781
+ """Patches ``torch._export.non_strict_utils._constrain_user_specified_dimhint_range``."""
782
+ from torch._export.non_strict_utils import is_int, int_oo, _DimHintType, ValueRanges
783
+
784
+ trace_vr = (
785
+ range_constraints[symint.node.expr]
786
+ if not is_int(symint)
787
+ else ValueRanges(int(symint), int(symint))
788
+ )
789
+ # warn on 0/1 specialization for Dim.AUTO; not an actual error
790
+ # PATCHED: remove logging
791
+ # if dim.type == _DimHintType.AUTO and trace_vr.is_singleton() and hint in (0, 1):
792
+ # pathstr = f"inputs{pytree.keystr(keypath)}"
793
+ # if i is not None:
794
+ # pathstr += f".shape[{i}]"
795
+ # msg = (
796
+ # f"dimension {pathstr} 0/1 specialized; Dim.AUTO was specified along "
797
+ # f"with a sample input with hint = {hint}."
798
+ # )
799
+ # log.warning(msg)
800
+
801
+ try:
802
+ user_vr = ValueRanges(
803
+ lower=0 if dim.min is None else dim.min,
804
+ upper=int_oo if dim.max is None else dim.max,
805
+ )
806
+ if is_int(symint):
807
+ out_vr = trace_vr & user_vr
808
+ else:
809
+ range_constraints[symint.node.expr] &= user_vr
810
+ shape_env.var_to_range[symint.node._expr] &= user_vr
811
+ out_vr = range_constraints[symint.node.expr]
812
+
813
+ # check for Dim.DYNAMIC specializations; special case error message on 0/1
814
+ if dim.type == _DimHintType.DYNAMIC and out_vr.is_singleton():
815
+ path = f"inputs{torch.utils._pytree.keystr(keypath)}"
816
+ if i is not None:
817
+ path += f".shape[{i}]"
818
+ if (
819
+ trace_vr.is_singleton()
820
+ and hint in (0, 1)
821
+ # PATCHED: line removed
822
+ # and not torch.fx.experimental._config.backed_size_oblivious
823
+ ):
824
+ return None
825
+ # PATCHED: line removed
826
+ # msg = (
827
+ # f"- Received user-specified dim hint "
828
+ # f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
829
+ # f"but export 0/1 specialized due to hint of "
830
+ # f"{hint} for dimension {path}."
831
+ # )
832
+ else:
833
+ msg = (
834
+ f"- Received user-specified dim hint "
835
+ f"Dim.DYNAMIC(min={dim.min}, max={dim.max}), "
836
+ f"but tracing inferred a static shape of "
837
+ f"{out_vr.lower} for dimension {path}."
838
+ )
839
+ return msg
840
+
841
+ except torch.utils._sympy.value_ranges.ValueRangeError:
842
+ path = f"inputs{torch.utils._pytree.keystr(keypath)}"
843
+ if i is not None:
844
+ path += f".shape[{i}]"
845
+ msg = (
846
+ f"- Received user-specified min/max range of [{dim.min}, {dim.max}], "
847
+ f"conflicting with the inferred min/max range of "
848
+ f"[{trace_vr.lower}, {trace_vr.upper}], "
849
+ f"for {path}."
850
+ )
851
+ return msg
852
+
853
+ return None
854
+
855
+
856
+ def patched__maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
857
+ """Patches ``torch._refs._maybe_broadcast``."""
858
+ from torch._prims_common import ShapeType, TensorLike, Number
859
+
860
+ # Computes common shape
861
+ common_shape = patched__broadcast_shapes(
862
+ *(t.shape if isinstance(t, TensorLike) else None for t in args)
863
+ )
864
+
865
+ def should_expand(a: ShapeType, b: ShapeType) -> bool:
866
+ from torch.fx.experimental.symbolic_shapes import (
867
+ guard_or_false,
868
+ sym_and,
869
+ sym_or,
870
+ )
871
+
872
+ if len(a) != len(b):
873
+ return True
874
+
875
+ for x, y in zip(a, b):
876
+ if guard_or_false(x != y):
877
+ # We know they are not the same.
878
+ return True
879
+
880
+ # They are the same or we do not know if they are the same or not.
881
+ # 1==1 no-broadcast
882
+ # u0==1 and 1==u0 cases. We broadcast!
883
+ if guard_or_false(sym_and(x == 1, y == 1)):
884
+ pass
885
+ elif guard_or_false(sym_or(x == 1, y == 1)):
886
+ # assume broadcasting.
887
+ return True
888
+
889
+ # u0==u1 assume the same, no broadcasting!
890
+ # PATCHED: avoid errors
891
+ return True # guard_or_true(x != y)
892
+ # torch._check(
893
+ # x == y,
894
+ # lambda x=x, y=y: (
895
+ # f"sizes assumed to be the same due to unbacked "
896
+ # f"broadcasting semantics x={x!r}, y={y!r}"
897
+ # ),
898
+ # )
899
+
900
+ return False
901
+
902
+ def __maybe_broadcast(x, shape):
903
+ if x is None:
904
+ return None
905
+ elif isinstance(x, Number):
906
+ return x
907
+ elif isinstance(x, TensorLike):
908
+ if preserve_cpu_scalar_tensors and torch._prims_common.is_cpu_scalar_tensor(x):
909
+ return x
910
+
911
+ if should_expand(x.shape, common_shape):
912
+ return x.expand(common_shape)
913
+
914
+ return x
915
+ else:
916
+ raise RuntimeError(f"Unexpected type when broadcasting: {str(type(x))}!")
917
+
918
+ return tuple(__maybe_broadcast(x, common_shape) for x in args)
919
+
920
+
921
+ def patched__broadcast_in_dim_meta(
922
+ a: torch._prims_common.TensorLikeType,
923
+ shape: torch._prims_common.ShapeType,
924
+ broadcast_dimensions: Sequence[int],
925
+ ):
926
+ """Patches ``torch._prims._broadcast_in_dim_meta``."""
927
+ from torch.fx.experimental.symbolic_shapes import (
928
+ guard_or_false,
929
+ guard_or_true,
930
+ sym_or,
931
+ )
932
+
933
+ # Type checks
934
+ assert isinstance(a, torch._prims_common.TensorLike)
935
+ assert isinstance(shape, Sequence)
936
+ assert isinstance(broadcast_dimensions, Sequence)
937
+
938
+ # every dimension must be accounted for
939
+ assert a.ndim == len(broadcast_dimensions)
940
+
941
+ # broadcast shape must have weakly more dimensions
942
+ assert len(shape) >= a.ndim
943
+
944
+ # broadcast_dimensions must be an ascending sequence
945
+ # (no relative reordering of dims) of integers and
946
+ # each dimension must be within the new shape
947
+ def _greater_than_reduce(acc, x):
948
+ assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
949
+ assert x > acc
950
+ assert x < len(shape)
951
+
952
+ return x
953
+
954
+ reduce(_greater_than_reduce, broadcast_dimensions, -1)
955
+
956
+ # shape must be broadcastable to
957
+ for idx, new_idx in enumerate(broadcast_dimensions):
958
+ torch._check(
959
+ sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
960
+ lambda idx=idx, new_idx=new_idx: (
961
+ f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
962
+ ),
963
+ )
964
+
965
+ new_strides = []
966
+ original_idx = 0
967
+ for idx in range(len(shape)):
968
+ if idx in broadcast_dimensions:
969
+ # Assigns a stride of zero to dimensions
970
+ # which were actually broadcast
971
+ if guard_or_false(a.shape[original_idx] == 1):
972
+ if guard_or_false(a.shape[original_idx] == shape[idx]):
973
+ new_strides.append(a.stride()[original_idx])
974
+ else:
975
+ new_strides.append(0)
976
+ else:
977
+ # PATCHED: disabled this check
978
+ # torch._check(
979
+ # a.shape[original_idx] == shape[idx],
980
+ # lambda idx=idx, original_idx=original_idx: (
981
+ # f"non-broadcasting semantics require "
982
+ # f"{a.shape[original_idx]} == {shape[idx]}"
983
+ # ),
984
+ # )
985
+ new_strides.append(a.stride()[original_idx])
986
+ original_idx = original_idx + 1
987
+ else:
988
+ if guard_or_true(shape[idx] != 1):
989
+ # consistent with previous use of guard_size_oblivious
990
+ new_strides.append(0)
991
+ elif original_idx == a.ndim:
992
+ new_strides.append(1)
993
+ else:
994
+ new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
995
+
996
+ return a.as_strided(shape, new_strides, a.storage_offset())