onnx-diagnostic 0.7.14__py3-none-any.whl → 0.7.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +156 -47
  3. onnx_diagnostic/export/dynamic_shapes.py +6 -6
  4. onnx_diagnostic/export/shape_helper.py +124 -6
  5. onnx_diagnostic/ext_test_case.py +5 -1
  6. onnx_diagnostic/helpers/cache_helper.py +68 -42
  7. onnx_diagnostic/helpers/config_helper.py +2 -1
  8. onnx_diagnostic/helpers/fake_tensor_helper.py +153 -0
  9. onnx_diagnostic/helpers/helper.py +3 -0
  10. onnx_diagnostic/helpers/rt_helper.py +3 -3
  11. onnx_diagnostic/tasks/image_text_to_text.py +7 -6
  12. onnx_diagnostic/tasks/text_generation.py +7 -4
  13. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +69 -11
  14. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +31 -13
  15. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +109 -18
  16. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +133 -28
  17. onnx_diagnostic/torch_models/code_sample.py +343 -0
  18. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +38 -0
  19. onnx_diagnostic/torch_models/hghub/model_inputs.py +7 -3
  20. onnx_diagnostic/torch_models/validate.py +73 -29
  21. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/METADATA +6 -6
  22. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/RECORD +25 -23
  23. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/WHEEL +0 -0
  24. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/licenses/LICENSE.txt +0 -0
  25. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.7.14"
6
+ __version__ = "0.7.16"
7
7
  __author__ = "Xavier Dupré"
@@ -371,30 +371,34 @@ class _BoolOrParseDictPatch(argparse.Action):
371
371
  setattr(namespace, self.dest, d)
372
372
 
373
373
 
