onnx-diagnostic 0.7.0__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +213 -5
  3. onnx_diagnostic/export/dynamic_shapes.py +48 -20
  4. onnx_diagnostic/export/shape_helper.py +126 -0
  5. onnx_diagnostic/ext_test_case.py +31 -0
  6. onnx_diagnostic/helpers/cache_helper.py +42 -20
  7. onnx_diagnostic/helpers/config_helper.py +16 -1
  8. onnx_diagnostic/helpers/log_helper.py +1561 -177
  9. onnx_diagnostic/helpers/torch_helper.py +6 -2
  10. onnx_diagnostic/tasks/__init__.py +2 -0
  11. onnx_diagnostic/tasks/image_text_to_text.py +69 -18
  12. onnx_diagnostic/tasks/text_generation.py +17 -8
  13. onnx_diagnostic/tasks/text_to_image.py +91 -0
  14. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
  15. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +144 -349
  16. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +87 -7
  17. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  18. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  19. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +259 -0
  20. onnx_diagnostic/torch_models/hghub/hub_api.py +73 -5
  21. onnx_diagnostic/torch_models/hghub/hub_data.py +7 -2
  22. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +28 -0
  23. onnx_diagnostic/torch_models/hghub/model_inputs.py +74 -14
  24. onnx_diagnostic/torch_models/validate.py +45 -16
  25. {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/METADATA +1 -1
  26. {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/RECORD +29 -24
  27. {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/WHEEL +0 -0
  28. {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/licenses/LICENSE.txt +0 -0
  29. {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.7.0"
6
+ __version__ = "0.7.2"
7
7
  __author__ = "Xavier Dupré"
@@ -333,7 +333,24 @@ def get_parser_validate() -> ArgumentParser:
333
333
  of supported tasks.
334
334
  """
335
335
  ),
336
- epilog="If the model id is specified, one untrained version of it is instantiated.",
336
+ epilog=textwrap.dedent(
337
+ """
338
+ If the model id is specified, one untrained version of it is instantiated.
339
+ Examples:
340
+
341
+ python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
342
+ --run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
343
+ --dtype float16 --device cuda --patch --export onnx-dynamo --opt ir
344
+
345
+ python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
346
+ --run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
347
+ --dtype float16 --device cuda --patch --export custom --opt default
348
+
349
+ python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
350
+ --run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
351
+ --dtype float16 --device cuda --export modelbuilder
352
+ """
353
+ ),
337
354
  formatter_class=RawTextHelpFormatter,
338
355
  )
339
356
  parser.add_argument("-m", "--mid", type=str, help="model id, usually <author>/<name>")
@@ -372,6 +389,12 @@ def get_parser_validate() -> ArgumentParser:
372
389
  type=int,
373
390
  help="Raises an exception if a dynamic dimension becomes static.",
374
391
  )
392
+ parser.add_argument(
393
+ "--same-as-trained",
394
+ default=False,
395
+ action=BooleanOptionalAction,
396
+ help="Validates a model identical to the trained model but not trained.",
397
+ )
375
398
  parser.add_argument(
376
399
  "--trained",
377
400
  default=False,
@@ -487,7 +510,8 @@ def _cmd_validate(argv: List[Any]):
487
510
  do_run=args.run,
488
511
  verbose=args.verbose,
489
512
  quiet=args.quiet,
490
- trained=args.trained,
513
+ same_as_pretrained=args.same_as_trained,
514
+ use_pretrained=args.trained,
491
515
  dtype=args.dtype,
492
516
  device=args.device,
493
517
  patch=args.patch,
@@ -609,6 +633,178 @@ def _cmd_stats(argv: List[Any]):
609
633
  print("done.")
610
634
 
611
635
 
636
+ def get_parser_agg() -> ArgumentParser:
637
+ parser = ArgumentParser(
638
+ prog="agg",
639
+ description=textwrap.dedent(
640
+ """
641
+ Aggregates statistics coming from benchmarks.
642
+ Every run is a row. Every row is indexed by some keys,
643
+ and produces values. Every row has a date.
644
+ """
645
+ ),
646
+ epilog=textwrap.dedent(
647
+ """
648
+ examples:\n
649
+
650
+ python -m onnx_diagnostic agg test_agg.xlsx raw/*.zip -v 1
651
+ """
652
+ ),
653
+ formatter_class=RawTextHelpFormatter,
654
+ )
655
+ parser.add_argument("output", help="output excel file")
656
+ parser.add_argument(
657
+ "inputs",
658
+ nargs="+",
659
+ help="input csv or zip files, at least 1, it can be a name, or search path",
660
+ )
661
+ parser.add_argument(
662
+ "--filter", default="rawdata_.*.csv", help="filter for input files inside zip files"
663
+ )
664
+ parser.add_argument(
665
+ "--recent",
666
+ default=True,
667
+ action=BooleanOptionalAction,
668
+ help="Keeps only the most recent experiment for the same of keys.",
669
+ )
670
+ parser.add_argument(
671
+ "--keep-last-date",
672
+ default=False,
673
+ action=BooleanOptionalAction,
674
+ help="Rewrite all dates to the last one to simplifies the analysis, "
675
+ "this assume changing the date does not add ambiguity, if any, option "
676
+ "--recent should be added.",
677
+ )
678
+ parser.add_argument(
679
+ "--raw",
680
+ default=True,
681
+ action=BooleanOptionalAction,
682
+ help="Keeps the raw data in a sheet.",
683
+ )
684
+ parser.add_argument("-t", "--time", default="DATE", help="Date or time column")
685
+ parser.add_argument(
686
+ "-k",
687
+ "--keys",
688
+ default="^version_.*,^model_.*,device,opt_patterns,suite,memory_peak,"
689
+ "machine,exporter,dynamic,rtopt,dtype,device,architecture",
690
+ help="List of columns to consider as keys, "
691
+ "multiple values are separated by `,`\n"
692
+ "regular expressions are allowed",
693
+ )
694
+ parser.add_argument(
695
+ "--drop-keys",
696
+ default="",
697
+ help="Drops keys from the given list. Something it is faster "
698
+ "to remove one than to select all the remaining ones.",
699
+ )
700
+ parser.add_argument(
701
+ "-w",
702
+ "--values",
703
+ default="^time_.*,^disc.*,^ERR_.*,CMD,^ITER.*,^onnx_.*,^op_onnx_.*,^peak_gpu_.*",
704
+ help="List of columns to consider as values, "
705
+ "multiple values are separated by `,`\n"
706
+ "regular expressions are allowed",
707
+ )
708
+ parser.add_argument(
709
+ "-i", "--ignored", default="^version_.*", help="List of columns to ignore"
710
+ )
711
+ parser.add_argument(
712
+ "-f",
713
+ "--formula",
714
+ default="speedup,bucket[speedup],ERR1,n_models,n_model_eager,"
715
+ "n_model_running,n_model_acc01,n_model_acc001,n_model_dynamic,"
716
+ "n_model_pass,n_model_faster,"
717
+ "n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
718
+ "peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
719
+ "n_node_constant,n_node_shape,n_node_expand,"
720
+ "n_node_function,n_node_initializer,n_node_scatter,"
721
+ "time_export_unbiased,onnx_n_nodes_no_cst,n_node_initializer_small",
722
+ help="Columns to compute after the aggregation was done.",
723
+ )
724
+ parser.add_argument(
725
+ "--views",
726
+ default="agg-suite,agg-all,disc,speedup,time,time_export,err,cmd,"
727
+ "bucket-speedup,raw-short,counts,peak-gpu,onnx",
728
+ help="Views to add to the output files.",
729
+ )
730
+ parser.add_argument(
731
+ "--csv",
732
+ default="raw-short",
733
+ help="Views to dump as csv files.",
734
+ )
735
+ parser.add_argument("-v", "--verbose", type=int, default=0, help="verbosity")
736
+ parser.add_argument(
737
+ "--filter-in",
738
+ default="",
739
+ help="adds a filter to filter in data, syntax is\n"
740
+ '``"<column1>:<value1>;<value2>/<column2>:<value3>"`` ...',
741
+ )
742
+ parser.add_argument(
743
+ "--filter-out",
744
+ default="",
745
+ help="adds a filter to filter out data, syntax is\n"
746
+ '``"<column1>:<value1>;<value2>/<column2>:<value3>"`` ...',
747
+ )
748
+ return parser
749
+
750
+
751
+ def _cmd_agg(argv: List[Any]):
752
+ from .helpers.log_helper import (
753
+ CubeLogsPerformance,
754
+ open_dataframe,
755
+ enumerate_csv_files,
756
+ filter_data,
757
+ )
758
+
759
+ parser = get_parser_agg()
760
+ args = parser.parse_args(argv[1:])
761
+ reg = re.compile(args.filter)
762
+
763
+ csv = list(
764
+ enumerate_csv_files(
765
+ args.inputs, verbose=args.verbose, filtering=lambda name: bool(reg.search(name))
766
+ )
767
+ )
768
+ assert csv, f"No csv files in {args.inputs}, args.filter={args.filter!r}, csv={csv}"
769
+ if args.verbose:
770
+ from tqdm import tqdm
771
+
772
+ loop = tqdm(csv)
773
+ else:
774
+ loop = csv
775
+ dfs = []
776
+ for c in loop:
777
+ df = open_dataframe(c)
778
+ assert (
779
+ args.time in df.columns
780
+ ), f"Missing time column {args.time!r} in {c!r}\n{df.head()}\n{sorted(df.columns)}"
781
+ dfs.append(filter_data(df, filter_in=args.filter_in, filter_out=args.filter_out))
782
+
783
+ drop_keys = set(args.drop_keys.split(","))
784
+ cube = CubeLogsPerformance(
785
+ dfs,
786
+ time=args.time,
787
+ keys=[a for a in args.keys.split(",") if a and a not in drop_keys],
788
+ values=[a for a in args.values.split(",") if a],
789
+ ignored=[a for a in args.ignored.split(",") if a],
790
+ recent=args.recent,
791
+ formulas={k: k for k in args.formula.split(",")},
792
+ keep_last_date=args.keep_last_date,
793
+ )
794
+ cube.load(verbose=max(args.verbose - 1, 0))
795
+ if args.verbose:
796
+ print(f"Dumps final file into {args.output!r}")
797
+ cube.to_excel(
798
+ args.output,
799
+ {k: k for k in args.views.split(",")},
800
+ verbose=args.verbose,
801
+ csv=args.csv.split(","),
802
+ raw=args.raw,
803
+ )
804
+ if args.verbose:
805
+ print(f"Wrote {args.output!r}")
806
+
807
+
612
808
  def get_main_parser() -> ArgumentParser:
613
809
  parser = ArgumentParser(
614
810
  prog="onnx_diagnostic",
@@ -619,19 +815,29 @@ def get_main_parser() -> ArgumentParser:
619
815
  Type 'python -m onnx_diagnostic <cmd> --help'
620
816
  to get help for a specific command.
621
817
 
818
+ agg - aggregates statistics from multiple files
622
819
  config - prints a configuration for a model id
623
820
  find - find node consuming or producing a result
624
821
  lighten - makes an onnx model lighter by removing the weights,
625
- unlighten - restores an onnx model produces by the previous experiment
626
822
  print - prints the model on standard output
627
- validate - validate a model
628
823
  stats - produces statistics on a model
824
+ unlighten - restores an onnx model produces by the previous experiment
825
+ validate - validate a model
629
826
  """
630
827
  ),
631
828
  )
632
829
  parser.add_argument(
633
830
  "cmd",
634
- choices=["config", "find", "lighten", "print", "stats", "unlighten", "validate"],
831
+ choices=[
832
+ "agg",
833
+ "config",
834
+ "find",
835
+ "lighten",
836
+ "print",
837
+ "stats",
838
+ "unlighten",
839
+ "validate",
840
+ ],
635
841
  help="Selects a command.",
636
842
  )
637
843
  return parser
@@ -646,6 +852,7 @@ def main(argv: Optional[List[Any]] = None):
646
852
  config=_cmd_config,
647
853
  validate=_cmd_validate,
648
854
  stats=_cmd_stats,
855
+ agg=_cmd_agg,
649
856
  )
650
857
 
651
858
  if argv is None:
@@ -667,6 +874,7 @@ def main(argv: Optional[List[Any]] = None):
667
874
  config=get_parser_config,
668
875
  validate=get_parser_validate,
669
876
  stats=get_parser_stats,
877
+ agg=get_parser_agg,
670
878
  )
671
879
  cmd = argv[0]
672
880
  if cmd not in parsers:
@@ -630,9 +630,12 @@ class ModelInputs:
630
630
  method_name: str = "forward",
631
631
  name: str = "main",
632
632
  ):
633
- assert isinstance(model, torch.nn.Module) or inspect.ismodule(
634
- model
635
- ), f"unexpected type for model={type(model)}, it must be a torch.nn.Module"
633
+ assert (
634
+ model is None or isinstance(model, torch.nn.Module) or inspect.ismodule(model)
635
+ ), (
636
+ f"unexpected type for model={type(model)}, "
637
+ f"it must be a torch.nn.Module or None"
638
+ )
636
639
  assert name, (
637
640
  f"name={name!r} cannot be empty this string is used to "
638
641
  f"display meaningful error messages"
@@ -641,26 +644,42 @@ class ModelInputs:
641
644
  self.model = model
642
645
  self.level = level
643
646
  self.method_name = method_name
644
- self.forward = getattr(model, method_name)
645
- self.signature = inspect.signature(self.forward)
647
+ self.forward = getattr(model, method_name) if model is not None else None
648
+ self.signature = inspect.signature(self.forward) if self.forward else None
646
649
 
647
650
  # information about the signature
648
- self.forward_parameter_names = set(
649
- p.name
650
- for p in self.signature.parameters.values()
651
- if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
651
+ self.forward_parameter_names = (
652
+ set(
653
+ p.name
654
+ for p in self.signature.parameters.values()
655
+ if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
656
+ )
657
+ if self.signature
658
+ else None
659
+ )
660
+ self.forward_ordered_parameter_names = (
661
+ list(self.signature.parameters) if self.signature else None
662
+ )
663
+ self.forward_positioned_parameter_names = (
664
+ [
665
+ p.name
666
+ for p in self.signature.parameters.values()
667
+ if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
668
+ ]
669
+ if self.signature
670
+ else None
671
+ )
672
+ names = (
673
+ [p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL]
674
+ if self.signature
675
+ else None
652
676
  )
653
- self.forward_ordered_parameter_names = list(self.signature.parameters)
654
- self.forward_positioned_parameter_names = [
655
- p.name
656
- for p in self.signature.parameters.values()
657
- if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
658
- ]
659
- names = [
660
- p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL
661
- ]
662
677
  self.forward_args = names[0] if names else None
663
- names = [p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD]
678
+ names = (
679
+ [p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD]
680
+ if self.signature
681
+ else None
682
+ )
664
683
  self.forward_kwargs = names[0] if names else None
665
684
  self.forward_custom_op_schema = None
666
685
  self.forward_need_serialization = False
@@ -711,6 +730,7 @@ class ModelInputs:
711
730
  @property
712
731
  def true_model_name(self) -> str:
713
732
  "Returns class name or module name."
733
+ assert self.model is not None, "model was None when the class was initialized."
714
734
  return (
715
735
  self.model.__class__.__name__
716
736
  if isinstance(self.model, torch.nn.Module)
@@ -942,7 +962,7 @@ class ModelInputs:
942
962
  )
943
963
  )
944
964
  names = s2.pop()
945
- for name in names:
965
+ for i, name in enumerate(names):
946
966
  assert name not in {"_diag", "verbose"}, (
947
967
  f"{self.full_name}: unexpected parameter {name!r}, names={names}"
948
968
  f"\ninputs[0]={string_type(self.inputs[0], with_shape=True)}"
@@ -968,6 +988,14 @@ class ModelInputs:
968
988
  with the corresponding dynamic shapes.
969
989
  *kwargs*, *dynamic_shapes* are modified inplace.
970
990
  """
991
+ assert (
992
+ self.signature is not None
993
+ and self.forward_parameter_names is not None
994
+ and self.forward_ordered_parameter_names is not None
995
+ ), (
996
+ "model was None when the class was initialized, "
997
+ "cannot move args to kwargs without the signature."
998
+ )
971
999
  sig = self.signature
972
1000
  arg_dyn, kw_dyn = dynamic_shapes
973
1001
  for i, p in enumerate(sig.parameters):
@@ -0,0 +1,126 @@
1
+ from typing import Any, Dict, List, Set, Tuple, Union
2
+ from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
3
+ from .dynamic_shapes import ModelInputs
4
+
5
+
6
+ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
7
+ """
8
+ Returns the dynamic shapes for the given inputs.
9
+ All dimensions are considered as dynamic.
10
+ ``dim_prefix`` can be a string (the function uses it as a prefix),
11
+ or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.
12
+
13
+ .. runpython::
14
+ :showcode:
15
+
16
+ import pprint
17
+ import torch
18
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
19
+ from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
20
+
21
+ bsize, nheads, slen, dim = 2, 1, 30, 96
22
+ inputs = dict(
23
+ input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
24
+ attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
25
+ position_ids=torch.arange(3, dtype=torch.int64),
26
+ past_key_values=make_dynamic_cache(
27
+ [(torch.randn(bsize, nheads, slen, dim),
28
+ torch.randn(bsize, nheads, slen, dim))]
29
+ ),
30
+ )
31
+ ds = all_dynamic_shape_from_inputs(inputs)
32
+ pprint.pprint(ds)
33
+ """
34
+ if isinstance(dim_prefix, str):
35
+ prefixes: Set[str] = set()
36
+
37
+ def tensor_to_shape(tensor):
38
+ n = len(prefixes)
39
+ p = f"{dim_prefix}_{n}"
40
+ prefixes.add(p)
41
+ return {i: f"{p}_{i}" for i in range(tensor.ndim)}
42
+
43
+ else:
44
+
45
+ def tensor_to_shape(tensor):
46
+ return {i: dim_prefix for i in range(tensor.ndim)} # noqa: C420
47
+
48
+ return flatten_unflatten_for_dynamic_shapes(
49
+ inputs, change_function=tensor_to_shape, use_dict=True
50
+ )
51
+
52
+
53
+ def guess_dynamic_shapes_from_inputs(
54
+ inputs: List[Any], auto: Union[bool, str] = False
55
+ ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
56
+ """
57
+ Guesses which dimension is dimension from a set of inputs.
58
+ Every dimension having different values over multiple sets
59
+ of inputs. Every dimension not changing remains static.
60
+
61
+ :param inputs: a list of input sets
62
+ :param auto: True for ``torch.export.Dim.AUTO``,
63
+ False for ``torch.export.Dim.DYNAMIC``,
64
+ a string to get a unique string for every dynamic dimension
65
+ :return: args and kwargs
66
+
67
+ .. runpython::
68
+ :showcode:
69
+
70
+ import pprint
71
+ import torch
72
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
73
+ from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
74
+
75
+ bsize, nheads, slen, dim = 2, 1, 30, 96
76
+ inputs1 = dict(
77
+ input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
78
+ attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
79
+ position_ids=torch.arange(3, dtype=torch.int64),
80
+ past_key_values=make_dynamic_cache(
81
+ [
82
+ (
83
+ torch.randn(bsize, nheads, slen, dim),
84
+ torch.randn(bsize, nheads, slen, dim),
85
+ ),
86
+ ]
87
+ ),
88
+ )
89
+ bsize, nheads, slen, dim = 3, 1, 33, 96
90
+ inputs2 = dict(
91
+ input_ids=torch.randint(15, size=(3, 4), dtype=torch.int64),
92
+ attention_mask=torch.randint(1, size=(3, 34), dtype=torch.int64),
93
+ position_ids=torch.arange(4, dtype=torch.int64),
94
+ past_key_values=make_dynamic_cache(
95
+ [
96
+ (
97
+ torch.randn(bsize, nheads, slen, dim),
98
+ torch.randn(bsize, nheads, slen, dim),
99
+ ),
100
+ ]
101
+ ),
102
+ )
103
+ ds = guess_dynamic_shapes_from_inputs([inputs1, inputs2], auto="d")
104
+ pprint.pprint(ds)
105
+
106
+ This function returns something equivalent to function
107
+ :class:`torch.export.dynamic_shapes.AdditionalInputs` but this
108
+ one needs a model.
109
+
110
+ .. runpython::
111
+ :showcode:
112
+
113
+ import pprint
114
+ import torch
115
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
116
+ from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
117
+ from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
118
+
119
+ data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True)
120
+ ds = torch.export.dynamic_shapes.AdditionalInputs()
121
+ ds.add((), data["inputs"])
122
+ ds.add((), data["inputs2"])
123
+ pprint.pprint(ds.dynamic_shapes(data["model"], (), data["inputs"]))
124
+ """
125
+ mi = ModelInputs(None, inputs)
126
+ return mi.guess_dynamic_shapes(auto=auto)
@@ -756,6 +756,18 @@ class ExtTestCase(unittest.TestCase):
756
756
  "Adds a todo printed when all test are run."
757
757
  cls._todos.append((f, msg))
758
758
 
759
+ @classmethod
760
+ def ort(cls):
761
+ import onnxruntime
762
+
763
+ return onnxruntime
764
+
765
+ @classmethod
766
+ def to_onnx(self, *args, **kwargs):
767
+ from experimental_experiment.torch_interpreter import to_onnx
768
+
769
+ return to_onnx(*args, **kwargs)
770
+
759
771
  def print_model(self, model: "ModelProto"): # noqa: F821
760
772
  "Prints a ModelProto"
761
773
  from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
@@ -917,6 +929,15 @@ class ExtTestCase(unittest.TestCase):
917
929
  ]
918
930
  raise AssertionError("\n".join(rows)) # noqa: B904
919
931
 
932
+ def assertEqualDataFrame(self, d1, d2, **kwargs):
933
+ """
934
+ Checks that two dataframes are equal.
935
+ Calls :func:`pandas.testing.assert_frame_equal`.
936
+ """
937
+ from pandas.testing import assert_frame_equal
938
+
939
+ assert_frame_equal(d1, d2, **kwargs)
940
+
920
941
  def assertEqualTrue(self, value: Any, msg: str = ""):
921
942
  if value is True:
922
943
  return
@@ -967,6 +988,16 @@ class ExtTestCase(unittest.TestCase):
967
988
  atol=atol,
968
989
  rtol=rtol,
969
990
  )
991
+ elif expected.__class__.__name__ == "StaticCache":
992
+ self.assertEqual(type(expected), type(value), msg=msg)
993
+ self.assertEqual(expected.max_cache_len, value.max_cache_len)
994
+ atts = ["key_cache", "value_cache"]
995
+ self.assertEqualAny(
996
+ {k: expected.__dict__.get(k, None) for k in atts},
997
+ {k: value.__dict__.get(k, None) for k in atts},
998
+ atol=atol,
999
+ rtol=rtol,
1000
+ )
970
1001
  elif expected.__class__.__name__ == "EncoderDecoderCache":
971
1002
  self.assertEqual(type(expected), type(value), msg=msg)
972
1003
  atts = ["self_attention_cache", "cross_attention_cache"]
@@ -1,11 +1,15 @@
1
- from typing import Any, List, Tuple
1
+ from typing import Any, Callable, List, Optional, Tuple
2
2
  import packaging.version as pv
3
3
  import torch
4
4
  import transformers
5
5
  import transformers.cache_utils
6
6
 
7
7
 
8
- def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> Any:
8
+ def flatten_unflatten_for_dynamic_shapes(
9
+ obj: Any,
10
+ use_dict: bool = False,
11
+ change_function: Optional[Callable[[torch.Tensor], Any]] = None,
12
+ ) -> Any:
9
13
  """
10
14
  Returns the object in a different structure similar to what
11
15
  the definition of the dynamic shapes should use.
@@ -15,11 +19,13 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An
15
19
  :func:`torch.export.export` only considers the values,
16
20
  the context gives the dictionary keys but it is not expressed
17
21
  in the dynamic shapes, these specifications seems to be different
18
- for the strict and non strict mode.
22
+ for the strict and non strict mode. It also preserves tuple.
23
+ :param change_function: to modifies the tensor in the structure itself,
24
+ like replace them by a shape
19
25
  :return: the serialized object
20
26
  """
21
27
  if isinstance(obj, torch.Tensor):
22
- return obj
28
+ return change_function(obj) if change_function else obj
23
29
  flat, spec = torch.utils._pytree.tree_flatten(obj)
24
30
  start = 0
25
31
  end = 0
@@ -27,12 +33,17 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An
27
33
  for subspec in spec.children_specs:
28
34
  end += subspec.num_leaves
29
35
  value = subspec.unflatten(flat[start:end])
30
- value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
36
+ value = flatten_unflatten_for_dynamic_shapes(
37
+ value, use_dict=use_dict, change_function=change_function
38
+ )
31
39
  subtrees.append(value)
32
40
  start = end
33
- if use_dict and (spec.type is dict or spec.context):
34
- # This a dictionary.
35
- return dict(zip(spec.context, subtrees))
41
+ if use_dict:
42
+ if spec.type is dict or spec.context:
43
+ # This a dictionary.
44
+ return dict(zip(spec.context, subtrees))
45
+ if spec.type is tuple:
46
+ return tuple(subtrees)
36
47
  # This is a list.
37
48
  return subtrees
38
49
 
@@ -143,10 +154,12 @@ else:
143
154
 
144
155
  def make_static_cache(
145
156
  key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
157
+ max_cache_len: Optional[int] = None,
146
158
  ) -> transformers.cache_utils.DynamicCache:
147
159
  """
148
160
  Creates an instance of :class:`transformers.cache_utils.StaticCache`.
149
161
  :param key_value_pairs: list of pairs of (key, values)
162
+ :param max_cache_len: max_cache_length or something inferred from the vector
150
163
  :return: :class:`transformers.cache_utils.StaticCache`
151
164
 
152
165
  Example:
@@ -168,7 +181,8 @@ def make_static_cache(
168
181
  torch.randn(bsize, nheads, slen, dim),
169
182
  )
170
183
  for i in range(n_layers)
171
- ]
184
+ ],
185
+ max_cache_len=10,
172
186
  )
173
187
  print(string_type(past_key_values, with_shape=True))
174
188
  """
@@ -179,24 +193,32 @@ def make_static_cache(
179
193
  self.num_attention_heads = key_value_pairs[0][0].shape[1]
180
194
  self.num_hidden_layers = len(key_value_pairs)
181
195
 
196
+ assert max_cache_len is not None, (
197
+ f"max_cache_len={max_cache_len} cannot be setup "
198
+ f"automatically yet from shape {key_value_pairs[0][0].shape}"
199
+ )
200
+ torch._check(
201
+ max_cache_len >= key_value_pairs[0][0].shape[2],
202
+ (
203
+ f"max_cache_len={max_cache_len} cannot be smaller "
204
+ f"shape[2]={key_value_pairs[0][0].shape[2]} in shape "
205
+ f"{key_value_pairs[0][0].shape}"
206
+ ),
207
+ )
182
208
  cache = transformers.cache_utils.StaticCache(
183
209
  _config(),
184
210
  max_batch_size=key_value_pairs[0][0].shape[0],
185
211
  device=key_value_pairs[0][0].device,
186
212
  dtype=key_value_pairs[0][0].dtype,
187
- max_cache_len=key_value_pairs[0][0].shape[2],
213
+ max_cache_len=max_cache_len,
188
214
  )
189
215
  for i in range(len(key_value_pairs)):
190
- assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, (
191
- f"Shape mismatch, expected {cache.key_cache[i].shape}, "
192
- f"got {key_value_pairs[i][0].shape}"
193
- )
194
- cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
195
- assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, (
196
- f"Shape mismatch, expected {cache.value_cache[i].shape}, "
197
- f"got {key_value_pairs[i][1].shape}"
198
- )
199
- cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
216
+ assert (
217
+ key_value_pairs[i][0].shape == key_value_pairs[i][1].shape
218
+ ), f"Shape mismatch {key_value_pairs[i][0].shape} != {key_value_pairs[i][1].shape}"
219
+ d = key_value_pairs[i][1].shape[2]
220
+ cache.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
221
+ cache.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
200
222
  return cache
201
223
 
202
224