onnx-diagnostic 0.7.0__py3-none-any.whl → 0.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +213 -5
- onnx_diagnostic/export/dynamic_shapes.py +48 -20
- onnx_diagnostic/export/shape_helper.py +126 -0
- onnx_diagnostic/ext_test_case.py +31 -0
- onnx_diagnostic/helpers/cache_helper.py +42 -20
- onnx_diagnostic/helpers/config_helper.py +16 -1
- onnx_diagnostic/helpers/log_helper.py +1561 -177
- onnx_diagnostic/helpers/torch_helper.py +6 -2
- onnx_diagnostic/tasks/__init__.py +2 -0
- onnx_diagnostic/tasks/image_text_to_text.py +69 -18
- onnx_diagnostic/tasks/text_generation.py +17 -8
- onnx_diagnostic/tasks/text_to_image.py +91 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +144 -349
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +87 -7
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +259 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +73 -5
- onnx_diagnostic/torch_models/hghub/hub_data.py +7 -2
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +28 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +74 -14
- onnx_diagnostic/torch_models/validate.py +45 -16
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/RECORD +29 -24
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -333,7 +333,24 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
333
333
|
of supported tasks.
|
|
334
334
|
"""
|
|
335
335
|
),
|
|
336
|
-
epilog=
|
|
336
|
+
epilog=textwrap.dedent(
|
|
337
|
+
"""
|
|
338
|
+
If the model id is specified, one untrained version of it is instantiated.
|
|
339
|
+
Examples:
|
|
340
|
+
|
|
341
|
+
python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
|
|
342
|
+
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
|
|
343
|
+
--dtype float16 --device cuda --patch --export onnx-dynamo --opt ir
|
|
344
|
+
|
|
345
|
+
python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
|
|
346
|
+
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
|
|
347
|
+
--dtype float16 --device cuda --patch --export custom --opt default
|
|
348
|
+
|
|
349
|
+
python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
|
|
350
|
+
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
|
|
351
|
+
--dtype float16 --device cuda --export modelbuilder
|
|
352
|
+
"""
|
|
353
|
+
),
|
|
337
354
|
formatter_class=RawTextHelpFormatter,
|
|
338
355
|
)
|
|
339
356
|
parser.add_argument("-m", "--mid", type=str, help="model id, usually <author>/<name>")
|
|
@@ -372,6 +389,12 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
372
389
|
type=int,
|
|
373
390
|
help="Raises an exception if a dynamic dimension becomes static.",
|
|
374
391
|
)
|
|
392
|
+
parser.add_argument(
|
|
393
|
+
"--same-as-trained",
|
|
394
|
+
default=False,
|
|
395
|
+
action=BooleanOptionalAction,
|
|
396
|
+
help="Validates a model identical to the trained model but not trained.",
|
|
397
|
+
)
|
|
375
398
|
parser.add_argument(
|
|
376
399
|
"--trained",
|
|
377
400
|
default=False,
|
|
@@ -487,7 +510,8 @@ def _cmd_validate(argv: List[Any]):
|
|
|
487
510
|
do_run=args.run,
|
|
488
511
|
verbose=args.verbose,
|
|
489
512
|
quiet=args.quiet,
|
|
490
|
-
|
|
513
|
+
same_as_pretrained=args.same_as_trained,
|
|
514
|
+
use_pretrained=args.trained,
|
|
491
515
|
dtype=args.dtype,
|
|
492
516
|
device=args.device,
|
|
493
517
|
patch=args.patch,
|
|
@@ -609,6 +633,178 @@ def _cmd_stats(argv: List[Any]):
|
|
|
609
633
|
print("done.")
|
|
610
634
|
|
|
611
635
|
|
|
636
|
+
def get_parser_agg() -> ArgumentParser:
|
|
637
|
+
parser = ArgumentParser(
|
|
638
|
+
prog="agg",
|
|
639
|
+
description=textwrap.dedent(
|
|
640
|
+
"""
|
|
641
|
+
Aggregates statistics coming from benchmarks.
|
|
642
|
+
Every run is a row. Every row is indexed by some keys,
|
|
643
|
+
and produces values. Every row has a date.
|
|
644
|
+
"""
|
|
645
|
+
),
|
|
646
|
+
epilog=textwrap.dedent(
|
|
647
|
+
"""
|
|
648
|
+
examples:\n
|
|
649
|
+
|
|
650
|
+
python -m onnx_diagnostic agg test_agg.xlsx raw/*.zip -v 1
|
|
651
|
+
"""
|
|
652
|
+
),
|
|
653
|
+
formatter_class=RawTextHelpFormatter,
|
|
654
|
+
)
|
|
655
|
+
parser.add_argument("output", help="output excel file")
|
|
656
|
+
parser.add_argument(
|
|
657
|
+
"inputs",
|
|
658
|
+
nargs="+",
|
|
659
|
+
help="input csv or zip files, at least 1, it can be a name, or search path",
|
|
660
|
+
)
|
|
661
|
+
parser.add_argument(
|
|
662
|
+
"--filter", default="rawdata_.*.csv", help="filter for input files inside zip files"
|
|
663
|
+
)
|
|
664
|
+
parser.add_argument(
|
|
665
|
+
"--recent",
|
|
666
|
+
default=True,
|
|
667
|
+
action=BooleanOptionalAction,
|
|
668
|
+
help="Keeps only the most recent experiment for the same of keys.",
|
|
669
|
+
)
|
|
670
|
+
parser.add_argument(
|
|
671
|
+
"--keep-last-date",
|
|
672
|
+
default=False,
|
|
673
|
+
action=BooleanOptionalAction,
|
|
674
|
+
help="Rewrite all dates to the last one to simplifies the analysis, "
|
|
675
|
+
"this assume changing the date does not add ambiguity, if any, option "
|
|
676
|
+
"--recent should be added.",
|
|
677
|
+
)
|
|
678
|
+
parser.add_argument(
|
|
679
|
+
"--raw",
|
|
680
|
+
default=True,
|
|
681
|
+
action=BooleanOptionalAction,
|
|
682
|
+
help="Keeps the raw data in a sheet.",
|
|
683
|
+
)
|
|
684
|
+
parser.add_argument("-t", "--time", default="DATE", help="Date or time column")
|
|
685
|
+
parser.add_argument(
|
|
686
|
+
"-k",
|
|
687
|
+
"--keys",
|
|
688
|
+
default="^version_.*,^model_.*,device,opt_patterns,suite,memory_peak,"
|
|
689
|
+
"machine,exporter,dynamic,rtopt,dtype,device,architecture",
|
|
690
|
+
help="List of columns to consider as keys, "
|
|
691
|
+
"multiple values are separated by `,`\n"
|
|
692
|
+
"regular expressions are allowed",
|
|
693
|
+
)
|
|
694
|
+
parser.add_argument(
|
|
695
|
+
"--drop-keys",
|
|
696
|
+
default="",
|
|
697
|
+
help="Drops keys from the given list. Something it is faster "
|
|
698
|
+
"to remove one than to select all the remaining ones.",
|
|
699
|
+
)
|
|
700
|
+
parser.add_argument(
|
|
701
|
+
"-w",
|
|
702
|
+
"--values",
|
|
703
|
+
default="^time_.*,^disc.*,^ERR_.*,CMD,^ITER.*,^onnx_.*,^op_onnx_.*,^peak_gpu_.*",
|
|
704
|
+
help="List of columns to consider as values, "
|
|
705
|
+
"multiple values are separated by `,`\n"
|
|
706
|
+
"regular expressions are allowed",
|
|
707
|
+
)
|
|
708
|
+
parser.add_argument(
|
|
709
|
+
"-i", "--ignored", default="^version_.*", help="List of columns to ignore"
|
|
710
|
+
)
|
|
711
|
+
parser.add_argument(
|
|
712
|
+
"-f",
|
|
713
|
+
"--formula",
|
|
714
|
+
default="speedup,bucket[speedup],ERR1,n_models,n_model_eager,"
|
|
715
|
+
"n_model_running,n_model_acc01,n_model_acc001,n_model_dynamic,"
|
|
716
|
+
"n_model_pass,n_model_faster,"
|
|
717
|
+
"n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
|
|
718
|
+
"peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
|
|
719
|
+
"n_node_constant,n_node_shape,n_node_expand,"
|
|
720
|
+
"n_node_function,n_node_initializer,n_node_scatter,"
|
|
721
|
+
"time_export_unbiased,onnx_n_nodes_no_cst,n_node_initializer_small",
|
|
722
|
+
help="Columns to compute after the aggregation was done.",
|
|
723
|
+
)
|
|
724
|
+
parser.add_argument(
|
|
725
|
+
"--views",
|
|
726
|
+
default="agg-suite,agg-all,disc,speedup,time,time_export,err,cmd,"
|
|
727
|
+
"bucket-speedup,raw-short,counts,peak-gpu,onnx",
|
|
728
|
+
help="Views to add to the output files.",
|
|
729
|
+
)
|
|
730
|
+
parser.add_argument(
|
|
731
|
+
"--csv",
|
|
732
|
+
default="raw-short",
|
|
733
|
+
help="Views to dump as csv files.",
|
|
734
|
+
)
|
|
735
|
+
parser.add_argument("-v", "--verbose", type=int, default=0, help="verbosity")
|
|
736
|
+
parser.add_argument(
|
|
737
|
+
"--filter-in",
|
|
738
|
+
default="",
|
|
739
|
+
help="adds a filter to filter in data, syntax is\n"
|
|
740
|
+
'``"<column1>:<value1>;<value2>/<column2>:<value3>"`` ...',
|
|
741
|
+
)
|
|
742
|
+
parser.add_argument(
|
|
743
|
+
"--filter-out",
|
|
744
|
+
default="",
|
|
745
|
+
help="adds a filter to filter out data, syntax is\n"
|
|
746
|
+
'``"<column1>:<value1>;<value2>/<column2>:<value3>"`` ...',
|
|
747
|
+
)
|
|
748
|
+
return parser
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
def _cmd_agg(argv: List[Any]):
|
|
752
|
+
from .helpers.log_helper import (
|
|
753
|
+
CubeLogsPerformance,
|
|
754
|
+
open_dataframe,
|
|
755
|
+
enumerate_csv_files,
|
|
756
|
+
filter_data,
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
parser = get_parser_agg()
|
|
760
|
+
args = parser.parse_args(argv[1:])
|
|
761
|
+
reg = re.compile(args.filter)
|
|
762
|
+
|
|
763
|
+
csv = list(
|
|
764
|
+
enumerate_csv_files(
|
|
765
|
+
args.inputs, verbose=args.verbose, filtering=lambda name: bool(reg.search(name))
|
|
766
|
+
)
|
|
767
|
+
)
|
|
768
|
+
assert csv, f"No csv files in {args.inputs}, args.filter={args.filter!r}, csv={csv}"
|
|
769
|
+
if args.verbose:
|
|
770
|
+
from tqdm import tqdm
|
|
771
|
+
|
|
772
|
+
loop = tqdm(csv)
|
|
773
|
+
else:
|
|
774
|
+
loop = csv
|
|
775
|
+
dfs = []
|
|
776
|
+
for c in loop:
|
|
777
|
+
df = open_dataframe(c)
|
|
778
|
+
assert (
|
|
779
|
+
args.time in df.columns
|
|
780
|
+
), f"Missing time column {args.time!r} in {c!r}\n{df.head()}\n{sorted(df.columns)}"
|
|
781
|
+
dfs.append(filter_data(df, filter_in=args.filter_in, filter_out=args.filter_out))
|
|
782
|
+
|
|
783
|
+
drop_keys = set(args.drop_keys.split(","))
|
|
784
|
+
cube = CubeLogsPerformance(
|
|
785
|
+
dfs,
|
|
786
|
+
time=args.time,
|
|
787
|
+
keys=[a for a in args.keys.split(",") if a and a not in drop_keys],
|
|
788
|
+
values=[a for a in args.values.split(",") if a],
|
|
789
|
+
ignored=[a for a in args.ignored.split(",") if a],
|
|
790
|
+
recent=args.recent,
|
|
791
|
+
formulas={k: k for k in args.formula.split(",")},
|
|
792
|
+
keep_last_date=args.keep_last_date,
|
|
793
|
+
)
|
|
794
|
+
cube.load(verbose=max(args.verbose - 1, 0))
|
|
795
|
+
if args.verbose:
|
|
796
|
+
print(f"Dumps final file into {args.output!r}")
|
|
797
|
+
cube.to_excel(
|
|
798
|
+
args.output,
|
|
799
|
+
{k: k for k in args.views.split(",")},
|
|
800
|
+
verbose=args.verbose,
|
|
801
|
+
csv=args.csv.split(","),
|
|
802
|
+
raw=args.raw,
|
|
803
|
+
)
|
|
804
|
+
if args.verbose:
|
|
805
|
+
print(f"Wrote {args.output!r}")
|
|
806
|
+
|
|
807
|
+
|
|
612
808
|
def get_main_parser() -> ArgumentParser:
|
|
613
809
|
parser = ArgumentParser(
|
|
614
810
|
prog="onnx_diagnostic",
|
|
@@ -619,19 +815,29 @@ def get_main_parser() -> ArgumentParser:
|
|
|
619
815
|
Type 'python -m onnx_diagnostic <cmd> --help'
|
|
620
816
|
to get help for a specific command.
|
|
621
817
|
|
|
818
|
+
agg - aggregates statistics from multiple files
|
|
622
819
|
config - prints a configuration for a model id
|
|
623
820
|
find - find node consuming or producing a result
|
|
624
821
|
lighten - makes an onnx model lighter by removing the weights,
|
|
625
|
-
unlighten - restores an onnx model produces by the previous experiment
|
|
626
822
|
print - prints the model on standard output
|
|
627
|
-
validate - validate a model
|
|
628
823
|
stats - produces statistics on a model
|
|
824
|
+
unlighten - restores an onnx model produces by the previous experiment
|
|
825
|
+
validate - validate a model
|
|
629
826
|
"""
|
|
630
827
|
),
|
|
631
828
|
)
|
|
632
829
|
parser.add_argument(
|
|
633
830
|
"cmd",
|
|
634
|
-
choices=[
|
|
831
|
+
choices=[
|
|
832
|
+
"agg",
|
|
833
|
+
"config",
|
|
834
|
+
"find",
|
|
835
|
+
"lighten",
|
|
836
|
+
"print",
|
|
837
|
+
"stats",
|
|
838
|
+
"unlighten",
|
|
839
|
+
"validate",
|
|
840
|
+
],
|
|
635
841
|
help="Selects a command.",
|
|
636
842
|
)
|
|
637
843
|
return parser
|
|
@@ -646,6 +852,7 @@ def main(argv: Optional[List[Any]] = None):
|
|
|
646
852
|
config=_cmd_config,
|
|
647
853
|
validate=_cmd_validate,
|
|
648
854
|
stats=_cmd_stats,
|
|
855
|
+
agg=_cmd_agg,
|
|
649
856
|
)
|
|
650
857
|
|
|
651
858
|
if argv is None:
|
|
@@ -667,6 +874,7 @@ def main(argv: Optional[List[Any]] = None):
|
|
|
667
874
|
config=get_parser_config,
|
|
668
875
|
validate=get_parser_validate,
|
|
669
876
|
stats=get_parser_stats,
|
|
877
|
+
agg=get_parser_agg,
|
|
670
878
|
)
|
|
671
879
|
cmd = argv[0]
|
|
672
880
|
if cmd not in parsers:
|
|
@@ -630,9 +630,12 @@ class ModelInputs:
|
|
|
630
630
|
method_name: str = "forward",
|
|
631
631
|
name: str = "main",
|
|
632
632
|
):
|
|
633
|
-
assert
|
|
634
|
-
model
|
|
635
|
-
),
|
|
633
|
+
assert (
|
|
634
|
+
model is None or isinstance(model, torch.nn.Module) or inspect.ismodule(model)
|
|
635
|
+
), (
|
|
636
|
+
f"unexpected type for model={type(model)}, "
|
|
637
|
+
f"it must be a torch.nn.Module or None"
|
|
638
|
+
)
|
|
636
639
|
assert name, (
|
|
637
640
|
f"name={name!r} cannot be empty this string is used to "
|
|
638
641
|
f"display meaningful error messages"
|
|
@@ -641,26 +644,42 @@ class ModelInputs:
|
|
|
641
644
|
self.model = model
|
|
642
645
|
self.level = level
|
|
643
646
|
self.method_name = method_name
|
|
644
|
-
self.forward = getattr(model, method_name)
|
|
645
|
-
self.signature = inspect.signature(self.forward)
|
|
647
|
+
self.forward = getattr(model, method_name) if model is not None else None
|
|
648
|
+
self.signature = inspect.signature(self.forward) if self.forward else None
|
|
646
649
|
|
|
647
650
|
# information about the signature
|
|
648
|
-
self.forward_parameter_names =
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
651
|
+
self.forward_parameter_names = (
|
|
652
|
+
set(
|
|
653
|
+
p.name
|
|
654
|
+
for p in self.signature.parameters.values()
|
|
655
|
+
if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
|
|
656
|
+
)
|
|
657
|
+
if self.signature
|
|
658
|
+
else None
|
|
659
|
+
)
|
|
660
|
+
self.forward_ordered_parameter_names = (
|
|
661
|
+
list(self.signature.parameters) if self.signature else None
|
|
662
|
+
)
|
|
663
|
+
self.forward_positioned_parameter_names = (
|
|
664
|
+
[
|
|
665
|
+
p.name
|
|
666
|
+
for p in self.signature.parameters.values()
|
|
667
|
+
if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
|
|
668
|
+
]
|
|
669
|
+
if self.signature
|
|
670
|
+
else None
|
|
671
|
+
)
|
|
672
|
+
names = (
|
|
673
|
+
[p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL]
|
|
674
|
+
if self.signature
|
|
675
|
+
else None
|
|
652
676
|
)
|
|
653
|
-
self.forward_ordered_parameter_names = list(self.signature.parameters)
|
|
654
|
-
self.forward_positioned_parameter_names = [
|
|
655
|
-
p.name
|
|
656
|
-
for p in self.signature.parameters.values()
|
|
657
|
-
if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
|
|
658
|
-
]
|
|
659
|
-
names = [
|
|
660
|
-
p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL
|
|
661
|
-
]
|
|
662
677
|
self.forward_args = names[0] if names else None
|
|
663
|
-
names =
|
|
678
|
+
names = (
|
|
679
|
+
[p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD]
|
|
680
|
+
if self.signature
|
|
681
|
+
else None
|
|
682
|
+
)
|
|
664
683
|
self.forward_kwargs = names[0] if names else None
|
|
665
684
|
self.forward_custom_op_schema = None
|
|
666
685
|
self.forward_need_serialization = False
|
|
@@ -711,6 +730,7 @@ class ModelInputs:
|
|
|
711
730
|
@property
|
|
712
731
|
def true_model_name(self) -> str:
|
|
713
732
|
"Returns class name or module name."
|
|
733
|
+
assert self.model is not None, "model was None when the class was initialized."
|
|
714
734
|
return (
|
|
715
735
|
self.model.__class__.__name__
|
|
716
736
|
if isinstance(self.model, torch.nn.Module)
|
|
@@ -942,7 +962,7 @@ class ModelInputs:
|
|
|
942
962
|
)
|
|
943
963
|
)
|
|
944
964
|
names = s2.pop()
|
|
945
|
-
for name in names:
|
|
965
|
+
for i, name in enumerate(names):
|
|
946
966
|
assert name not in {"_diag", "verbose"}, (
|
|
947
967
|
f"{self.full_name}: unexpected parameter {name!r}, names={names}"
|
|
948
968
|
f"\ninputs[0]={string_type(self.inputs[0], with_shape=True)}"
|
|
@@ -968,6 +988,14 @@ class ModelInputs:
|
|
|
968
988
|
with the corresponding dynamic shapes.
|
|
969
989
|
*kwargs*, *dynamic_shapes* are modified inplace.
|
|
970
990
|
"""
|
|
991
|
+
assert (
|
|
992
|
+
self.signature is not None
|
|
993
|
+
and self.forward_parameter_names is not None
|
|
994
|
+
and self.forward_ordered_parameter_names is not None
|
|
995
|
+
), (
|
|
996
|
+
"model was None when the class was initialized, "
|
|
997
|
+
"cannot move args to kwargs without the signature."
|
|
998
|
+
)
|
|
971
999
|
sig = self.signature
|
|
972
1000
|
arg_dyn, kw_dyn = dynamic_shapes
|
|
973
1001
|
for i, p in enumerate(sig.parameters):
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Set, Tuple, Union
|
|
2
|
+
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
|
|
3
|
+
from .dynamic_shapes import ModelInputs
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
7
|
+
"""
|
|
8
|
+
Returns the dynamic shapes for the given inputs.
|
|
9
|
+
All dimensions are considered as dynamic.
|
|
10
|
+
``dim_prefix`` can be a string (the function uses it as a prefix),
|
|
11
|
+
or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.
|
|
12
|
+
|
|
13
|
+
.. runpython::
|
|
14
|
+
:showcode:
|
|
15
|
+
|
|
16
|
+
import pprint
|
|
17
|
+
import torch
|
|
18
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
19
|
+
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
|
|
20
|
+
|
|
21
|
+
bsize, nheads, slen, dim = 2, 1, 30, 96
|
|
22
|
+
inputs = dict(
|
|
23
|
+
input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
|
|
24
|
+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
|
|
25
|
+
position_ids=torch.arange(3, dtype=torch.int64),
|
|
26
|
+
past_key_values=make_dynamic_cache(
|
|
27
|
+
[(torch.randn(bsize, nheads, slen, dim),
|
|
28
|
+
torch.randn(bsize, nheads, slen, dim))]
|
|
29
|
+
),
|
|
30
|
+
)
|
|
31
|
+
ds = all_dynamic_shape_from_inputs(inputs)
|
|
32
|
+
pprint.pprint(ds)
|
|
33
|
+
"""
|
|
34
|
+
if isinstance(dim_prefix, str):
|
|
35
|
+
prefixes: Set[str] = set()
|
|
36
|
+
|
|
37
|
+
def tensor_to_shape(tensor):
|
|
38
|
+
n = len(prefixes)
|
|
39
|
+
p = f"{dim_prefix}_{n}"
|
|
40
|
+
prefixes.add(p)
|
|
41
|
+
return {i: f"{p}_{i}" for i in range(tensor.ndim)}
|
|
42
|
+
|
|
43
|
+
else:
|
|
44
|
+
|
|
45
|
+
def tensor_to_shape(tensor):
|
|
46
|
+
return {i: dim_prefix for i in range(tensor.ndim)} # noqa: C420
|
|
47
|
+
|
|
48
|
+
return flatten_unflatten_for_dynamic_shapes(
|
|
49
|
+
inputs, change_function=tensor_to_shape, use_dict=True
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def guess_dynamic_shapes_from_inputs(
|
|
54
|
+
inputs: List[Any], auto: Union[bool, str] = False
|
|
55
|
+
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
|
|
56
|
+
"""
|
|
57
|
+
Guesses which dimension is dimension from a set of inputs.
|
|
58
|
+
Every dimension having different values over multiple sets
|
|
59
|
+
of inputs. Every dimension not changing remains static.
|
|
60
|
+
|
|
61
|
+
:param inputs: a list of input sets
|
|
62
|
+
:param auto: True for ``torch.export.Dim.AUTO``,
|
|
63
|
+
False for ``torch.export.Dim.DYNAMIC``,
|
|
64
|
+
a string to get a unique string for every dynamic dimension
|
|
65
|
+
:return: args and kwargs
|
|
66
|
+
|
|
67
|
+
.. runpython::
|
|
68
|
+
:showcode:
|
|
69
|
+
|
|
70
|
+
import pprint
|
|
71
|
+
import torch
|
|
72
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
73
|
+
from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
|
|
74
|
+
|
|
75
|
+
bsize, nheads, slen, dim = 2, 1, 30, 96
|
|
76
|
+
inputs1 = dict(
|
|
77
|
+
input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
|
|
78
|
+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
|
|
79
|
+
position_ids=torch.arange(3, dtype=torch.int64),
|
|
80
|
+
past_key_values=make_dynamic_cache(
|
|
81
|
+
[
|
|
82
|
+
(
|
|
83
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
84
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
85
|
+
),
|
|
86
|
+
]
|
|
87
|
+
),
|
|
88
|
+
)
|
|
89
|
+
bsize, nheads, slen, dim = 3, 1, 33, 96
|
|
90
|
+
inputs2 = dict(
|
|
91
|
+
input_ids=torch.randint(15, size=(3, 4), dtype=torch.int64),
|
|
92
|
+
attention_mask=torch.randint(1, size=(3, 34), dtype=torch.int64),
|
|
93
|
+
position_ids=torch.arange(4, dtype=torch.int64),
|
|
94
|
+
past_key_values=make_dynamic_cache(
|
|
95
|
+
[
|
|
96
|
+
(
|
|
97
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
98
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
99
|
+
),
|
|
100
|
+
]
|
|
101
|
+
),
|
|
102
|
+
)
|
|
103
|
+
ds = guess_dynamic_shapes_from_inputs([inputs1, inputs2], auto="d")
|
|
104
|
+
pprint.pprint(ds)
|
|
105
|
+
|
|
106
|
+
This function returns something equivalent to function
|
|
107
|
+
:class:`torch.export.dynamic_shapes.AdditionalInputs` but this
|
|
108
|
+
one needs a model.
|
|
109
|
+
|
|
110
|
+
.. runpython::
|
|
111
|
+
:showcode:
|
|
112
|
+
|
|
113
|
+
import pprint
|
|
114
|
+
import torch
|
|
115
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
116
|
+
from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
|
|
117
|
+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
|
|
118
|
+
|
|
119
|
+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True)
|
|
120
|
+
ds = torch.export.dynamic_shapes.AdditionalInputs()
|
|
121
|
+
ds.add((), data["inputs"])
|
|
122
|
+
ds.add((), data["inputs2"])
|
|
123
|
+
pprint.pprint(ds.dynamic_shapes(data["model"], (), data["inputs"]))
|
|
124
|
+
"""
|
|
125
|
+
mi = ModelInputs(None, inputs)
|
|
126
|
+
return mi.guess_dynamic_shapes(auto=auto)
|
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -756,6 +756,18 @@ class ExtTestCase(unittest.TestCase):
|
|
|
756
756
|
"Adds a todo printed when all test are run."
|
|
757
757
|
cls._todos.append((f, msg))
|
|
758
758
|
|
|
759
|
+
@classmethod
|
|
760
|
+
def ort(cls):
|
|
761
|
+
import onnxruntime
|
|
762
|
+
|
|
763
|
+
return onnxruntime
|
|
764
|
+
|
|
765
|
+
@classmethod
|
|
766
|
+
def to_onnx(self, *args, **kwargs):
|
|
767
|
+
from experimental_experiment.torch_interpreter import to_onnx
|
|
768
|
+
|
|
769
|
+
return to_onnx(*args, **kwargs)
|
|
770
|
+
|
|
759
771
|
def print_model(self, model: "ModelProto"): # noqa: F821
|
|
760
772
|
"Prints a ModelProto"
|
|
761
773
|
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
|
|
@@ -917,6 +929,15 @@ class ExtTestCase(unittest.TestCase):
|
|
|
917
929
|
]
|
|
918
930
|
raise AssertionError("\n".join(rows)) # noqa: B904
|
|
919
931
|
|
|
932
|
+
def assertEqualDataFrame(self, d1, d2, **kwargs):
|
|
933
|
+
"""
|
|
934
|
+
Checks that two dataframes are equal.
|
|
935
|
+
Calls :func:`pandas.testing.assert_frame_equal`.
|
|
936
|
+
"""
|
|
937
|
+
from pandas.testing import assert_frame_equal
|
|
938
|
+
|
|
939
|
+
assert_frame_equal(d1, d2, **kwargs)
|
|
940
|
+
|
|
920
941
|
def assertEqualTrue(self, value: Any, msg: str = ""):
|
|
921
942
|
if value is True:
|
|
922
943
|
return
|
|
@@ -967,6 +988,16 @@ class ExtTestCase(unittest.TestCase):
|
|
|
967
988
|
atol=atol,
|
|
968
989
|
rtol=rtol,
|
|
969
990
|
)
|
|
991
|
+
elif expected.__class__.__name__ == "StaticCache":
|
|
992
|
+
self.assertEqual(type(expected), type(value), msg=msg)
|
|
993
|
+
self.assertEqual(expected.max_cache_len, value.max_cache_len)
|
|
994
|
+
atts = ["key_cache", "value_cache"]
|
|
995
|
+
self.assertEqualAny(
|
|
996
|
+
{k: expected.__dict__.get(k, None) for k in atts},
|
|
997
|
+
{k: value.__dict__.get(k, None) for k in atts},
|
|
998
|
+
atol=atol,
|
|
999
|
+
rtol=rtol,
|
|
1000
|
+
)
|
|
970
1001
|
elif expected.__class__.__name__ == "EncoderDecoderCache":
|
|
971
1002
|
self.assertEqual(type(expected), type(value), msg=msg)
|
|
972
1003
|
atts = ["self_attention_cache", "cross_attention_cache"]
|
|
@@ -1,11 +1,15 @@
|
|
|
1
|
-
from typing import Any, List, Tuple
|
|
1
|
+
from typing import Any, Callable, List, Optional, Tuple
|
|
2
2
|
import packaging.version as pv
|
|
3
3
|
import torch
|
|
4
4
|
import transformers
|
|
5
5
|
import transformers.cache_utils
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
def flatten_unflatten_for_dynamic_shapes(
|
|
8
|
+
def flatten_unflatten_for_dynamic_shapes(
|
|
9
|
+
obj: Any,
|
|
10
|
+
use_dict: bool = False,
|
|
11
|
+
change_function: Optional[Callable[[torch.Tensor], Any]] = None,
|
|
12
|
+
) -> Any:
|
|
9
13
|
"""
|
|
10
14
|
Returns the object in a different structure similar to what
|
|
11
15
|
the definition of the dynamic shapes should use.
|
|
@@ -15,11 +19,13 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An
|
|
|
15
19
|
:func:`torch.export.export` only considers the values,
|
|
16
20
|
the context gives the dictionary keys but it is not expressed
|
|
17
21
|
in the dynamic shapes, these specifications seems to be different
|
|
18
|
-
for the strict and non strict mode.
|
|
22
|
+
for the strict and non strict mode. It also preserves tuple.
|
|
23
|
+
:param change_function: to modifies the tensor in the structure itself,
|
|
24
|
+
like replace them by a shape
|
|
19
25
|
:return: the serialized object
|
|
20
26
|
"""
|
|
21
27
|
if isinstance(obj, torch.Tensor):
|
|
22
|
-
return obj
|
|
28
|
+
return change_function(obj) if change_function else obj
|
|
23
29
|
flat, spec = torch.utils._pytree.tree_flatten(obj)
|
|
24
30
|
start = 0
|
|
25
31
|
end = 0
|
|
@@ -27,12 +33,17 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An
|
|
|
27
33
|
for subspec in spec.children_specs:
|
|
28
34
|
end += subspec.num_leaves
|
|
29
35
|
value = subspec.unflatten(flat[start:end])
|
|
30
|
-
value = flatten_unflatten_for_dynamic_shapes(
|
|
36
|
+
value = flatten_unflatten_for_dynamic_shapes(
|
|
37
|
+
value, use_dict=use_dict, change_function=change_function
|
|
38
|
+
)
|
|
31
39
|
subtrees.append(value)
|
|
32
40
|
start = end
|
|
33
|
-
if use_dict
|
|
34
|
-
|
|
35
|
-
|
|
41
|
+
if use_dict:
|
|
42
|
+
if spec.type is dict or spec.context:
|
|
43
|
+
# This a dictionary.
|
|
44
|
+
return dict(zip(spec.context, subtrees))
|
|
45
|
+
if spec.type is tuple:
|
|
46
|
+
return tuple(subtrees)
|
|
36
47
|
# This is a list.
|
|
37
48
|
return subtrees
|
|
38
49
|
|
|
@@ -143,10 +154,12 @@ else:
|
|
|
143
154
|
|
|
144
155
|
def make_static_cache(
|
|
145
156
|
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
157
|
+
max_cache_len: Optional[int] = None,
|
|
146
158
|
) -> transformers.cache_utils.DynamicCache:
|
|
147
159
|
"""
|
|
148
160
|
Creates an instance of :class:`transformers.cache_utils.StaticCache`.
|
|
149
161
|
:param key_value_pairs: list of pairs of (key, values)
|
|
162
|
+
:param max_cache_len: max_cache_length or something inferred from the vector
|
|
150
163
|
:return: :class:`transformers.cache_utils.StaticCache`
|
|
151
164
|
|
|
152
165
|
Example:
|
|
@@ -168,7 +181,8 @@ def make_static_cache(
|
|
|
168
181
|
torch.randn(bsize, nheads, slen, dim),
|
|
169
182
|
)
|
|
170
183
|
for i in range(n_layers)
|
|
171
|
-
]
|
|
184
|
+
],
|
|
185
|
+
max_cache_len=10,
|
|
172
186
|
)
|
|
173
187
|
print(string_type(past_key_values, with_shape=True))
|
|
174
188
|
"""
|
|
@@ -179,24 +193,32 @@ def make_static_cache(
|
|
|
179
193
|
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
|
180
194
|
self.num_hidden_layers = len(key_value_pairs)
|
|
181
195
|
|
|
196
|
+
assert max_cache_len is not None, (
|
|
197
|
+
f"max_cache_len={max_cache_len} cannot be setup "
|
|
198
|
+
f"automatically yet from shape {key_value_pairs[0][0].shape}"
|
|
199
|
+
)
|
|
200
|
+
torch._check(
|
|
201
|
+
max_cache_len >= key_value_pairs[0][0].shape[2],
|
|
202
|
+
(
|
|
203
|
+
f"max_cache_len={max_cache_len} cannot be smaller "
|
|
204
|
+
f"shape[2]={key_value_pairs[0][0].shape[2]} in shape "
|
|
205
|
+
f"{key_value_pairs[0][0].shape}"
|
|
206
|
+
),
|
|
207
|
+
)
|
|
182
208
|
cache = transformers.cache_utils.StaticCache(
|
|
183
209
|
_config(),
|
|
184
210
|
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
185
211
|
device=key_value_pairs[0][0].device,
|
|
186
212
|
dtype=key_value_pairs[0][0].dtype,
|
|
187
|
-
max_cache_len=
|
|
213
|
+
max_cache_len=max_cache_len,
|
|
188
214
|
)
|
|
189
215
|
for i in range(len(key_value_pairs)):
|
|
190
|
-
assert
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
cache.key_cache[i][:, :,
|
|
195
|
-
|
|
196
|
-
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
|
|
197
|
-
f"got {key_value_pairs[i][1].shape}"
|
|
198
|
-
)
|
|
199
|
-
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
|
|
216
|
+
assert (
|
|
217
|
+
key_value_pairs[i][0].shape == key_value_pairs[i][1].shape
|
|
218
|
+
), f"Shape mismatch {key_value_pairs[i][0].shape} != {key_value_pairs[i][1].shape}"
|
|
219
|
+
d = key_value_pairs[i][1].shape[2]
|
|
220
|
+
cache.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
|
|
221
|
+
cache.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
|
|
200
222
|
return cache
|
|
201
223
|
|
|
202
224
|
|