374
- def get_parser_validate() -> ArgumentParser:
374
+ def get_parser_validate(name: str = "validate") -> ArgumentParser:
375
375
  parser = ArgumentParser(
376
- prog="validate",
376
+ prog=name,
377
377
  description=textwrap.dedent(
378
378
  """
379
- Prints out dummy inputs for a particular task or a model id.
380
- If both mid and task are empty, the command line displays the list
381
- of supported tasks.
379
+ Validates a model for a particular task given the model id.
380
+ It exports the model and then validates it by computing the discrepancies
381
+ on different input sets.
382
+ """
383
+ if name == "validate"
384
+ else """
385
+ Creates a script to export a model for a particular task given the model id.
382
386
  """
383
387
  ),
384
388
  epilog=textwrap.dedent(
385
- """
389
+ f"""
386
390
  If the model id is specified, one untrained version of it is instantiated.
387
391
  Examples:
388
392
 
389
- python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
393
+ python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\
390
394
  --run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
391
395
  --dtype float16 --device cuda --patch --export onnx-dynamo --opt ir
392
396
 
393
- python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
397
+ python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\
394
398
  --run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
395
399
  --dtype float16 --device cuda --patch --export custom --opt default
396
400
 
397
- python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
401
+ python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\
398
402
  --run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
399
403
  --dtype float16 --device cuda --export modelbuilder
400
404
 
@@ -405,12 +409,12 @@ def get_parser_validate() -> ArgumentParser:
405
409
  The behaviour may be modified compare the original configuration,
406
410
  the following argument can be rope_scaling to dynamic:
407
411
 
408
- --mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\""
412
+ --mop \"rope_scaling={{'rope_type': 'dynamic', 'factor': 10.0}}\""
409
413
 
410
414
  You can profile the command line by running:
411
415
 
412
- pyinstrument -m onnx_diagnostic validate ...
413
- pyinstrument -r html -o profile.html -m onnx_diagnostic validate ...
416
+ pyinstrument -m onnx_diagnostic {name} ...
417
+ pyinstrument -r html -o profile.html -m onnx_diagnostic {name} ...
414
418
  """
415
419
  ),
416
420
  formatter_class=RawTextHelpFormatter,
@@ -460,19 +464,19 @@ def get_parser_validate() -> ArgumentParser:
460
464
  "--same-as-trained",
461
465
  default=False,
462
466
  action=BooleanOptionalAction,
463
- help="Validates a model identical to the trained model but not trained.",
467
+ help="Validates or exports a model identical to the trained model but not trained.",
464
468
  )
465
469
  parser.add_argument(
466
470
  "--trained",
467
471
  default=False,
468
472
  action=BooleanOptionalAction,
469
- help="Validates the trained model (requires downloading).",
473
+ help="Validates or exports the trained model (requires downloading).",
470
474
  )
471
475
  parser.add_argument(
472
476
  "--inputs2",
473
477
  default=1,
474
478
  type=int,
475
- help="Validates the model on a second set of inputs\n"
479
+ help="Validates or exports the model on a second set of inputs\n"
476
480
  "to check the exported model supports dynamism. The values is used "
477
481
  "as an increment to the first set of inputs. A high value may trick "
478
482
  "a different behavior in the model and missed by the exporter.",
@@ -504,13 +508,14 @@ def get_parser_validate() -> ArgumentParser:
504
508
  "--subfolder",
505
509
  help="Subfolder where to find the model and the configuration.",
506
510
  )
507
- parser.add_argument(
508
- "--ortfusiontype",
509
- required=False,
510
- help="Applies onnxruntime fusion, this parameter should contain the\n"
511
- "model type or multiple values separated by `|`. `ALL` can be used\n"
512
- "to run them all.",
513
- )
511
+ if name == "validate":
512
+ parser.add_argument(
513
+ "--ortfusiontype",
514
+ required=False,
515
+ help="Applies onnxruntime fusion, this parameter should contain the\n"
516
+ "model type or multiple values separated by `|`. `ALL` can be used\n"
517
+ "to run them all.",
518
+ )
514
519
  parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
515
520
  parser.add_argument("--dtype", help="Changes dtype if necessary.")
516
521
  parser.add_argument("--device", help="Changes the device if necessary.")
@@ -532,27 +537,38 @@ def get_parser_validate() -> ArgumentParser:
532
537
  "--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"",
533
538
  action=_ParseDict,
534
539
  )
535
- parser.add_argument(
536
- "--repeat",
537
- default=1,
538
- type=int,
539
- help="number of times to run the model to measures inference time",
540
- )
541
- parser.add_argument(
542
- "--warmup", default=0, type=int, help="number of times to run the model to do warmup"
543
- )
540
+ if name == "validate":
541
+ parser.add_argument(
542
+ "--repeat",
543
+ default=1,
544
+ type=int,
545
+ help="number of times to run the model to measures inference time",
546
+ )
547
+ parser.add_argument(
548
+ "--warmup",
549
+ default=0,
550
+ type=int,
551
+ help="number of times to run the model to do warmup",
552
+ )
544
553
  parser.add_argument(
545
554
  "--outnames",
546
555
  help="This comma separated list defines the output names "
547
556
  "the onnx exporter should use.",
548
557
  default="",
549
558
  )
550
- parser.add_argument(
551
- "--ort-logs",
552
- default=False,
553
- action=BooleanOptionalAction,
554
- help="Enables onnxruntime logging when the session is created",
555
- )
559
+ if name == "validate":
560
+ parser.add_argument(
561
+ "--ort-logs",
562
+ default=False,
563
+ action=BooleanOptionalAction,
564
+ help="Enables onnxruntime logging when the session is created",
565
+ )
566
+ parser.add_argument(
567
+ "--quiet-input-sets",
568
+ default="",
569
+ help="Avoids raising an exception when an input sets does not work with "
570
+ "the exported model.\nExample: --quiet-input-sets=inputs,inputs22",
571
+ )
556
572
  return parser
557
573
 
558
574
 
@@ -614,6 +630,7 @@ def _cmd_validate(argv: List[Any]):
614
630
  warmup=args.warmup,
615
631
  inputs2=args.inputs2,
616
632
  ort_logs=args.ort_logs,
633
+ quiet_input_sets=set(args.quiet_input_sets.split(",")),
617
634
  output_names=(
618
635
  None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
619
636
  ),
@@ -624,6 +641,94 @@ def _cmd_validate(argv: List[Any]):
624
641
  print(f":{k},{v};")
625
642
 
626
643
 
644
+ def _cmd_export_sample(argv: List[Any]):
645
+ from .helpers import string_type
646
+ from .torch_models.validate import get_inputs_for_task, _make_folder_name
647
+ from .torch_models.code_sample import code_sample
648
+ from .tasks import supported_tasks
649
+
650
+ parser = get_parser_validate("exportsample")
651
+ args = parser.parse_args(argv[1:])
652
+ if not args.task and not args.mid:
653
+ print("-- list of supported tasks:")
654
+ print("\n".join(supported_tasks()))
655
+ elif not args.mid:
656
+ data = get_inputs_for_task(args.task)
657
+ if args.verbose:
658
+ print(f"task: {args.task}")
659
+ max_length = max(len(k) for k in data["inputs"]) + 1
660
+ print("-- inputs")
661
+ for k, v in data["inputs"].items():
662
+ print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}")
663
+ print("-- dynamic_shapes")
664
+ for k, v in data["dynamic_shapes"].items():
665
+ print(f" + {k.ljust(max_length)}: {string_type(v)}")
666
+ else:
667
+ # Let's skip any invalid combination if known to be unsupported
668
+ if (
669
+ "onnx" not in (args.export or "")
670
+ and "custom" not in (args.export or "")
671
+ and (args.opt or "")
672
+ ):
673
+ print(f"code-sample - unsupported args: export={args.export!r}, opt={args.opt!r}")
674
+ return
675
+ patch_dict = args.patch if isinstance(args.patch, dict) else {"patch": args.patch}
676
+ code = code_sample(
677
+ model_id=args.mid,
678
+ task=args.task,
679
+ do_run=args.run,
680
+ verbose=args.verbose,
681
+ quiet=args.quiet,
682
+ same_as_pretrained=args.same_as_trained,
683
+ use_pretrained=args.trained,
684
+ dtype=args.dtype,
685
+ device=args.device,
686
+ patch=patch_dict,
687
+ rewrite=args.rewrite and patch_dict.get("patch", True),
688
+ stop_if_static=args.stop_if_static,
689
+ optimization=args.opt,
690
+ exporter=args.export,
691
+ dump_folder=args.dump_folder,
692
+ drop_inputs=None if not args.drop else args.drop.split(","),
693
+ input_options=args.iop,
694
+ model_options=args.mop,
695
+ subfolder=args.subfolder,
696
+ opset=args.opset,
697
+ runtime=args.runtime,
698
+ output_names=(
699
+ None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
700
+ ),
701
+ )
702
+ if args.dump_folder:
703
+ os.makedirs(args.dump_folder, exist_ok=True)
704
+ name = (
705
+ _make_folder_name(
706
+ model_id=args.mid,
707
+ exporter=args.export,
708
+ optimization=args.opt,
709
+ dtype=args.dtype,
710
+ device=args.device,
711
+ subfolder=args.subfolder,
712
+ opset=args.opset,
713
+ drop_inputs=None if not args.drop else args.drop.split(","),
714
+ same_as_pretrained=args.same_as_trained,
715
+ use_pretrained=args.trained,
716
+ task=args.task,
717
+ ).replace("/", "-")
718
+ + ".py"
719
+ )
720
+ fullname = os.path.join(args.dump_folder, name)
721
+ if args.verbose:
722
+ print(f"-- prints code in {fullname!r}")
723
+ print("--")
724
+ with open(fullname, "w") as f:
725
+ f.write(code)
726
+ if args.verbose:
727
+ print("-- done")
728
+ else:
729
+ print(code)
730
+
731
+
627
732
  def get_parser_stats() -> ArgumentParser:
628
733
  parser = ArgumentParser(
629
734
  prog="stats",
@@ -834,7 +939,7 @@ def get_parser_agg() -> ArgumentParser:
834
939
  "n_model_pass,n_model_faster,"
835
940
  "n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
836
941
  "n_node_attention23,n_node_rotary_embedding,n_node_rotary_embedding23,"
837
- "n_node_layer_normalization,n_node_layer_normalization23,"
942
+ "n_node_gqa,n_node_layer_normalization,n_node_layer_normalization23,"
838
943
  "peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
839
944
  "n_node_constant,n_node_shape,n_node_expand,"
840
945
  "n_node_function,n_node_initializer,n_node_scatter,"
@@ -953,14 +1058,15 @@ def get_main_parser() -> ArgumentParser:
953
1058
  Type 'python -m onnx_diagnostic <cmd> --help'
954
1059
  to get help for a specific command.
955
1060
 
956
- agg - aggregates statistics from multiple files
957
- config - prints a configuration for a model id
958
- find - find node consuming or producing a result
959
- lighten - makes an onnx model lighter by removing the weights,
960
- print - prints the model on standard output
961
- stats - produces statistics on a model
962
- unlighten - restores an onnx model produces by the previous experiment
963
- validate - validate a model
1061
+ agg - aggregates statistics from multiple files
1062
+ config - prints a configuration for a model id
1063
+ exportsample - produces a code to export a model
1064
+ find - find node consuming or producing a result
1065
+ lighten - makes an onnx model lighter by removing the weights,
1066
+ print - prints the model on standard output
1067
+ stats - produces statistics on a model
1068
+ unlighten - restores an onnx model produces by the previous experiment
1069
+ validate - validate a model
964
1070
  """
965
1071
  ),
966
1072
  )
