onnx-diagnostic 0.7.14__py3-none-any.whl → 0.7.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +156 -47
- onnx_diagnostic/export/dynamic_shapes.py +6 -6
- onnx_diagnostic/export/shape_helper.py +124 -6
- onnx_diagnostic/ext_test_case.py +5 -1
- onnx_diagnostic/helpers/cache_helper.py +68 -42
- onnx_diagnostic/helpers/config_helper.py +2 -1
- onnx_diagnostic/helpers/fake_tensor_helper.py +153 -0
- onnx_diagnostic/helpers/helper.py +3 -0
- onnx_diagnostic/helpers/rt_helper.py +3 -3
- onnx_diagnostic/tasks/image_text_to_text.py +7 -6
- onnx_diagnostic/tasks/text_generation.py +7 -4
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +69 -11
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +31 -13
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +109 -18
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +133 -28
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +38 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +7 -3
- onnx_diagnostic/torch_models/validate.py +73 -29
- {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/METADATA +6 -6
- {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/RECORD +25 -23
- {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -371,30 +371,34 @@ class _BoolOrParseDictPatch(argparse.Action):
|
|
|
371
371
|
setattr(namespace, self.dest, d)
|
|
372
372
|
|
|
373
373
|
|
|
374
|
-
def get_parser_validate() -> ArgumentParser:
|
|
374
|
+
def get_parser_validate(name: str = "validate") -> ArgumentParser:
|
|
375
375
|
parser = ArgumentParser(
|
|
376
|
-
prog=
|
|
376
|
+
prog=name,
|
|
377
377
|
description=textwrap.dedent(
|
|
378
378
|
"""
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
379
|
+
Validates a model for a particular task given the model id.
|
|
380
|
+
It exports the model and then validates it by computing the discrepancies
|
|
381
|
+
on different input sets.
|
|
382
|
+
"""
|
|
383
|
+
if name == "validate"
|
|
384
|
+
else """
|
|
385
|
+
Creates a script to export a model for a particular task given the model id.
|
|
382
386
|
"""
|
|
383
387
|
),
|
|
384
388
|
epilog=textwrap.dedent(
|
|
385
|
-
"""
|
|
389
|
+
f"""
|
|
386
390
|
If the model id is specified, one untrained version of it is instantiated.
|
|
387
391
|
Examples:
|
|
388
392
|
|
|
389
|
-
python -m onnx_diagnostic
|
|
393
|
+
python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\
|
|
390
394
|
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
|
|
391
395
|
--dtype float16 --device cuda --patch --export onnx-dynamo --opt ir
|
|
392
396
|
|
|
393
|
-
python -m onnx_diagnostic
|
|
397
|
+
python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\
|
|
394
398
|
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
|
|
395
399
|
--dtype float16 --device cuda --patch --export custom --opt default
|
|
396
400
|
|
|
397
|
-
python -m onnx_diagnostic
|
|
401
|
+
python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\
|
|
398
402
|
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
|
|
399
403
|
--dtype float16 --device cuda --export modelbuilder
|
|
400
404
|
|
|
@@ -405,12 +409,12 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
405
409
|
The behaviour may be modified compare the original configuration,
|
|
406
410
|
the following argument can be rope_scaling to dynamic:
|
|
407
411
|
|
|
408
|
-
--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\""
|
|
412
|
+
--mop \"rope_scaling={{'rope_type': 'dynamic', 'factor': 10.0}}\""
|
|
409
413
|
|
|
410
414
|
You can profile the command line by running:
|
|
411
415
|
|
|
412
|
-
pyinstrument -m onnx_diagnostic
|
|
413
|
-
pyinstrument -r html -o profile.html -m onnx_diagnostic
|
|
416
|
+
pyinstrument -m onnx_diagnostic {name} ...
|
|
417
|
+
pyinstrument -r html -o profile.html -m onnx_diagnostic {name} ...
|
|
414
418
|
"""
|
|
415
419
|
),
|
|
416
420
|
formatter_class=RawTextHelpFormatter,
|
|
@@ -460,19 +464,19 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
460
464
|
"--same-as-trained",
|
|
461
465
|
default=False,
|
|
462
466
|
action=BooleanOptionalAction,
|
|
463
|
-
help="Validates a model identical to the trained model but not trained.",
|
|
467
|
+
help="Validates or exports a model identical to the trained model but not trained.",
|
|
464
468
|
)
|
|
465
469
|
parser.add_argument(
|
|
466
470
|
"--trained",
|
|
467
471
|
default=False,
|
|
468
472
|
action=BooleanOptionalAction,
|
|
469
|
-
help="Validates the trained model (requires downloading).",
|
|
473
|
+
help="Validates or exports the trained model (requires downloading).",
|
|
470
474
|
)
|
|
471
475
|
parser.add_argument(
|
|
472
476
|
"--inputs2",
|
|
473
477
|
default=1,
|
|
474
478
|
type=int,
|
|
475
|
-
help="Validates the model on a second set of inputs\n"
|
|
479
|
+
help="Validates or exports the model on a second set of inputs\n"
|
|
476
480
|
"to check the exported model supports dynamism. The values is used "
|
|
477
481
|
"as an increment to the first set of inputs. A high value may trick "
|
|
478
482
|
"a different behavior in the model and missed by the exporter.",
|
|
@@ -504,13 +508,14 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
504
508
|
"--subfolder",
|
|
505
509
|
help="Subfolder where to find the model and the configuration.",
|
|
506
510
|
)
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
511
|
+
if name == "validate":
|
|
512
|
+
parser.add_argument(
|
|
513
|
+
"--ortfusiontype",
|
|
514
|
+
required=False,
|
|
515
|
+
help="Applies onnxruntime fusion, this parameter should contain the\n"
|
|
516
|
+
"model type or multiple values separated by `|`. `ALL` can be used\n"
|
|
517
|
+
"to run them all.",
|
|
518
|
+
)
|
|
514
519
|
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
|
|
515
520
|
parser.add_argument("--dtype", help="Changes dtype if necessary.")
|
|
516
521
|
parser.add_argument("--device", help="Changes the device if necessary.")
|
|
@@ -532,27 +537,38 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
532
537
|
"--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"",
|
|
533
538
|
action=_ParseDict,
|
|
534
539
|
)
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
540
|
+
if name == "validate":
|
|
541
|
+
parser.add_argument(
|
|
542
|
+
"--repeat",
|
|
543
|
+
default=1,
|
|
544
|
+
type=int,
|
|
545
|
+
help="number of times to run the model to measures inference time",
|
|
546
|
+
)
|
|
547
|
+
parser.add_argument(
|
|
548
|
+
"--warmup",
|
|
549
|
+
default=0,
|
|
550
|
+
type=int,
|
|
551
|
+
help="number of times to run the model to do warmup",
|
|
552
|
+
)
|
|
544
553
|
parser.add_argument(
|
|
545
554
|
"--outnames",
|
|
546
555
|
help="This comma separated list defines the output names "
|
|
547
556
|
"the onnx exporter should use.",
|
|
548
557
|
default="",
|
|
549
558
|
)
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
559
|
+
if name == "validate":
|
|
560
|
+
parser.add_argument(
|
|
561
|
+
"--ort-logs",
|
|
562
|
+
default=False,
|
|
563
|
+
action=BooleanOptionalAction,
|
|
564
|
+
help="Enables onnxruntime logging when the session is created",
|
|
565
|
+
)
|
|
566
|
+
parser.add_argument(
|
|
567
|
+
"--quiet-input-sets",
|
|
568
|
+
default="",
|
|
569
|
+
help="Avoids raising an exception when an input sets does not work with "
|
|
570
|
+
"the exported model.\nExample: --quiet-input-sets=inputs,inputs22",
|
|
571
|
+
)
|
|
556
572
|
return parser
|
|
557
573
|
|
|
558
574
|
|
|
@@ -614,6 +630,7 @@ def _cmd_validate(argv: List[Any]):
|
|
|
614
630
|
warmup=args.warmup,
|
|
615
631
|
inputs2=args.inputs2,
|
|
616
632
|
ort_logs=args.ort_logs,
|
|
633
|
+
quiet_input_sets=set(args.quiet_input_sets.split(",")),
|
|
617
634
|
output_names=(
|
|
618
635
|
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
|
|
619
636
|
),
|
|
@@ -624,6 +641,94 @@ def _cmd_validate(argv: List[Any]):
|
|
|
624
641
|
print(f":{k},{v};")
|
|
625
642
|
|
|
626
643
|
|
|
644
|
+
def _cmd_export_sample(argv: List[Any]):
|
|
645
|
+
from .helpers import string_type
|
|
646
|
+
from .torch_models.validate import get_inputs_for_task, _make_folder_name
|
|
647
|
+
from .torch_models.code_sample import code_sample
|
|
648
|
+
from .tasks import supported_tasks
|
|
649
|
+
|
|
650
|
+
parser = get_parser_validate("exportsample")
|
|
651
|
+
args = parser.parse_args(argv[1:])
|
|
652
|
+
if not args.task and not args.mid:
|
|
653
|
+
print("-- list of supported tasks:")
|
|
654
|
+
print("\n".join(supported_tasks()))
|
|
655
|
+
elif not args.mid:
|
|
656
|
+
data = get_inputs_for_task(args.task)
|
|
657
|
+
if args.verbose:
|
|
658
|
+
print(f"task: {args.task}")
|
|
659
|
+
max_length = max(len(k) for k in data["inputs"]) + 1
|
|
660
|
+
print("-- inputs")
|
|
661
|
+
for k, v in data["inputs"].items():
|
|
662
|
+
print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}")
|
|
663
|
+
print("-- dynamic_shapes")
|
|
664
|
+
for k, v in data["dynamic_shapes"].items():
|
|
665
|
+
print(f" + {k.ljust(max_length)}: {string_type(v)}")
|
|
666
|
+
else:
|
|
667
|
+
# Let's skip any invalid combination if known to be unsupported
|
|
668
|
+
if (
|
|
669
|
+
"onnx" not in (args.export or "")
|
|
670
|
+
and "custom" not in (args.export or "")
|
|
671
|
+
and (args.opt or "")
|
|
672
|
+
):
|
|
673
|
+
print(f"code-sample - unsupported args: export={args.export!r}, opt={args.opt!r}")
|
|
674
|
+
return
|
|
675
|
+
patch_dict = args.patch if isinstance(args.patch, dict) else {"patch": args.patch}
|
|
676
|
+
code = code_sample(
|
|
677
|
+
model_id=args.mid,
|
|
678
|
+
task=args.task,
|
|
679
|
+
do_run=args.run,
|
|
680
|
+
verbose=args.verbose,
|
|
681
|
+
quiet=args.quiet,
|
|
682
|
+
same_as_pretrained=args.same_as_trained,
|
|
683
|
+
use_pretrained=args.trained,
|
|
684
|
+
dtype=args.dtype,
|
|
685
|
+
device=args.device,
|
|
686
|
+
patch=patch_dict,
|
|
687
|
+
rewrite=args.rewrite and patch_dict.get("patch", True),
|
|
688
|
+
stop_if_static=args.stop_if_static,
|
|
689
|
+
optimization=args.opt,
|
|
690
|
+
exporter=args.export,
|
|
691
|
+
dump_folder=args.dump_folder,
|
|
692
|
+
drop_inputs=None if not args.drop else args.drop.split(","),
|
|
693
|
+
input_options=args.iop,
|
|
694
|
+
model_options=args.mop,
|
|
695
|
+
subfolder=args.subfolder,
|
|
696
|
+
opset=args.opset,
|
|
697
|
+
runtime=args.runtime,
|
|
698
|
+
output_names=(
|
|
699
|
+
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
|
|
700
|
+
),
|
|
701
|
+
)
|
|
702
|
+
if args.dump_folder:
|
|
703
|
+
os.makedirs(args.dump_folder, exist_ok=True)
|
|
704
|
+
name = (
|
|
705
|
+
_make_folder_name(
|
|
706
|
+
model_id=args.mid,
|
|
707
|
+
exporter=args.export,
|
|
708
|
+
optimization=args.opt,
|
|
709
|
+
dtype=args.dtype,
|
|
710
|
+
device=args.device,
|
|
711
|
+
subfolder=args.subfolder,
|
|
712
|
+
opset=args.opset,
|
|
713
|
+
drop_inputs=None if not args.drop else args.drop.split(","),
|
|
714
|
+
same_as_pretrained=args.same_as_trained,
|
|
715
|
+
use_pretrained=args.trained,
|
|
716
|
+
task=args.task,
|
|
717
|
+
).replace("/", "-")
|
|
718
|
+
+ ".py"
|
|
719
|
+
)
|
|
720
|
+
fullname = os.path.join(args.dump_folder, name)
|
|
721
|
+
if args.verbose:
|
|
722
|
+
print(f"-- prints code in {fullname!r}")
|
|
723
|
+
print("--")
|
|
724
|
+
with open(fullname, "w") as f:
|
|
725
|
+
f.write(code)
|
|
726
|
+
if args.verbose:
|
|
727
|
+
print("-- done")
|
|
728
|
+
else:
|
|
729
|
+
print(code)
|
|
730
|
+
|
|
731
|
+
|
|
627
732
|
def get_parser_stats() -> ArgumentParser:
|
|
628
733
|
parser = ArgumentParser(
|
|
629
734
|
prog="stats",
|
|
@@ -834,7 +939,7 @@ def get_parser_agg() -> ArgumentParser:
|
|
|
834
939
|
"n_model_pass,n_model_faster,"
|
|
835
940
|
"n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
|
|
836
941
|
"n_node_attention23,n_node_rotary_embedding,n_node_rotary_embedding23,"
|
|
837
|
-
"n_node_layer_normalization,n_node_layer_normalization23,"
|
|
942
|
+
"n_node_gqa,n_node_layer_normalization,n_node_layer_normalization23,"
|
|
838
943
|
"peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
|
|
839
944
|
"n_node_constant,n_node_shape,n_node_expand,"
|
|
840
945
|
"n_node_function,n_node_initializer,n_node_scatter,"
|
|
@@ -953,14 +1058,15 @@ def get_main_parser() -> ArgumentParser:
|
|
|
953
1058
|
Type 'python -m onnx_diagnostic <cmd> --help'
|
|
954
1059
|
to get help for a specific command.
|
|
955
1060
|
|
|
956
|
-
agg
|
|
957
|
-
config
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
1061
|
+
agg - aggregates statistics from multiple files
|
|
1062
|
+
config - prints a configuration for a model id
|
|
1063
|
+
exportsample - produces a code to export a model
|
|
1064
|
+
find - find node consuming or producing a result
|
|
1065
|
+
lighten - makes an onnx model lighter by removing the weights,
|
|
1066
|
+
print - prints the model on standard output
|
|
1067
|
+
stats - produces statistics on a model
|
|
1068
|
+
unlighten - restores an onnx model produces by the previous experiment
|
|
1069
|
+
validate - validate a model
|
|
964
1070
|
"""
|
|
965
1071
|
),
|
|
966
1072
|
)
|
|
@@ -969,6 +1075,7 @@ def get_main_parser() -> ArgumentParser:
|
|
|
969
1075
|
choices=[
|
|
970
1076
|
"agg",
|
|
971
1077
|
"config",
|
|
1078
|
+
"exportsample",
|
|
972
1079
|
"find",
|
|
973
1080
|
"lighten",
|
|
974
1081
|
"print",
|
|
@@ -991,6 +1098,7 @@ def main(argv: Optional[List[Any]] = None):
|
|
|
991
1098
|
validate=_cmd_validate,
|
|
992
1099
|
stats=_cmd_stats,
|
|
993
1100
|
agg=_cmd_agg,
|
|
1101
|
+
exportsample=_cmd_export_sample,
|
|
994
1102
|
)
|
|
995
1103
|
|
|
996
1104
|
if argv is None:
|
|
@@ -1013,13 +1121,14 @@ def main(argv: Optional[List[Any]] = None):
|
|
|
1013
1121
|
validate=get_parser_validate,
|
|
1014
1122
|
stats=get_parser_stats,
|
|
1015
1123
|
agg=get_parser_agg,
|
|
1124
|
+
exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
|
|
1016
1125
|
)
|
|
1017
1126
|
cmd = argv[0]
|
|
1018
1127
|
if cmd not in parsers:
|
|
1019
1128
|
raise ValueError(
|
|
1020
1129
|
f"Unknown command {cmd!r}, it should be in {list(sorted(parsers))}."
|
|
1021
1130
|
)
|
|
1022
|
-
parser = parsers[cmd]()
|
|
1131
|
+
parser = parsers[cmd]() # type: ignore[operator]
|
|
1023
1132
|
parser.parse_args(argv[1:])
|
|
1024
1133
|
raise RuntimeError("The programme should have exited before.")
|
|
1025
1134
|
|
|
@@ -8,17 +8,17 @@ from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
|
|
|
8
8
|
DYNAMIC_SHAPES = Tuple[Tuple[Any, ...], Dict[str, Any]]
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
def
|
|
11
|
+
def _flatten_dynamic_shapes(ds: Any) -> Any:
|
|
12
12
|
"""Flattens the dynamic shapes."""
|
|
13
13
|
if isinstance(ds, list):
|
|
14
|
-
return _flat_list([
|
|
14
|
+
return _flat_list([_flatten_dynamic_shapes(t) for t in ds])
|
|
15
15
|
if isinstance(ds, tuple):
|
|
16
|
-
return tuple(_flat_list([
|
|
16
|
+
return tuple(_flat_list([_flatten_dynamic_shapes(t) for t in ds]))
|
|
17
17
|
if isinstance(ds, dict):
|
|
18
18
|
if all(isinstance(i, int) for i in ds):
|
|
19
19
|
# That's a dynamic shape
|
|
20
20
|
return ds
|
|
21
|
-
return _flat_list([
|
|
21
|
+
return _flat_list([_flatten_dynamic_shapes(t) for t in ds.values()])
|
|
22
22
|
raise AssertionError(f"Not implemented for {type(ds)}: {ds}")
|
|
23
23
|
|
|
24
24
|
|
|
@@ -226,7 +226,7 @@ class CoupleInputsDynamicShapes:
|
|
|
226
226
|
for i, d in enumerate(inputs.shape):
|
|
227
227
|
if i in ds and not isinstance(ds[i], int):
|
|
228
228
|
# dynamic then
|
|
229
|
-
if d in {0, 1}:
|
|
229
|
+
if isinstance(d, int) and d in {0, 1}:
|
|
230
230
|
# export issues for sure
|
|
231
231
|
issues[i] = f"d=[{d}]"
|
|
232
232
|
return issues if issues else None
|
|
@@ -380,7 +380,7 @@ class CoupleInputsDynamicShapes:
|
|
|
380
380
|
flat, spec = torch.utils._pytree.tree_flatten(inputs)
|
|
381
381
|
if all(isinstance(t, torch.Tensor) for t in flat):
|
|
382
382
|
# We need to flatten dynamic shapes as well
|
|
383
|
-
ds =
|
|
383
|
+
ds = _flatten_dynamic_shapes(ds)
|
|
384
384
|
res = cls._generic_walker_step(
|
|
385
385
|
processor, flat, ds, flatten_unflatten=flatten_unflatten
|
|
386
386
|
)
|
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
from typing import Any, Dict, List, Set, Tuple, Union
|
|
1
|
+
from typing import Any, Dict, List, Set, Optional, Tuple, Union
|
|
2
2
|
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
|
|
3
|
+
from ..helpers.fake_tensor_helper import fake_reshape
|
|
3
4
|
from .dynamic_shapes import ModelInputs
|
|
4
5
|
|
|
5
6
|
|
|
6
|
-
def
|
|
7
|
+
def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
7
8
|
"""
|
|
8
9
|
Returns the dynamic shapes for the given inputs.
|
|
9
10
|
All dimensions are considered as dynamic.
|
|
@@ -18,7 +19,7 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
|
18
19
|
import pprint
|
|
19
20
|
import torch
|
|
20
21
|
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
21
|
-
from onnx_diagnostic.export.shape_helper import
|
|
22
|
+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
|
|
22
23
|
from onnx_diagnostic.torch_export_patches import torch_export_patches
|
|
23
24
|
|
|
24
25
|
bsize, nheads, slen, dim = 2, 1, 30, 96
|
|
@@ -32,7 +33,7 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
|
32
33
|
),
|
|
33
34
|
)
|
|
34
35
|
with torch_export_patches(patch_transformers=True):
|
|
35
|
-
ds =
|
|
36
|
+
ds = all_dynamic_shapes_from_inputs(inputs)
|
|
36
37
|
pprint.pprint(ds)
|
|
37
38
|
|
|
38
39
|
For this function to work, patches must be enabled if :epkg:`transformers`
|
|
@@ -50,7 +51,7 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
|
50
51
|
make_sliding_window_cache,
|
|
51
52
|
make_static_cache,
|
|
52
53
|
)
|
|
53
|
-
from onnx_diagnostic.export.shape_helper import
|
|
54
|
+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
|
|
54
55
|
from onnx_diagnostic.torch_export_patches import torch_export_patches
|
|
55
56
|
|
|
56
57
|
caches = [
|
|
@@ -104,7 +105,7 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
|
104
105
|
with torch_export_patches(patch_transformers=True):
|
|
105
106
|
for cache in caches:
|
|
106
107
|
print(f"-- {cache.__class__.__name__}")
|
|
107
|
-
pprint.pprint(
|
|
108
|
+
pprint.pprint(all_dynamic_shapes_from_inputs(cache))
|
|
108
109
|
"""
|
|
109
110
|
if isinstance(dim_prefix, str):
|
|
110
111
|
prefixes: Set[str] = set()
|
|
@@ -199,3 +200,120 @@ def guess_dynamic_shapes_from_inputs(
|
|
|
199
200
|
"""
|
|
200
201
|
mi = ModelInputs(None, inputs)
|
|
201
202
|
return mi.guess_dynamic_shapes(auto=auto)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def make_fake_with_dynamic_dimensions(
|
|
206
|
+
x: Any,
|
|
207
|
+
dynamic_shapes: Any,
|
|
208
|
+
fake_mode: Optional["FakeTensorMode"] = None, # noqa: F821
|
|
209
|
+
) -> Tuple[Any, "FakeTensorMode"]: # noqa: F821
|
|
210
|
+
"""
|
|
211
|
+
Replaces all tensors by fake tensor respecting the same
|
|
212
|
+
constraints as the following dynamic shapes.
|
|
213
|
+
This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`.
|
|
214
|
+
|
|
215
|
+
.. runpython::
|
|
216
|
+
:showcode:
|
|
217
|
+
|
|
218
|
+
import pprint
|
|
219
|
+
import torch
|
|
220
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
221
|
+
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
|
|
222
|
+
|
|
223
|
+
inputs, _ = make_fake_with_dynamic_dimensions(
|
|
224
|
+
dict(
|
|
225
|
+
input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
|
|
226
|
+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
|
|
227
|
+
position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
|
|
228
|
+
past_key_values=make_dynamic_cache(
|
|
229
|
+
[
|
|
230
|
+
(
|
|
231
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
232
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
233
|
+
),
|
|
234
|
+
(
|
|
235
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
236
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
237
|
+
),
|
|
238
|
+
]
|
|
239
|
+
),
|
|
240
|
+
),
|
|
241
|
+
dynamic_shapes={
|
|
242
|
+
"input_ids": {0: "batch", 1: "seq_length"},
|
|
243
|
+
"attention_mask": {0: "batch", 1: "cache+seq"},
|
|
244
|
+
"position_ids": {0: "batch", 1: "seq_length"},
|
|
245
|
+
"past_key_values": [
|
|
246
|
+
[{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
|
|
247
|
+
[{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
|
|
248
|
+
],
|
|
249
|
+
},
|
|
250
|
+
)
|
|
251
|
+
pprint.pprint(inputs)
|
|
252
|
+
"""
|
|
253
|
+
if x is None:
|
|
254
|
+
return None, None
|
|
255
|
+
if fake_mode is None:
|
|
256
|
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
257
|
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
258
|
+
|
|
259
|
+
shape_env = ShapeEnv()
|
|
260
|
+
fake_mode = FakeTensorMode(shape_env=shape_env)
|
|
261
|
+
|
|
262
|
+
if isinstance(x, (list, tuple)):
|
|
263
|
+
return (
|
|
264
|
+
x.__class__(
|
|
265
|
+
[
|
|
266
|
+
make_fake_with_dynamic_dimensions(
|
|
267
|
+
i, fake_mode=fake_mode, dynamic_shapes=ds
|
|
268
|
+
)[0]
|
|
269
|
+
for i, ds in zip(x, dynamic_shapes)
|
|
270
|
+
]
|
|
271
|
+
),
|
|
272
|
+
fake_mode,
|
|
273
|
+
)
|
|
274
|
+
if isinstance(x, dict):
|
|
275
|
+
return {
|
|
276
|
+
k: make_fake_with_dynamic_dimensions(
|
|
277
|
+
v, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[k]
|
|
278
|
+
)[0]
|
|
279
|
+
for k, v in x.items()
|
|
280
|
+
}, fake_mode
|
|
281
|
+
|
|
282
|
+
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
|
|
283
|
+
assert hasattr(x, "layers"), (
|
|
284
|
+
f"Une more recent version of transformers (>=4.55), "
|
|
285
|
+
f"'layers' not found in class {type(x)}"
|
|
286
|
+
)
|
|
287
|
+
assert (
|
|
288
|
+
isinstance(dynamic_shapes, list) and len(dynamic_shapes) == 2
|
|
289
|
+
), f"Unexpected dynamic_shapes={dynamic_shapes} for a DynamicCache"
|
|
290
|
+
for il, layer in enumerate(x.layers):
|
|
291
|
+
assert hasattr(layer, "keys") and hasattr(layer, "values"), (
|
|
292
|
+
f"Une more recent version of transformers (>=4.55), 'layers' "
|
|
293
|
+
f"not found in class {type(layer)} ({dir(layer)})"
|
|
294
|
+
)
|
|
295
|
+
layer.keys = make_fake_with_dynamic_dimensions(
|
|
296
|
+
layer.keys, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[0][il]
|
|
297
|
+
)[0]
|
|
298
|
+
layer.values = make_fake_with_dynamic_dimensions(
|
|
299
|
+
layer.values, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[1][il]
|
|
300
|
+
)[0]
|
|
301
|
+
return x, fake_mode
|
|
302
|
+
if x.__class__.__name__ == "EncoderDecoderCache":
|
|
303
|
+
make_fake_with_dynamic_dimensions(
|
|
304
|
+
x.self_attention_cache, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[0]
|
|
305
|
+
)
|
|
306
|
+
make_fake_with_dynamic_dimensions(
|
|
307
|
+
x.cross_attention_cache, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[1]
|
|
308
|
+
)
|
|
309
|
+
return x, fake_mode
|
|
310
|
+
if hasattr(x, "shape"):
|
|
311
|
+
t = fake_reshape(x, dynamic_shapes, fake_mode=fake_mode)
|
|
312
|
+
assert t.device == x.device, f"device mismatch {x.device} -> {t.device}"
|
|
313
|
+
assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}"
|
|
314
|
+
return t, fake_mode
|
|
315
|
+
from ..helpers import string_type
|
|
316
|
+
|
|
317
|
+
raise TypeError(
|
|
318
|
+
f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
|
|
319
|
+
)
|
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -979,7 +979,11 @@ class ExtTestCase(unittest.TestCase):
|
|
|
979
979
|
else:
|
|
980
980
|
for e, g in zip(expected, value):
|
|
981
981
|
self.assertEqualAny(e, g, msg=msg, atol=atol, rtol=rtol)
|
|
982
|
-
elif expected.__class__.__name__ in (
|
|
982
|
+
elif expected.__class__.__name__ in (
|
|
983
|
+
"DynamicCache",
|
|
984
|
+
"SlidingWindowCache",
|
|
985
|
+
"HybridCache",
|
|
986
|
+
):
|
|
983
987
|
self.assertEqual(type(expected), type(value), msg=msg)
|
|
984
988
|
atts = ["key_cache", "value_cache"]
|
|
985
989
|
self.assertEqualAny(
|