onnx-diagnostic 0.7.16__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +78 -22
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +2 -1
- onnx_diagnostic/export/shape_helper.py +47 -70
- onnx_diagnostic/ext_test_case.py +11 -0
- onnx_diagnostic/helpers/cache_helper.py +38 -7
- onnx_diagnostic/helpers/fake_tensor_helper.py +224 -104
- onnx_diagnostic/helpers/helper.py +27 -33
- onnx_diagnostic/helpers/log_helper.py +109 -5
- onnx_diagnostic/helpers/memory_peak.py +2 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +1 -1
- onnx_diagnostic/helpers/model_builder_helper.py +132 -2
- onnx_diagnostic/helpers/onnx_helper.py +1 -1
- onnx_diagnostic/helpers/ort_session.py +4 -0
- onnx_diagnostic/helpers/rt_helper.py +393 -43
- onnx_diagnostic/helpers/torch_helper.py +20 -1
- onnx_diagnostic/tasks/__init__.py +7 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +2 -8
- onnx_diagnostic/tasks/feature_extraction.py +2 -8
- onnx_diagnostic/tasks/image_text_to_text.py +10 -8
- onnx_diagnostic/tasks/summarization.py +2 -8
- onnx_diagnostic/tasks/text2text_generation.py +3 -8
- onnx_diagnostic/tasks/text_generation.py +86 -65
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +718 -438
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +1 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +9 -36
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -6
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +162 -24
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +140 -104
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -4
- onnx_diagnostic/torch_models/validate.py +626 -228
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/RECORD +38 -36
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -265,7 +265,7 @@ def get_parser_config() -> ArgumentParser:
|
|
|
265
265
|
"--mop",
|
|
266
266
|
metavar="KEY=VALUE",
|
|
267
267
|
nargs="*",
|
|
268
|
-
help="Additional model options,
|
|
268
|
+
help="Additional model options, used to change some parameters of the model, "
|
|
269
269
|
"example:\n --mop attn_implementation=sdpa or --mop attn_implementation=eager",
|
|
270
270
|
action=_ParseDict,
|
|
271
271
|
)
|
|
@@ -442,11 +442,17 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
|
|
|
442
442
|
default=True,
|
|
443
443
|
action=_BoolOrParseDictPatch,
|
|
444
444
|
nargs="*",
|
|
445
|
-
help=
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
445
|
+
help=textwrap.dedent(
|
|
446
|
+
"""
|
|
447
|
+
Applies patches before exporting, it can be a boolean
|
|
448
|
+
to enable to disable the patches or be more finetuned
|
|
449
|
+
(default is True). It is possible to disable patch for torch
|
|
450
|
+
by adding:
|
|
451
|
+
--patch "patch_sympy=False" --patch "patch_torch=False"
|
|
452
|
+
""".strip(
|
|
453
|
+
"\n"
|
|
454
|
+
)
|
|
455
|
+
),
|
|
450
456
|
)
|
|
451
457
|
parser.add_argument(
|
|
452
458
|
"--rewrite",
|
|
@@ -476,10 +482,16 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
|
|
|
476
482
|
"--inputs2",
|
|
477
483
|
default=1,
|
|
478
484
|
type=int,
|
|
479
|
-
help=
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
485
|
+
help=textwrap.dedent(
|
|
486
|
+
"""
|
|
487
|
+
Validates or exports the model on a second set of inputs
|
|
488
|
+
to check the exported model supports dynamism. The values is used
|
|
489
|
+
as an increment to the first set of inputs. A high value may trick
|
|
490
|
+
a different behavior in the model and missed by the exporter.
|
|
491
|
+
""".strip(
|
|
492
|
+
"\n"
|
|
493
|
+
)
|
|
494
|
+
),
|
|
483
495
|
)
|
|
484
496
|
parser.add_argument(
|
|
485
497
|
"--runtime",
|
|
@@ -512,9 +524,15 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
|
|
|
512
524
|
parser.add_argument(
|
|
513
525
|
"--ortfusiontype",
|
|
514
526
|
required=False,
|
|
515
|
-
help=
|
|
516
|
-
|
|
517
|
-
|
|
527
|
+
help=textwrap.dedent(
|
|
528
|
+
"""
|
|
529
|
+
Applies onnxruntime fusion, this parameter should contain the
|
|
530
|
+
model type or multiple values separated by `|`. `ALL` can be used
|
|
531
|
+
to run them all.
|
|
532
|
+
""".strip(
|
|
533
|
+
"\n"
|
|
534
|
+
)
|
|
535
|
+
),
|
|
518
536
|
)
|
|
519
537
|
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
|
|
520
538
|
parser.add_argument("--dtype", help="Changes dtype if necessary.")
|
|
@@ -523,18 +541,32 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
|
|
|
523
541
|
"--iop",
|
|
524
542
|
metavar="KEY=VALUE",
|
|
525
543
|
nargs="*",
|
|
526
|
-
help=
|
|
527
|
-
|
|
528
|
-
|
|
544
|
+
help=textwrap.dedent(
|
|
545
|
+
"""
|
|
546
|
+
Additional input options, used to change the default
|
|
547
|
+
inputs use to export. Examples:
|
|
548
|
+
--iop cls_cache=SlidingWindowCache
|
|
549
|
+
--iop cls_cache=StaticCache
|
|
550
|
+
""".strip(
|
|
551
|
+
"\n"
|
|
552
|
+
)
|
|
553
|
+
),
|
|
529
554
|
action=_ParseDict,
|
|
530
555
|
)
|
|
531
556
|
parser.add_argument(
|
|
532
557
|
"--mop",
|
|
533
558
|
metavar="KEY=VALUE",
|
|
534
559
|
nargs="*",
|
|
535
|
-
help=
|
|
536
|
-
|
|
537
|
-
|
|
560
|
+
help=textwrap.dedent(
|
|
561
|
+
"""
|
|
562
|
+
Additional model options, used to change some parameters
|
|
563
|
+
of the model. Example:
|
|
564
|
+
--mop attn_implementation=sdpa --mop attn_implementation=eager"
|
|
565
|
+
--mop "rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}"
|
|
566
|
+
""".strip(
|
|
567
|
+
"\n"
|
|
568
|
+
)
|
|
569
|
+
),
|
|
538
570
|
action=_ParseDict,
|
|
539
571
|
)
|
|
540
572
|
if name == "validate":
|
|
@@ -566,9 +598,32 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
|
|
|
566
598
|
parser.add_argument(
|
|
567
599
|
"--quiet-input-sets",
|
|
568
600
|
default="",
|
|
569
|
-
help=
|
|
570
|
-
|
|
601
|
+
help=textwrap.dedent(
|
|
602
|
+
"""
|
|
603
|
+
Avoids raising an exception when an input sets does not work with
|
|
604
|
+
the exported model. Example:
|
|
605
|
+
--quiet-input-sets=inputs,inputs22
|
|
606
|
+
""".strip(
|
|
607
|
+
"\n"
|
|
608
|
+
)
|
|
609
|
+
),
|
|
571
610
|
)
|
|
611
|
+
parser.add_argument(
|
|
612
|
+
"--expop",
|
|
613
|
+
metavar="KEY=VALUE",
|
|
614
|
+
nargs="*",
|
|
615
|
+
help=textwrap.dedent(
|
|
616
|
+
"""
|
|
617
|
+
Additional exporter options, use to change some parameters
|
|
618
|
+
of the model. Examples:
|
|
619
|
+
--expop report=True
|
|
620
|
+
--expop report=True --expop verify=True
|
|
621
|
+
""".strip(
|
|
622
|
+
"\n"
|
|
623
|
+
)
|
|
624
|
+
),
|
|
625
|
+
action=_ParseDict,
|
|
626
|
+
)
|
|
572
627
|
return parser
|
|
573
628
|
|
|
574
629
|
|
|
@@ -634,6 +689,7 @@ def _cmd_validate(argv: List[Any]):
|
|
|
634
689
|
output_names=(
|
|
635
690
|
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
|
|
636
691
|
),
|
|
692
|
+
exporter_options=args.expop,
|
|
637
693
|
)
|
|
638
694
|
print("")
|
|
639
695
|
print("-- summary --")
|
|
@@ -940,7 +996,7 @@ def get_parser_agg() -> ArgumentParser:
|
|
|
940
996
|
"n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
|
|
941
997
|
"n_node_attention23,n_node_rotary_embedding,n_node_rotary_embedding23,"
|
|
942
998
|
"n_node_gqa,n_node_layer_normalization,n_node_layer_normalization23,"
|
|
943
|
-
"peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
|
|
999
|
+
"peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,n_node_random,"
|
|
944
1000
|
"n_node_constant,n_node_shape,n_node_expand,"
|
|
945
1001
|
"n_node_function,n_node_initializer,n_node_scatter,"
|
|
946
1002
|
"time_export_unbiased,onnx_n_nodes_no_cst,n_node_initializer_small",
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Sequence, Optional, Tuple, Union
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def to_onnx(
|
|
6
|
+
mod: Union["torch.nn.Module", "torch.fx.GraphModule"], # noqa: F821
|
|
7
|
+
args: Optional[Sequence["torch.Tensor"]] = None, # noqa: F821
|
|
8
|
+
kwargs: Optional[Dict[str, "torch.Tensor"]] = None, # noqa: F821
|
|
9
|
+
input_names: Optional[Sequence[str]] = None,
|
|
10
|
+
target_opset: Optional[Union[int, Dict[str, int]]] = None,
|
|
11
|
+
verbose: int = 0,
|
|
12
|
+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
13
|
+
filename: Optional[str] = None,
|
|
14
|
+
output_names: Optional[List[str]] = None,
|
|
15
|
+
output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
16
|
+
exporter: str = "onnx-dynamo",
|
|
17
|
+
) -> Any:
|
|
18
|
+
"""
|
|
19
|
+
Common API for exporters. By default, the models are optimized to use the
|
|
20
|
+
most efficient kernels implemented in :epkg:`onnxruntime`.
|
|
21
|
+
|
|
22
|
+
:param mod: torch model
|
|
23
|
+
:param args: unnamed arguments
|
|
24
|
+
:param kwargs: named arguments
|
|
25
|
+
:param input_names: input names for the onnx model (optional)
|
|
26
|
+
:param target_opset: opset to target, if not specified, each converter
|
|
27
|
+
keeps its default value
|
|
28
|
+
:param verbose: verbosity level
|
|
29
|
+
:param dynamic_shapes: dynamic shapes, usually a nested structure
|
|
30
|
+
included a dictionary for each tensor
|
|
31
|
+
:param filename: output filename
|
|
32
|
+
:param output_names: to change the output of the onnx model
|
|
33
|
+
:param output_dynamic_shapes: to overwrite the dynamic shapes names
|
|
34
|
+
:param exporter: exporter to use (``onnx-dynamo``, ``modelbuilder``, ``custom``)
|
|
35
|
+
:return: the output of the selected exporter, usually a structure including
|
|
36
|
+
an onnx model
|
|
37
|
+
|
|
38
|
+
A simple example:
|
|
39
|
+
|
|
40
|
+
.. code-block:: python
|
|
41
|
+
|
|
42
|
+
to_onnx(
|
|
43
|
+
model,
|
|
44
|
+
kwargs=inputs,
|
|
45
|
+
dynamic_shapes=ds,
|
|
46
|
+
exporter=exporter,
|
|
47
|
+
filename=filename,
|
|
48
|
+
)
|
|
49
|
+
"""
|
|
50
|
+
if exporter == "custom":
|
|
51
|
+
from experimental_experiment.torch_interpreter import to_onnx as _to_onnx
|
|
52
|
+
from experimental_experiment.xbuilder import OptimizationOptions
|
|
53
|
+
|
|
54
|
+
return _to_onnx(
|
|
55
|
+
mod,
|
|
56
|
+
args=args,
|
|
57
|
+
kwargs=kwargs,
|
|
58
|
+
input_names=input_names,
|
|
59
|
+
output_names=output_names,
|
|
60
|
+
target_opset=target_opset,
|
|
61
|
+
verbose=verbose,
|
|
62
|
+
filename=filename,
|
|
63
|
+
dynamic_shapes=dynamic_shapes,
|
|
64
|
+
large_model=True,
|
|
65
|
+
output_dynamic_shapes=output_dynamic_shapes,
|
|
66
|
+
options=OptimizationOptions(patterns="default+onnxruntime"),
|
|
67
|
+
)
|
|
68
|
+
if exporter in ("dynamo", "onnx-dynamo"):
|
|
69
|
+
import onnxscript.rewriter.ort_fusions as ort_fusions
|
|
70
|
+
|
|
71
|
+
assert (
|
|
72
|
+
not output_dynamic_shapes
|
|
73
|
+
), f"output_dynamic_shapes not supported for exporter={exporter!r}"
|
|
74
|
+
epo = torch.onnx.export(
|
|
75
|
+
mod,
|
|
76
|
+
args=args or tuple(),
|
|
77
|
+
kwargs=kwargs,
|
|
78
|
+
input_names=input_names,
|
|
79
|
+
output_names=output_names,
|
|
80
|
+
opset_version=target_opset,
|
|
81
|
+
dynamic_shapes=dynamic_shapes,
|
|
82
|
+
dynamo=True,
|
|
83
|
+
)
|
|
84
|
+
ort_fusions.optimize_for_ort(epo.model)
|
|
85
|
+
epo.save(filename)
|
|
86
|
+
return epo
|
|
87
|
+
|
|
88
|
+
if exporter == "modelbuilder":
|
|
89
|
+
import os
|
|
90
|
+
from ..helpers import flatten_object, string_type
|
|
91
|
+
from ..helpers.model_builder_helper import create_model_builder, save_model_builder
|
|
92
|
+
|
|
93
|
+
assert filename, f"filename must be specified for exporter={exporter!r}"
|
|
94
|
+
assert (
|
|
95
|
+
not output_dynamic_shapes
|
|
96
|
+
), f"output_dynamic_shapes not supported for exporter={exporter!r}"
|
|
97
|
+
assert hasattr(mod, "config"), f"configuration is missing in model class {type(mod)}"
|
|
98
|
+
assert not args, f"only kwargs can be defined with exporter={exporter!r}"
|
|
99
|
+
assert list(kwargs) == ["input_ids", "attention_mask", "past_key_values"], ( # type: ignore[arg-type]
|
|
100
|
+
f"Only a specified set of inputs is supported for exporter={exporter!r}, "
|
|
101
|
+
f"but it is {list(kwargs)}" # type: ignore[arg-type]
|
|
102
|
+
)
|
|
103
|
+
flat_inputs = flatten_object(kwargs, drop_keys=True)
|
|
104
|
+
first = flat_inputs[0]
|
|
105
|
+
first_float = [
|
|
106
|
+
t
|
|
107
|
+
for t in flat_inputs
|
|
108
|
+
if t.dtype in {torch.float32, torch.double, torch.float16, torch.bfloat16}
|
|
109
|
+
]
|
|
110
|
+
assert first_float, (
|
|
111
|
+
f"Unable to find a float tensor in the inputs "
|
|
112
|
+
f"{string_type(kwargs, with_shape=True)}"
|
|
113
|
+
)
|
|
114
|
+
onx = create_model_builder(
|
|
115
|
+
mod.config,
|
|
116
|
+
mod,
|
|
117
|
+
precision=str(first_float[0].dtype).split(".")[-1],
|
|
118
|
+
execution_provider="cuda" if first.is_cuda else "cpu",
|
|
119
|
+
cache_dir=os.path.dirname(filename),
|
|
120
|
+
)
|
|
121
|
+
save_model_builder(onx, os.path.dirname(filename))
|
|
122
|
+
return onx
|
|
123
|
+
|
|
124
|
+
raise ValueError(f"Unknown exporter={exporter!r}")
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import inspect
|
|
2
|
+
import itertools
|
|
2
3
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
3
4
|
import numpy as np
|
|
4
5
|
import torch
|
|
@@ -934,7 +935,7 @@ class ModelInputs:
|
|
|
934
935
|
auto=auto if isinstance(auto, bool) else f"{auto}_{i}vdc",
|
|
935
936
|
)
|
|
936
937
|
)
|
|
937
|
-
return
|
|
938
|
+
return list(itertools.chain.from_iterable(zip(key_cache, value_cache)))
|
|
938
939
|
|
|
939
940
|
raise NotImplementedError(
|
|
940
941
|
f"Unable to build dynamic shapes for type {set_types.pop()}: "
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from typing import Any, Dict, List, Set, Optional, Tuple, Union
|
|
2
2
|
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
|
|
3
|
-
from ..helpers.fake_tensor_helper import fake_reshape
|
|
4
3
|
from .dynamic_shapes import ModelInputs
|
|
5
4
|
|
|
6
5
|
|
|
@@ -203,14 +202,49 @@ def guess_dynamic_shapes_from_inputs(
|
|
|
203
202
|
|
|
204
203
|
|
|
205
204
|
def make_fake_with_dynamic_dimensions(
|
|
206
|
-
x: Any,
|
|
207
|
-
|
|
208
|
-
fake_mode: Optional["FakeTensorMode"] = None, # noqa: F821
|
|
209
|
-
) -> Tuple[Any, "FakeTensorMode"]: # noqa: F821
|
|
205
|
+
x: Any, dynamic_shapes: Any, context: Optional["FakeTensorContext"] = None # noqa: F821
|
|
206
|
+
) -> Tuple[Any, "FakeTensorContext"]: # noqa: F821
|
|
210
207
|
"""
|
|
211
208
|
Replaces all tensors by fake tensor respecting the same
|
|
212
209
|
constraints as the following dynamic shapes.
|
|
213
210
|
This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`.
|
|
211
|
+
Parameter ``existing`` is used to reused the same object when the dynamic
|
|
212
|
+
dimension is given the same name as another one.
|
|
213
|
+
|
|
214
|
+
A simple tensor:
|
|
215
|
+
|
|
216
|
+
.. runpython::
|
|
217
|
+
:showcode:
|
|
218
|
+
|
|
219
|
+
import torch
|
|
220
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
221
|
+
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
|
|
222
|
+
|
|
223
|
+
inputs, _ = make_fake_with_dynamic_dimensions(
|
|
224
|
+
torch.rand((2, 3, 4, 5), dtype=torch.float32),
|
|
225
|
+
{0: "batch", 2: "cache_length"},
|
|
226
|
+
)
|
|
227
|
+
print(inputs)
|
|
228
|
+
|
|
229
|
+
Two tensors:
|
|
230
|
+
|
|
231
|
+
.. runpython::
|
|
232
|
+
:showcode:
|
|
233
|
+
|
|
234
|
+
import torch
|
|
235
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
236
|
+
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
|
|
237
|
+
|
|
238
|
+
inputs, _ = make_fake_with_dynamic_dimensions(
|
|
239
|
+
(
|
|
240
|
+
torch.rand((2, 3, 4, 5), dtype=torch.float32),
|
|
241
|
+
torch.rand((2, 3, 4, 5), dtype=torch.float32),
|
|
242
|
+
),
|
|
243
|
+
({0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}),
|
|
244
|
+
)
|
|
245
|
+
print(inputs)
|
|
246
|
+
|
|
247
|
+
With a cache:
|
|
214
248
|
|
|
215
249
|
.. runpython::
|
|
216
250
|
:showcode:
|
|
@@ -243,8 +277,10 @@ def make_fake_with_dynamic_dimensions(
|
|
|
243
277
|
"attention_mask": {0: "batch", 1: "cache+seq"},
|
|
244
278
|
"position_ids": {0: "batch", 1: "seq_length"},
|
|
245
279
|
"past_key_values": [
|
|
246
|
-
|
|
247
|
-
|
|
280
|
+
{0: "batch", 2: "cache_length"},
|
|
281
|
+
{0: "batch", 2: "cache_length"},
|
|
282
|
+
{0: "batch", 2: "cache_length"},
|
|
283
|
+
{0: "batch", 2: "cache_length"},
|
|
248
284
|
],
|
|
249
285
|
},
|
|
250
286
|
)
|
|
@@ -252,68 +288,9 @@ def make_fake_with_dynamic_dimensions(
|
|
|
252
288
|
"""
|
|
253
289
|
if x is None:
|
|
254
290
|
return None, None
|
|
255
|
-
if
|
|
256
|
-
from
|
|
257
|
-
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
291
|
+
if context is None:
|
|
292
|
+
from ..helpers.fake_tensor_helper import FakeTensorContext
|
|
258
293
|
|
|
259
|
-
|
|
260
|
-
fake_mode = FakeTensorMode(shape_env=shape_env)
|
|
294
|
+
context = FakeTensorContext()
|
|
261
295
|
|
|
262
|
-
|
|
263
|
-
return (
|
|
264
|
-
x.__class__(
|
|
265
|
-
[
|
|
266
|
-
make_fake_with_dynamic_dimensions(
|
|
267
|
-
i, fake_mode=fake_mode, dynamic_shapes=ds
|
|
268
|
-
)[0]
|
|
269
|
-
for i, ds in zip(x, dynamic_shapes)
|
|
270
|
-
]
|
|
271
|
-
),
|
|
272
|
-
fake_mode,
|
|
273
|
-
)
|
|
274
|
-
if isinstance(x, dict):
|
|
275
|
-
return {
|
|
276
|
-
k: make_fake_with_dynamic_dimensions(
|
|
277
|
-
v, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[k]
|
|
278
|
-
)[0]
|
|
279
|
-
for k, v in x.items()
|
|
280
|
-
}, fake_mode
|
|
281
|
-
|
|
282
|
-
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
|
|
283
|
-
assert hasattr(x, "layers"), (
|
|
284
|
-
f"Une more recent version of transformers (>=4.55), "
|
|
285
|
-
f"'layers' not found in class {type(x)}"
|
|
286
|
-
)
|
|
287
|
-
assert (
|
|
288
|
-
isinstance(dynamic_shapes, list) and len(dynamic_shapes) == 2
|
|
289
|
-
), f"Unexpected dynamic_shapes={dynamic_shapes} for a DynamicCache"
|
|
290
|
-
for il, layer in enumerate(x.layers):
|
|
291
|
-
assert hasattr(layer, "keys") and hasattr(layer, "values"), (
|
|
292
|
-
f"Une more recent version of transformers (>=4.55), 'layers' "
|
|
293
|
-
f"not found in class {type(layer)} ({dir(layer)})"
|
|
294
|
-
)
|
|
295
|
-
layer.keys = make_fake_with_dynamic_dimensions(
|
|
296
|
-
layer.keys, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[0][il]
|
|
297
|
-
)[0]
|
|
298
|
-
layer.values = make_fake_with_dynamic_dimensions(
|
|
299
|
-
layer.values, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[1][il]
|
|
300
|
-
)[0]
|
|
301
|
-
return x, fake_mode
|
|
302
|
-
if x.__class__.__name__ == "EncoderDecoderCache":
|
|
303
|
-
make_fake_with_dynamic_dimensions(
|
|
304
|
-
x.self_attention_cache, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[0]
|
|
305
|
-
)
|
|
306
|
-
make_fake_with_dynamic_dimensions(
|
|
307
|
-
x.cross_attention_cache, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[1]
|
|
308
|
-
)
|
|
309
|
-
return x, fake_mode
|
|
310
|
-
if hasattr(x, "shape"):
|
|
311
|
-
t = fake_reshape(x, dynamic_shapes, fake_mode=fake_mode)
|
|
312
|
-
assert t.device == x.device, f"device mismatch {x.device} -> {t.device}"
|
|
313
|
-
assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}"
|
|
314
|
-
return t, fake_mode
|
|
315
|
-
from ..helpers import string_type
|
|
316
|
-
|
|
317
|
-
raise TypeError(
|
|
318
|
-
f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
|
|
319
|
-
)
|
|
296
|
+
return context.make_fake_with_dynamic_dimensions(x, dynamic_shapes), context
|
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -630,6 +630,17 @@ def has_onnxruntime_training(push_back_batch: bool = False):
|
|
|
630
630
|
return True
|
|
631
631
|
|
|
632
632
|
|
|
633
|
+
def has_onnxruntime_genai():
|
|
634
|
+
"""Tells if onnxruntime_genai is installed."""
|
|
635
|
+
try:
|
|
636
|
+
import onnxruntime_genai # noqa: F401
|
|
637
|
+
|
|
638
|
+
return True
|
|
639
|
+
except ImportError:
|
|
640
|
+
# onnxruntime not training
|
|
641
|
+
return False
|
|
642
|
+
|
|
643
|
+
|
|
633
644
|
def requires_onnxruntime_training(
|
|
634
645
|
push_back_batch: bool = False, ortmodule: bool = False, msg: str = ""
|
|
635
646
|
) -> Callable:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Callable, List, Optional, Tuple
|
|
1
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
2
2
|
import packaging.version as pv
|
|
3
3
|
import torch
|
|
4
4
|
import transformers
|
|
@@ -46,9 +46,14 @@ class CacheKeyValue:
|
|
|
46
46
|
raise NotImplementedError(f"type(cache)={type(cache)}")
|
|
47
47
|
|
|
48
48
|
def make_dynamic_cache(self):
|
|
49
|
-
"""
|
|
49
|
+
"""Does the reverse operation."""
|
|
50
50
|
return make_dynamic_cache(list(zip(self.key_cache, self.value_cache)))
|
|
51
51
|
|
|
52
|
+
@property
|
|
53
|
+
def n_layers(self) -> int:
|
|
54
|
+
"""Returns the number of layers."""
|
|
55
|
+
return len(self.key_cache) if self.key_cache else 0
|
|
56
|
+
|
|
52
57
|
|
|
53
58
|
def flatten_unflatten_for_dynamic_shapes(
|
|
54
59
|
obj: Any,
|
|
@@ -134,10 +139,31 @@ def is_cache_dynamic_registered(fast: bool = False) -> bool:
|
|
|
134
139
|
return len(cache2.key_cache) == len(cache.value_cache)
|
|
135
140
|
|
|
136
141
|
|
|
142
|
+
def make_dynamic_shapes_kv_cache(
|
|
143
|
+
cache: transformers.cache_utils.Cache, shape_of_one: Dict[int, Any]
|
|
144
|
+
) -> List[Dict[int, Any]]:
|
|
145
|
+
"""
|
|
146
|
+
Returns the dynamic shapes for key-value cache
|
|
147
|
+
|
|
148
|
+
:param cache: a cache
|
|
149
|
+
:param shape_of_one: shape of one element
|
|
150
|
+
:return: dynamic shapes
|
|
151
|
+
"""
|
|
152
|
+
return [shape_of_one for _ in range(CacheKeyValue(cache).n_layers * 2)]
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _preprocess_key_value_pairs(
|
|
156
|
+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
157
|
+
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
|
158
|
+
if not key_value_pairs or isinstance(key_value_pairs[0], tuple):
|
|
159
|
+
return key_value_pairs
|
|
160
|
+
return list(zip(key_value_pairs[::2], key_value_pairs[1::2]))
|
|
161
|
+
|
|
162
|
+
|
|
137
163
|
if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
138
164
|
|
|
139
165
|
def make_dynamic_cache(
|
|
140
|
-
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
166
|
+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
141
167
|
) -> transformers.cache_utils.DynamicCache:
|
|
142
168
|
"""
|
|
143
169
|
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
|
|
@@ -173,6 +199,7 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
173
199
|
``transformers>=4.56``. Before that version, only FakeTensor with static dimensions
|
|
174
200
|
are supported.
|
|
175
201
|
"""
|
|
202
|
+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
176
203
|
if (
|
|
177
204
|
key_value_pairs
|
|
178
205
|
and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
|
|
@@ -212,7 +239,7 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
212
239
|
else:
|
|
213
240
|
|
|
214
241
|
def make_dynamic_cache(
|
|
215
|
-
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
242
|
+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
216
243
|
) -> transformers.cache_utils.DynamicCache:
|
|
217
244
|
"""
|
|
218
245
|
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
|
|
@@ -244,6 +271,7 @@ else:
|
|
|
244
271
|
)
|
|
245
272
|
print(string_type(past_key_values, with_shape=True))
|
|
246
273
|
"""
|
|
274
|
+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
247
275
|
cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore
|
|
248
276
|
for i, (key, value) in enumerate(key_value_pairs):
|
|
249
277
|
cache.update(key, value, i)
|
|
@@ -251,7 +279,7 @@ else:
|
|
|
251
279
|
|
|
252
280
|
|
|
253
281
|
def make_static_cache(
|
|
254
|
-
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
282
|
+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
255
283
|
max_cache_len: Optional[int] = None,
|
|
256
284
|
) -> transformers.cache_utils.DynamicCache:
|
|
257
285
|
"""
|
|
@@ -284,6 +312,7 @@ def make_static_cache(
|
|
|
284
312
|
)
|
|
285
313
|
print(string_type(past_key_values, with_shape=True))
|
|
286
314
|
"""
|
|
315
|
+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
287
316
|
|
|
288
317
|
class _config:
|
|
289
318
|
def __init__(self):
|
|
@@ -426,9 +455,10 @@ def make_mamba_cache(
|
|
|
426
455
|
|
|
427
456
|
|
|
428
457
|
def make_sliding_window_cache(
|
|
429
|
-
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
458
|
+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
430
459
|
) -> transformers.cache_utils.SlidingWindowCache:
|
|
431
460
|
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
|
|
461
|
+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
432
462
|
|
|
433
463
|
class _config:
|
|
434
464
|
def __init__(self):
|
|
@@ -481,7 +511,7 @@ def make_sliding_window_cache(
|
|
|
481
511
|
|
|
482
512
|
|
|
483
513
|
def make_hybrid_cache(
|
|
484
|
-
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
514
|
+
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
485
515
|
max_cache_len: Optional[int] = None,
|
|
486
516
|
max_batch_size: Optional[int] = None,
|
|
487
517
|
sliding_window: Optional[int] = None,
|
|
@@ -566,6 +596,7 @@ def make_hybrid_cache(
|
|
|
566
596
|
self.key_cache.append(new_layer_key_cache)
|
|
567
597
|
self.value_cache.append(new_layer_value_cache)
|
|
568
598
|
"""
|
|
599
|
+
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
569
600
|
layer_types = None
|
|
570
601
|
if key_value_pairs:
|
|
571
602
|
assert (
|