@@ -969,6 +1075,7 @@ def get_main_parser() -> ArgumentParser:
969
1075
  choices=[
970
1076
  "agg",
971
1077
  "config",
1078
+ "exportsample",
972
1079
  "find",
973
1080
  "lighten",
974
1081
  "print",
@@ -991,6 +1098,7 @@ def main(argv: Optional[List[Any]] = None):
991
1098
  validate=_cmd_validate,
992
1099
  stats=_cmd_stats,
993
1100
  agg=_cmd_agg,
1101
+ exportsample=_cmd_export_sample,
994
1102
  )
995
1103
 
996
1104
  if argv is None:
@@ -1013,13 +1121,14 @@ def main(argv: Optional[List[Any]] = None):
1013
1121
  validate=get_parser_validate,
1014
1122
  stats=get_parser_stats,
1015
1123
  agg=get_parser_agg,
1124
+ exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
1016
1125
  )
1017
1126
  cmd = argv[0]
1018
1127
  if cmd not in parsers:
1019
1128
  raise ValueError(
1020
1129
  f"Unknown command {cmd!r}, it should be in {list(sorted(parsers))}."
1021
1130
  )
1022
- parser = parsers[cmd]()
1131
+ parser = parsers[cmd]() # type: ignore[operator]
1023
1132
  parser.parse_args(argv[1:])
