onnx-diagnostic 0.7.13__py3-none-any.whl → 0.7.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,10 @@
1
+ import functools
1
2
  import inspect
3
+ import operator
2
4
  import os
3
5
  import traceback
4
- from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
6
+ from functools import reduce
7
+ from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
5
8
  import torch
6
9
  from torch._subclasses.fake_tensor import FakeTensorMode
7
10
 
@@ -85,7 +88,7 @@ def patch__check_input_constraints_for_graph(
85
88
 
86
89
  def patched_infer_size(a, b):
87
90
  """Patches ``torch._subclasses.fake_impls.infer_size``."""
88
- from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
91
+ from torch.fx.experimental.symbolic_shapes import guard_or_false
89
92
 
90
93
  dimsA = len(a)
91
94
  dimsB = len(b)
@@ -110,19 +113,19 @@ def patched_infer_size(a, b):
110
113
  # were not the case, we'd need to write this using torch.sym_or() or
111
114
  # something like that).
112
115
  try:
113
- b1 = guard_size_oblivious(sizeA == 1)
116
+ b1 = guard_or_false(sizeA == 1)
114
117
  except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
115
118
  b1 = False
116
119
  try:
117
- b2 = guard_size_oblivious(sizeB == 1)
120
+ b2 = guard_or_false(sizeB == 1)
118
121
  except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
119
122
  b2 = False
120
123
  try:
121
- b3 = guard_size_oblivious(sizeA == sizeB)
124
+ b3 = guard_or_false(sizeA == sizeB)
122
125
  except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
123
126
  b3 = False
124
127
  if b1 or b2 or b3:
125
- expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA
128
+ expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
126
129
  else:
127
130
  # PATCHED: generic case, the dimension is known, no need to assert
128
131
  expandedSizes[i] = torch.sym_max(sizeA, sizeB)
@@ -134,7 +137,6 @@ def patched__broadcast_shapes(*_shapes):
134
137
  from functools import reduce
135
138
  from torch._prims_common import IntLike
136
139
  from torch.fx.experimental.symbolic_shapes import (
137
- guard_size_oblivious,
138
140
  guard_or_false,
139
141
  is_nested_int,
140
142
  )
@@ -171,13 +173,15 @@ def patched__broadcast_shapes(*_shapes):
171
173
  continue
172
174
  # PATCHED: two cases, if == for sure, no broadcast,
173
175
  # otherwise maybe broadcast with max(dimensions)
174
- if guard_size_oblivious(common_shape[idx] == 1):
176
+ if guard_or_false(common_shape[idx] != 1):
177
+ pass
178
+ elif guard_or_false(common_shape[idx] == 1) or guard_or_false(shape[idx] != 1):
175
179
  if shape[idx] < 0:
176
180
  raise ValueError(
177
181
  "Attempting to broadcast a dimension with negative length!"
178
182
  )
179
183
  common_shape[idx] = shape[idx]
180
- elif guard_size_oblivious(shape[idx] != 1):
184
+ else:
181
185
  common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx])
182
186
 
183
187
  return common_shape
@@ -357,6 +361,10 @@ class patched_ShapeEnv:
357
361
  },
358
362
  )
359
363
 
364
+ for source in self.var_to_sources.get(a, []):
365
+ if user_tb:
366
+ self.specialization_stacks[source] = user_tb
367
+
360
368
  # PATCHED: removed lines
361
369
  # if config.print_specializations:
362
370
  # self.log.warning(
@@ -396,6 +404,284 @@ class patched_ShapeEnv:
396
404
  # stacklevel=0,
397
405
  # )
398
406
 