1024
1133
  raise RuntimeError("The programme should have exited before.")
1025
1134
 
@@ -8,17 +8,17 @@ from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
8
8
  DYNAMIC_SHAPES = Tuple[Tuple[Any, ...], Dict[str, Any]]
9
9
 
10
10
 
11
- def flatten_dynamic_shapes(ds: Any) -> Any:
11
+ def _flatten_dynamic_shapes(ds: Any) -> Any:
12
12
  """Flattens the dynamic shapes."""
13
13
  if isinstance(ds, list):
14
- return _flat_list([flatten_dynamic_shapes(t) for t in ds])
14
+ return _flat_list([_flatten_dynamic_shapes(t) for t in ds])
15
15
  if isinstance(ds, tuple):
16
- return tuple(_flat_list([flatten_dynamic_shapes(t) for t in ds]))
16
+ return tuple(_flat_list([_flatten_dynamic_shapes(t) for t in ds]))
17
17
  if isinstance(ds, dict):
18
18
  if all(isinstance(i, int) for i in ds):
19
19
  # That's a dynamic shape
20
20
  return ds
21
- return _flat_list([flatten_dynamic_shapes(t) for t in ds.values()])
21
+ return _flat_list([_flatten_dynamic_shapes(t) for t in ds.values()])
22
22
  raise AssertionError(f"Not implemented for {type(ds)}: {ds}")
23
23
 
24
24
 
@@ -226,7 +226,7 @@ class CoupleInputsDynamicShapes:
226
226
  for i, d in enumerate(inputs.shape):
227
227
  if i in ds and not isinstance(ds[i], int):
228
228
  # dynamic then
229
- if d in {0, 1}:
229
+ if isinstance(d, int) and d in {0, 1}:
230
230
  # export issues for sure
231
231
  issues[i] = f"d=[{d}]"
232
232
  return issues if issues else None
@@ -380,7 +380,7 @@ class CoupleInputsDynamicShapes:
380
380
  flat, spec = torch.utils._pytree.tree_flatten(inputs)
381
381
  if all(isinstance(t, torch.Tensor) for t in flat):
382
382
  # We need to flatten dynamic shapes as well
383
- ds = flatten_dynamic_shapes(ds)
383
+ ds = _flatten_dynamic_shapes(ds)
384
384
  res = cls._generic_walker_step(
385
385
  processor, flat, ds, flatten_unflatten=flatten_unflatten
386
386
  )
@@ -1,9 +1,10 @@
1
- from typing import Any, Dict, List, Set, Tuple, Union
1
+ from typing import Any, Dict, List, Set, Optional, Tuple, Union
2
2
  from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
3
+ from ..helpers.fake_tensor_helper import fake_reshape
3
4
  from .dynamic_shapes import ModelInputs
4
5
 
5
6
 
6
- def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
7
+ def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
7
8
  """
8
9
  Returns the dynamic shapes for the given inputs.
9
10
  All dimensions are considered as dynamic.
@@ -18,7 +19,7 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
18
19
  import pprint
19
20
  import torch
20
21
  from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
21
- from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
22
+ from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
22
23
  from onnx_diagnostic.torch_export_patches import torch_export_patches
23
24
 
24
25
  bsize, nheads, slen, dim = 2, 1, 30, 96
@@ -32,7 +33,7 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
32
33
  ),
33
34
  )
34
35
  with torch_export_patches(patch_transformers=True):
35
- ds = all_dynamic_shape_from_inputs(inputs)
36
+ ds = all_dynamic_shapes_from_inputs(inputs)
36
37
  pprint.pprint(ds)
37
38
 
38
39
  For this function to work, patches must be enabled if :epkg:`transformers`
@@ -50,7 +51,7 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
50
51
  make_sliding_window_cache,
51
52
  make_static_cache,
52
53
  )
53
- from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
54
+ from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
54
55
  from onnx_diagnostic.torch_export_patches import torch_export_patches
55
56
 
56
57
  caches = [
@@ -104,7 +105,7 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
104
105
  with torch_export_patches(patch_transformers=True):
105
106
  for cache in caches:
106
107
  print(f"-- {cache.__class__.__name__}")
107
- pprint.pprint(all_dynamic_shape_from_inputs(cache))
108
+ pprint.pprint(all_dynamic_shapes_from_inputs(cache))
108
109
  """
109
110
  if isinstance(dim_prefix, str):
110
111
  prefixes: Set[str] = set()