407
+ def _evaluate_expr(
408
+ self,
409
+ orig_expr: "sympy.Basic", # noqa: F821
410
+ hint: Optional[Union[bool, int, float]] = None,
411
+ fx_node: Optional[torch.fx.Node] = None,
412
+ size_oblivious: bool = False,
413
+ fallback_value: Optional[bool] = None,
414
+ *,
415
+ forcing_spec: bool = False,
416
+ ) -> "sympy.Basic": # noqa: F821
417
+ # TODO: split conjunctions and evaluate them separately
418
+ import sympy
419
+ from torch.fx.experimental import _config as config
420
+ from torch.fx.experimental.symbolic_shapes import (
421
+ SympyBoolean,
422
+ log,
423
+ SymT,
424
+ symbol_is_type,
425
+ )
426
+ from torch._guards import ShapeGuard
427
+
428
+ if isinstance(
429
+ orig_expr,
430
+ (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse),
431
+ ):
432
+ return orig_expr
433
+
434
+ # Don't track this one. (Because this cache is inside this function the
435
+ # cache only lasts for the invocation of this function call)
436
+ @functools.cache
437
+ def compute_concrete_val() -> sympy.Basic:
438
+ if hint is None:
439
+ # This is only ever called for expressions WITHOUT unbacked
440
+ # symbols
441
+ r = self.size_hint(orig_expr)
442
+ assert r is not None
443
+ return r
444
+ else:
445
+ return sympy.sympify(hint)
446
+
447
+ concrete_val: Optional[sympy.Basic]
448
+
449
+ # Check if:
450
+ # 1. 'translation_validation' is set
451
+ # 2. the corresponding 'fx_node' is not 'None'
452
+ # 3. the guard should not be suppressed
453
+ # 4. the guard doesn't contain backed symfloat symbols
454
+ # since z3 can't handle floats
455
+ # 5. fallback_value is none.
456
+ # If all of the above check, we create an FX node representing the
457
+ # actual expression to be guarded.
458
+ node = None
459
+ fresh = False
460
+ if (
461
+ self._translation_validation_enabled
462
+ and fx_node is not None
463
+ and not self._suppress_guards_tls()
464
+ and not size_oblivious
465
+ and not any(symbol_is_type(s, SymT.FLOAT) for s in orig_expr.free_symbols)
466
+ and fallback_value is None
467
+ ):
468
+ # TODO: does this even worked with unbacked :think:
469
+ concrete_val = compute_concrete_val()
470
+ if concrete_val is sympy.true:
471
+ node, fresh = self._create_fx_call_function(torch._assert, (fx_node,))
472
+ elif concrete_val is sympy.false:
473
+ neg, _ = self._create_fx_call_function(operator.not_, (fx_node,))
474
+ node, fresh = self._create_fx_call_function(torch._assert, (neg,))
475
+ else:
476
+ eql, _ = self._create_fx_call_function(operator.eq, (fx_node, concrete_val))
477
+ node, fresh = self._create_fx_call_function(torch._assert, (eql,))
478
+
479
+ assert node is not None
480
+ # If this is a fresh node, we have to remember the event index that
481
+ # corresponds to this assertion node.
482
+ # Reason: so that, given an assertion node, we can replay the ShapeEnv
483
+ # events until the point where this assertion node was freshly created.
484
+ if fresh:
485
+ self._add_fx_node_metadata(node)
486
+
487
+ # After creating the FX node corresponding to orig_expr, we must make sure that
488
+ # no error will be raised until the end of this function.
489
+ #
490
+ # Reason: the translation validation may become invalid otherwise.
491
+ #
492
+ # If an error is raised before the end of this function, we remove the FX node
493
+ # inserted, and re-raise the error.
494
+ guard = None
495
+
496
+ try:
497
+ if orig_expr.is_number:
498
+ self.log.debug("eval %s [trivial]", orig_expr)
499
+ if hint is not None:
500
+ if isinstance(hint, bool):
501
+ assert orig_expr == hint, f"{orig_expr} != {hint}"
502
+ else:
503
+ assert sympy.Eq(orig_expr, hint), f"{orig_expr} != {hint}"
504
+ return orig_expr
505
+
506
+ expr = orig_expr
507
+
508
+ static_expr = self._maybe_evaluate_static(expr, size_oblivious=size_oblivious)
509
+ if static_expr is not None:
510
+ self.log.debug(
511
+ "eval %s == %s [statically known]",
512
+ (f"size_oblivious({orig_expr})" if size_oblivious else size_oblivious),
513
+ static_expr,
514
+ )
515
+ if not size_oblivious and config.backed_size_oblivious and hint is not None:
516
+ # TODO: maybe reconcile this with use of counterfactual hints
517
+ # in unbacked case
518
+ assert static_expr == hint, f"{static_expr} != {hint}"
519
+ return static_expr
520
+
521
+ transmute_into_runtime_assert = False
522
+
523
+ concrete_val = None
524
+ if not (expr.free_symbols <= self.var_to_val.keys()):
525
+ # TODO: dedupe this with _maybe_evaluate_static
526
+ # Attempt to eliminate the unbacked SymInt
527
+ new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
528
+ assert new_expr is not None
529
+ if not (new_expr.free_symbols <= self.var_to_val.keys()):
530
+ ok = False
531
+
532
+ # fallback_value is set when guard_or_true or guard_or_false are used.
533
+ if not ok and fallback_value is not None:
534
+ self._log_suppressed_dde(orig_expr, fallback_value)
535
+ return fallback_value
536
+
537
+ # oblivious_var_to_val will be defined iff we have sizes
538
+ # with DimDynamic.OBLIVIOUS_SIZE type.
539
+ # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
540
+ if (
541
+ self.oblivious_var_to_val
542
+ and not (
543
+ correct_hint := orig_expr.xreplace(self.oblivious_var_to_val)
544
+ ).free_symbols
545
+ and not (
546
+ counterfactual_hint := orig_expr.xreplace(
547
+ {k: max(2, v) for k, v in self.oblivious_var_to_val.items()}
548
+ )
549
+ ).free_symbols
550
+ and correct_hint == counterfactual_hint
551
+ ):
552
+ # TODO: better logging
553
+ log.info(
554
+ "oblivious_size %s -> %s (passed counterfactual)",
555
+ orig_expr,
556
+ # pyrefly: ignore # unbound-name
557
+ correct_hint,
558
+ )
559
+ # pyrefly: ignore # unbound-name
560
+ concrete_val = correct_hint
561
+ # NB: do NOT transmute into runtime assert
562
+ ok = True
563
+
564
+ # unbacked_var_to_val is not None iff propagate_real_tensors is on.
565
+ # if propagate_real_tensors is on, we check the example values
566
+ # to generate (unsound_result)
567
+ # and if they pass we add a runtime assertions and continue.
568
+ if (
569
+ not ok
570
+ and self.unbacked_var_to_val
571
+ and not (
572
+ unsound_result := orig_expr.xreplace(
573
+ self.unbacked_var_to_val
574
+ ).xreplace(self.var_to_val)
575
+ ).free_symbols
576
+ ):
577
+ # pyrefly: ignore # unbound-name
578
+ self._log_real_tensor_propagation(orig_expr, unsound_result)
579
+ transmute_into_runtime_assert = True
580
+ # pyrefly: ignore # unbound-name
581
+ concrete_val = unsound_result
582
+ ok = True
583
+
584
+ # Check if this is coming from a python assert statement,
585
+ # if so, convert it to a runtime assertion
586
+ # instead of failing.
587
+ if not ok and self.trace_asserts and self._is_python_assert():
588
+ concrete_val = sympy.true
589
+ transmute_into_runtime_assert = True
590
+ ok = True
591
+
592
+ # PATCHED: ok -> True
593
+ ok = True
594
+ # if not ok:
595
+ # raise self._make_data_dependent_error(
596
+ # expr.xreplace(self.var_to_val),
597
+ # expr,
598
+ # expr_sym_node_id=self._expr_sym_node_id,
599
+ # )
600
+ else:
601
+ expr = new_expr
602
+
603
+ if concrete_val is None:
604
+ concrete_val = compute_concrete_val()
605
+ self._check_frozen(expr, concrete_val)
606
+
607
+ if (
608
+ config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY
609
+ and isinstance(hint, bool)
610
+ and isinstance(expr, (sympy.Eq, sympy.Ne))
611
+ ):
612
+ expr = sympy.Not(expr)
613
+
614
+ # Turn this into a boolean expression, no longer need to consult
615
+ # concrete_val
616
+ if concrete_val is sympy.true:
617
+ g = cast(SympyBoolean, expr)
618
+ elif concrete_val is sympy.false:
619
+ g = sympy.Not(expr)
620
+ else:
621
+ g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type]
622
+
623
+ if transmute_into_runtime_assert:
624
+ self.guard_or_defer_runtime_assert(
625
+ g, f"propagate_real_tensors: {orig_expr} == {concrete_val}"
626
+ )
627
+ return concrete_val
628
+
629
+ if not self._suppress_guards_tls():
630
+ self._log_guard("eval", g, forcing_spec=forcing_spec)
631
+
632
+ # TODO: If we successfully eliminate a symbol via equality, it
633
+ # is not actually necessary to save a guard for the equality,
634
+ # as we will implicitly generate a guard when we match that
635
+ # input against the symbol. Probably the easiest way to
636
+ # implement this is to have maybe_guard_rel return a bool
637
+ # saying if it "subsumed" the guard (and therefore the guard
638
+ # is no longer necessary)
639
+ self._maybe_guard_rel(g)
640
+
641
+ if (
642
+ torch.compiler.is_exporting()
643
+ and self.prefer_deferred_runtime_asserts_over_guards
644
+ ):
645
+ # it's fine to defer simple guards here without checking,
646
+ # the _maybe_guard_rel() call above will set replacements if possible,
647
+ # and so the result here will be statically known
648
+ self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
649
+ else:
650
+ # at this point, we've evaluated the concrete expr value, and have
651
+ # flipped/negated the guard if necessary. Now we know what to guard
652
+ # or defer to runtime assert on.
653
+ guard = ShapeGuard(g, self._get_sloc(), size_oblivious=size_oblivious)
654
+ self.guards.append(guard)
655
+ self.axioms.update(dict(self.get_implications(self.simplify(g))))
656
+ else:
657
+ self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec)
658
+
659
+ except Exception:
660
+ if fresh:
661
+ self._remove_fx_node(node)
662
+ raise
663
+
664
+ if not self._suppress_guards_tls():
665
+ if guard is not None: # we might have deferred this to runtime assert
666
+ for s in g.free_symbols:
667
+ self.symbol_guard_counter[s] += 1
668
+ # Forcing_spec to avoid infinite recursion
669
+ if (
670
+ not forcing_spec
671
+ and config.symbol_guard_limit_before_specialize is not None
672
+ and self.symbol_guard_counter[s]
673
+ > config.symbol_guard_limit_before_specialize
674
+ ):
675
+ # Force specialization
676
+ self.log.info(
677
+ "symbol_guard_limit_before_specialize=%s exceeded on %s",
678
+ config.symbol_guard_limit_before_specialize,
679
+ s,
680
+ )
681
+ self.evaluate_expr(s, forcing_spec=True)
682
+
683
+ return concrete_val
684
+
399
685
 