@@ -199,3 +200,120 @@ def guess_dynamic_shapes_from_inputs(
199
200
  """
200
201
  mi = ModelInputs(None, inputs)
201
202
  return mi.guess_dynamic_shapes(auto=auto)
203
+
204
+
205
+ def make_fake_with_dynamic_dimensions(
206
+ x: Any,
207
+ dynamic_shapes: Any,
208
+ fake_mode: Optional["FakeTensorMode"] = None, # noqa: F821
209
+ ) -> Tuple[Any, "FakeTensorMode"]: # noqa: F821
210
+ """
211
+ Replaces all tensors by fake tensor respecting the same
212
+ constraints as the following dynamic shapes.
213
+ This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`.
214
+
215
+ .. runpython::
216
+ :showcode:
217
+
218
+ import pprint
219
+ import torch
220
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
221
+ from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
222
+
223
+ inputs, _ = make_fake_with_dynamic_dimensions(
224
+ dict(
225
+ input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
226
+ attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
227
+ position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
228
+ past_key_values=make_dynamic_cache(
229
+ [
230
+ (
231
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
232
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
233
+ ),
234
+ (
235
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
236
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
237
+ ),
238
+ ]
239
+ ),
240
+ ),
241
+ dynamic_shapes={
242
+ "input_ids": {0: "batch", 1: "seq_length"},
243
+ "attention_mask": {0: "batch", 1: "cache+seq"},
244
+ "position_ids": {0: "batch", 1: "seq_length"},
245
+ "past_key_values": [
246
+ [{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
247
+ [{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
248
+ ],
249
+ },
250
+ )
251
+ pprint.pprint(inputs)
252
+ """
253
+ if x is None:
254
+ return None, None
255
+ if fake_mode is None:
256
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
257
+ from torch._subclasses.fake_tensor import FakeTensorMode
258
+
259
+ shape_env = ShapeEnv()
260
+ fake_mode = FakeTensorMode(shape_env=shape_env)
261
+
262
+ if isinstance(x, (list, tuple)):
263
+ return (
264
+ x.__class__(
265
+ [
266
+ make_fake_with_dynamic_dimensions(
267
+ i, fake_mode=fake_mode, dynamic_shapes=ds
268
+ )[0]
269
+ for i, ds in zip(x, dynamic_shapes)
270
+ ]
271
+ ),
272
+ fake_mode,
273
+ )
274
+ if isinstance(x, dict):
275
+ return {
276
+ k: make_fake_with_dynamic_dimensions(
277
+ v, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[k]
278
+ )[0]
279
+ for k, v in x.items()
280
+ }, fake_mode
281
+
282
+ if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
283
+ assert hasattr(x, "layers"), (
284
+ f"Une more recent version of transformers (>=4.55), "
285
+ f"'layers' not found in class {type(x)}"
286
+ )
287
+ assert (
288
+ isinstance(dynamic_shapes, list) and len(dynamic_shapes) == 2
289
+ ), f"Unexpected dynamic_shapes={dynamic_shapes} for a DynamicCache"
290
+ for il, layer in enumerate(x.layers):
291
+ assert hasattr(layer, "keys") and hasattr(layer, "values"), (
292
+ f"Une more recent version of transformers (>=4.55), 'layers' "
293
+ f"not found in class {type(layer)} ({dir(layer)})"
294
+ )
295
+ layer.keys = make_fake_with_dynamic_dimensions(
296
+ layer.keys, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[0][il]
297
+ )[0]
298
+ layer.values = make_fake_with_dynamic_dimensions(
299
+ layer.values, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[1][il]
300
+ )[0]
301
+ return x, fake_mode
302
+ if x.__class__.__name__ == "EncoderDecoderCache":
303
+ make_fake_with_dynamic_dimensions(
304
+ x.self_attention_cache, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[0]
305
+ )
306
+ make_fake_with_dynamic_dimensions(
307
+ x.cross_attention_cache, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[1]
308
+ )
309
+ return x, fake_mode
310
+ if hasattr(x, "shape"):
311
+ t = fake_reshape(x, dynamic_shapes, fake_mode=fake_mode)
312
+ assert t.device == x.device, f"device mismatch {x.device} -> {t.device}"
313
+ assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}"
314
+ return t, fake_mode
315
+ from ..helpers import string_type
316
+
317
+ raise TypeError(
318
+ f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
319
+ )
@@ -979,7 +979,11 @@ class ExtTestCase(unittest.TestCase):
979
979
  else:
980
980
  for e, g in zip(expected, value):
981
981
  self.assertEqualAny(e, g, msg=msg, atol=atol, rtol=rtol)
982
- elif expected.__class__.__name__ in ("DynamicCache", "SlidingWindowCache"):
982
+ elif expected.__class__.__name__ in (
983
+ "DynamicCache",
984
+ "SlidingWindowCache",
985
+ "HybridCache",
986
+ ):
983
987
  self.assertEqual(type(expected), type(value), msg=msg)
984
988
  atts = ["key_cache", "value_cache"]
985
989
  self.assertEqualAny(