400
686
  def patched_vmap(func, in_dims=0, out_dims=0):
401
687
  """
@@ -570,3 +856,153 @@ def patched__constrain_user_specified_dimhint_range(
570
856
  return msg
571
857
 
572
858
  return None
859
+
860
+
861
+ def patched__maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
862
+ """Patches ``torch._refs._maybe_broadcast``."""
863
+ from torch._prims_common import ShapeType, TensorLike, Number
864
+
865
+ # Computes common shape
866
+ common_shape = patched__broadcast_shapes(
867
+ *(t.shape if isinstance(t, TensorLike) else None for t in args)
868
+ )
869
+
870
+ def should_expand(a: ShapeType, b: ShapeType) -> bool:
871
+ from torch.fx.experimental.symbolic_shapes import (
872
+ guard_or_false,
873
+ sym_and,
874
+ sym_or,
875
+ )
876
+
877
+ if len(a) != len(b):
878
+ return True
879
+
880
+ for x, y in zip(a, b):
881
+ if guard_or_false(x != y):
882
+ # We know they are not the same.
883
+ return True
884
+
885
+ # They are the same or we do not know if they are the same or not.
886
+ # 1==1 no-broadcast
887
+ # u0==1 and 1==u0 cases. We broadcast!
888
+ if guard_or_false(sym_and(x == 1, y == 1)):
889
+ pass
890
+ elif guard_or_false(sym_or(x == 1, y == 1)):
891
+ # assume broadcasting.
892
+ return True
893
+
894
+ # u0==u1 assume the same, no broadcasting!
895
+ # PATCHED: avoid errors
896
+ return True # guard_or_true(x != y)
897
+ # torch._check(
898
+ # x == y,
899
+ # lambda x=x, y=y: (
900
+ # f"sizes assumed to be the same due to unbacked "
901
+ # f"broadcasting semantics x={x!r}, y={y!r}"
902
+ # ),
903
+ # )
904
+
905
+ return False
906
+
907
+ def __maybe_broadcast(x, shape):
908
+ if x is None:
909
+ return None
910
+ elif isinstance(x, Number):
911
+ return x
912
+ elif isinstance(x, TensorLike):
913
+ if preserve_cpu_scalar_tensors and torch._prims_common.is_cpu_scalar_tensor(x):
914
+ return x
915
+
916
+ if should_expand(x.shape, common_shape):
917
+ return x.expand(common_shape)
918
+
919
+ return x
920
+ else:
921
+ raise RuntimeError(f"Unexpected type when broadcasting: {str(type(x))}!")
922
+
923
+ return tuple(__maybe_broadcast(x, common_shape) for x in args)
924
+
925
+
926
+ def patched__broadcast_in_dim_meta(
927
+ a: torch._prims_common.TensorLikeType,
928
+ shape: torch._prims_common.ShapeType,
929
+ broadcast_dimensions: Sequence[int],
930
+ ):
931
+ """Patches ``torch._prims._broadcast_in_dim_meta``."""
932
+ from torch.fx.experimental.symbolic_shapes import (
933
+ guard_or_false,
934
+ guard_or_true,
935
+ sym_or,
936
+ )
937
+
938
+ # Type checks
939
+ assert isinstance(a, torch._prims_common.TensorLike)
940
+ assert isinstance(shape, Sequence)
941
+ assert isinstance(broadcast_dimensions, Sequence)
942
+
943
+ # every dimension must be accounted for
944
+ assert a.ndim == len(broadcast_dimensions)
945
+
946
+ # broadcast shape must have weakly more dimensions
947
+ assert len(shape) >= a.ndim
948
+
949
+ # broadcast_dimensions must be an ascending sequence
950
+ # (no relative reordering of dims) of integers and
951
+ # each dimension must be within the new shape
952
+ def _greater_than_reduce(acc, x):
953
+ assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
954
+ assert x > acc
955
+ assert x < len(shape)
956
+
957
+ return x
958
+
959
+ reduce(_greater_than_reduce, broadcast_dimensions, -1)
960
+
961
+ # shape must be broadcastable to
962
+ for idx, new_idx in enumerate(broadcast_dimensions):
963
+ torch._check(
964
+ sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
965
+ lambda idx=idx, new_idx=new_idx: (
966
+ f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
967
+ ),
968
+ )
969
+
970
+ new_strides = []
971
+ original_idx = 0
972
+ for idx in range(len(shape)):
973
+ if idx in broadcast_dimensions:
974
+ # Assigns a stride of zero to dimensions
975
+ # which were actually broadcast
976
+ if guard_or_false(a.shape[original_idx] == 1):
977
+ if guard_or_false(a.shape[original_idx] == shape[idx]):
978
+ new_strides.append(a.stride()[original_idx])
979
+ else:
980
+ new_strides.append(0)
981
+ # PATCHED: disabled this check
982
+ elif guard_or_false(a.shape[original_idx] != 1):
983
+ new_strides.append(a.stride()[original_idx])
984
+ else:
985
+ torch._check(
986
+ a.shape[original_idx] == shape[idx],
987
+ lambda idx=idx, original_idx=original_idx: (
988
+ f"non-broadcasting semantics require "
989
+ f"{a.shape[original_idx]} == {shape[idx]}, "
990
+ f"{guard_or_false(a.shape[idx] != 1)}, "
991
+ f"guard_or_false(a.shape[idx] == 1)="
992
+ f"{guard_or_false(a.shape[idx] == 1)}, "
993
+ f"a.stride()={a.stride()}, idx={idx}, "
994
+ f"original_idx={original_idx}"
995
+ ),
996
+ )
997
+ new_strides.append(a.stride()[original_idx])
998
+ original_idx = original_idx + 1
999
+ else:
1000
+ if guard_or_true(shape[idx] != 1):
1001
+ # consistent with previous use of guard_size_oblivious
1002
+ new_strides.append(0)
1003
+ elif original_idx == a.ndim:
1004
+ new_strides.append(1)
1005
+ else:
1006
+ new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
1007
+
1008
+ return a.as_strided(shape, new_strides, a.storage_offset())