onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1083 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import itertools
|
|
3
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from ..helpers import string_type
|
|
7
|
+
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
|
|
8
|
+
|
|
9
|
+
DYNAMIC_SHAPES = Tuple[Tuple[Any, ...], Dict[str, Any]]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _flatten_dynamic_shapes(ds: Any) -> Any:
|
|
13
|
+
"""Flattens the dynamic shapes."""
|
|
14
|
+
if isinstance(ds, list):
|
|
15
|
+
return _flat_list([_flatten_dynamic_shapes(t) for t in ds])
|
|
16
|
+
if isinstance(ds, tuple):
|
|
17
|
+
return tuple(_flat_list([_flatten_dynamic_shapes(t) for t in ds]))
|
|
18
|
+
if isinstance(ds, dict):
|
|
19
|
+
if all(isinstance(i, int) for i in ds):
|
|
20
|
+
# That's a dynamic shape
|
|
21
|
+
return ds
|
|
22
|
+
return _flat_list([_flatten_dynamic_shapes(t) for t in ds.values()])
|
|
23
|
+
raise AssertionError(f"Not implemented for {type(ds)}: {ds}")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _flat_list(li: List[Any]) -> List[Dict[int, str]]:
|
|
27
|
+
res = []
|
|
28
|
+
for t in li:
|
|
29
|
+
if isinstance(t, dict):
|
|
30
|
+
res.append(t)
|
|
31
|
+
else:
|
|
32
|
+
res.extend(t)
|
|
33
|
+
return res
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class CoupleInputsDynamicShapes:
|
|
37
|
+
"""
|
|
38
|
+
Pair inputs / dynamic shapes.
|
|
39
|
+
|
|
40
|
+
:param args: positional arguments
|
|
41
|
+
:param kwargs: named arguments
|
|
42
|
+
:param dynamic_shapes: dynamic shapes
|
|
43
|
+
:param args_names: if both args and kwargs are not empty, then
|
|
44
|
+
dynamic shapes must be a dictionary, and positional must be added
|
|
45
|
+
to the named arguments. Arguments names or a module must be given
|
|
46
|
+
in that case.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
args: Tuple[Any, ...],
|
|
52
|
+
kwargs: Dict[str, Any],
|
|
53
|
+
dynamic_shapes: DYNAMIC_SHAPES,
|
|
54
|
+
args_names: Optional[Union[torch.nn.Module, List[str]]] = None,
|
|
55
|
+
):
|
|
56
|
+
self.args = args
|
|
57
|
+
self.kwargs = kwargs
|
|
58
|
+
self.dynamic_shapes = dynamic_shapes
|
|
59
|
+
self.args_names = args_names
|
|
60
|
+
if not self.kwargs and isinstance(self.dynamic_shapes, dict):
|
|
61
|
+
# This assumes the dictionary for the dynamic shapes is ordered
|
|
62
|
+
# the same way the args are. The input names are not known.
|
|
63
|
+
assert len(self.dynamic_shapes) == len(self.args), (
|
|
64
|
+
f"Length mismatch, kwargs is empty, len(dynamic_shapes)="
|
|
65
|
+
f"{len(self.dynamic_shapes)}, len(args)={len(self.args)}"
|
|
66
|
+
)
|
|
67
|
+
self.dynamic_shapes = tuple(self.dynamic_shapes.values())
|
|
68
|
+
|
|
69
|
+
def __str__(self) -> str:
|
|
70
|
+
return "\n".join(
|
|
71
|
+
[
|
|
72
|
+
f"{self.__class__.__name__}(",
|
|
73
|
+
f" args={string_type(self.args, with_shape=True)},"
|
|
74
|
+
f" kwargs={string_type(self.kwargs, with_shape=True)},"
|
|
75
|
+
f" dynamic_shapes={string_type(self.dynamic_shapes, with_shape=True)},"
|
|
76
|
+
f")",
|
|
77
|
+
]
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def replace_string_by(self, value: Any = None):
|
|
81
|
+
"""
|
|
82
|
+
Replaces string by the value ``torch.export.Dim.DYNAMIC``
|
|
83
|
+
(default) or any other value specified by value.
|
|
84
|
+
|
|
85
|
+
Example:
|
|
86
|
+
|
|
87
|
+
.. runpython::
|
|
88
|
+
:showcode:
|
|
89
|
+
|
|
90
|
+
import torch
|
|
91
|
+
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
|
|
92
|
+
|
|
93
|
+
T3x1 = torch.rand((3, 1))
|
|
94
|
+
T3x4 = torch.rand((3, 4))
|
|
95
|
+
ds_batch = {0: "batch"}
|
|
96
|
+
ds_batch_seq = {0: "batch", 1: "seq"}
|
|
97
|
+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
|
|
98
|
+
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
|
|
99
|
+
print(CoupleInputsDynamicShapes((), kwargs, ds).replace_string_by())
|
|
100
|
+
"""
|
|
101
|
+
return self._generic_walker(
|
|
102
|
+
lambda inputs, ds, value=value: self._replace_string_dim_tensor(
|
|
103
|
+
inputs, ds, value=value
|
|
104
|
+
),
|
|
105
|
+
flatten_unflatten=True,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
@classmethod
|
|
109
|
+
def _replace_string_dim_tensor(cls, inputs, ds, value=None):
|
|
110
|
+
assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
|
|
111
|
+
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
|
|
112
|
+
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
|
|
113
|
+
f"a dictionary is expected to specify a dimension"
|
|
114
|
+
)
|
|
115
|
+
if value is None:
|
|
116
|
+
value = torch.export.Dim.DYNAMIC
|
|
117
|
+
new_ds = ds.copy()
|
|
118
|
+
for i, v in ds.items():
|
|
119
|
+
if isinstance(v, str):
|
|
120
|
+
new_ds[i] = value
|
|
121
|
+
return new_ds
|
|
122
|
+
|
|
123
|
+
def replace_by_string(self):
|
|
124
|
+
"""
|
|
125
|
+
Replaces dimensions by strings.
|
|
126
|
+
|
|
127
|
+
Example:
|
|
128
|
+
|
|
129
|
+
.. runpython::
|
|
130
|
+
:showcode:
|
|
131
|
+
|
|
132
|
+
import torch
|
|
133
|
+
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
|
|
134
|
+
|
|
135
|
+
Dim = torch.export.Dim
|
|
136
|
+
T3x1 = torch.rand((3, 1))
|
|
137
|
+
T3x4 = torch.rand((3, 4))
|
|
138
|
+
ds_batch = {0: Dim("batch")}
|
|
139
|
+
ds_batch_seq = {0: Dim("batch"), 1: Dim("seq")}
|
|
140
|
+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
|
|
141
|
+
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
|
|
142
|
+
print(CoupleInputsDynamicShapes((), kwargs, ds).replace_by_string())
|
|
143
|
+
"""
|
|
144
|
+
unique = set()
|
|
145
|
+
return self._generic_walker(
|
|
146
|
+
lambda inputs, ds, unique=unique: self._replace_dim_tensor_by_string(
|
|
147
|
+
inputs, ds, unique=unique
|
|
148
|
+
),
|
|
149
|
+
flatten_unflatten=True,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def _replace_dim_tensor_by_string(cls, inputs, ds, unique: Set[str]):
|
|
154
|
+
assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
|
|
155
|
+
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
|
|
156
|
+
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
|
|
157
|
+
f"a dictionary is expected to specify a dimension"
|
|
158
|
+
)
|
|
159
|
+
new_ds = ds.copy()
|
|
160
|
+
for i, v in ds.items():
|
|
161
|
+
if isinstance(v, str):
|
|
162
|
+
unique.add(v)
|
|
163
|
+
new_ds[i] = v
|
|
164
|
+
elif v in (torch.export.Dim.DYNAMIC, torch.export.Dim.AUTO):
|
|
165
|
+
name = f"Dim{len(unique)}"
|
|
166
|
+
new_ds[i] = name
|
|
167
|
+
unique.add(name)
|
|
168
|
+
else:
|
|
169
|
+
name = v.__name__
|
|
170
|
+
unique.add(name)
|
|
171
|
+
new_ds[i] = name
|
|
172
|
+
return new_ds
|
|
173
|
+
|
|
174
|
+
def invalid_dimensions_for_export(self):
|
|
175
|
+
"""
|
|
176
|
+
Tells if the inputs are valid based on the dynamic shapes definition.
|
|
177
|
+
The method assumes that all custom classes can be serialized.
|
|
178
|
+
If some patches were applied to export, they should enabled while
|
|
179
|
+
calling this method if the inputs contains such classes.
|
|
180
|
+
|
|
181
|
+
The function checks that a dynamic dimension does not receive a value
|
|
182
|
+
of 0 or 1. It returns the unexpected values in the same structure as
|
|
183
|
+
the given dynamic shapes.
|
|
184
|
+
|
|
185
|
+
Example:
|
|
186
|
+
|
|
187
|
+
.. runpython::
|
|
188
|
+
:showcode:
|
|
189
|
+
|
|
190
|
+
import torch
|
|
191
|
+
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
|
|
192
|
+
|
|
193
|
+
T3x1 = torch.rand((3, 1))
|
|
194
|
+
T3x4 = torch.rand((3, 4))
|
|
195
|
+
ds_batch = {0: "batch"}
|
|
196
|
+
ds_batch_seq = {0: "batch", 1: "seq"}
|
|
197
|
+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
|
|
198
|
+
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
|
|
199
|
+
print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
|
|
200
|
+
|
|
201
|
+
In case it works, it shows:
|
|
202
|
+
|
|
203
|
+
.. runpython::
|
|
204
|
+
:showcode:
|
|
205
|
+
|
|
206
|
+
import torch
|
|
207
|
+
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
|
|
208
|
+
|
|
209
|
+
T3x2 = torch.rand((3, 2))
|
|
210
|
+
T3x4 = torch.rand((3, 4))
|
|
211
|
+
ds_batch = {0: "batch"}
|
|
212
|
+
ds_batch_seq = {0: "batch", 1: "seq"}
|
|
213
|
+
kwargs = {"A": T3x4, "B": (T3x2, T3x2)}
|
|
214
|
+
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
|
|
215
|
+
print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
|
|
216
|
+
"""
|
|
217
|
+
return self._generic_walker(self._valid_shapes_tensor, flatten_unflatten=True)
|
|
218
|
+
|
|
219
|
+
@classmethod
|
|
220
|
+
def _valid_shapes_tensor(cls, inputs, ds):
|
|
221
|
+
assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
|
|
222
|
+
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
|
|
223
|
+
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
|
|
224
|
+
f"a dictionary is expected to specify a dimension dimension"
|
|
225
|
+
)
|
|
226
|
+
issues = {}
|
|
227
|
+
for i, d in enumerate(inputs.shape):
|
|
228
|
+
if i in ds and not isinstance(ds[i], int):
|
|
229
|
+
# dynamic then
|
|
230
|
+
if isinstance(d, int) and d in {0, 1}:
|
|
231
|
+
# export issues for sure
|
|
232
|
+
issues[i] = f"d=[{d}]"
|
|
233
|
+
return issues if issues else None
|
|
234
|
+
|
|
235
|
+
def _generic_walker(
|
|
236
|
+
self, processor: Callable, args_kwargs: bool = False, flatten_unflatten: bool = False
|
|
237
|
+
):
|
|
238
|
+
"""
|
|
239
|
+
Generic deserializator walking through inputs and dynamic_shapes all along.
|
|
240
|
+
The function returns a result with the same structure as the dynamic shapes.
|
|
241
|
+
"""
|
|
242
|
+
if not self.args:
|
|
243
|
+
assert isinstance(self.kwargs, dict) and isinstance(self.dynamic_shapes, dict), (
|
|
244
|
+
f"Type mismatch, args={string_type(self.args)}, "
|
|
245
|
+
f"kwargs={string_type(self.kwargs)} and dynamic_shapes="
|
|
246
|
+
f"{string_type(self.dynamic_shapes)} should have the same type."
|
|
247
|
+
)
|
|
248
|
+
res = self._generic_walker_step(
|
|
249
|
+
processor,
|
|
250
|
+
self.kwargs,
|
|
251
|
+
self.dynamic_shapes,
|
|
252
|
+
flatten_unflatten=flatten_unflatten,
|
|
253
|
+
)
|
|
254
|
+
return (tuple(), res) if args_kwargs else res
|
|
255
|
+
|
|
256
|
+
if not self.kwargs:
|
|
257
|
+
assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), (
|
|
258
|
+
f"Type mismatch, args={string_type(self.args)} and "
|
|
259
|
+
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
|
|
260
|
+
)
|
|
261
|
+
res = self._generic_walker_step(
|
|
262
|
+
processor, self.args, self.dynamic_shapes, flatten_unflatten=flatten_unflatten
|
|
263
|
+
)
|
|
264
|
+
return (res, {}) if args_kwargs else res
|
|
265
|
+
|
|
266
|
+
assert isinstance(self.dynamic_shapes, dict), (
|
|
267
|
+
f"Both positional and named arguments (args and kwargs) are filled. "
|
|
268
|
+
f"dynamic shapes must a dictionary not {type(self.dynamic_shapes)}"
|
|
269
|
+
)
|
|
270
|
+
if not self.args_names and set(self.dynamic_shapes) & set(self.kwargs) == set(
|
|
271
|
+
self.dynamic_shapes
|
|
272
|
+
):
|
|
273
|
+
# No dynamic shapes for the positional arguments.
|
|
274
|
+
return self._generic_walker_step(
|
|
275
|
+
processor,
|
|
276
|
+
self.kwargs,
|
|
277
|
+
self.dynamic_shapes,
|
|
278
|
+
flatten_unflatten=flatten_unflatten,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
if isinstance(self.args_names, list):
|
|
282
|
+
if not set(self.args_names) & set(self.dynamic_shapes):
|
|
283
|
+
# No dynamic shapes for the positional arguments.
|
|
284
|
+
return self._generic_walker_step(
|
|
285
|
+
processor,
|
|
286
|
+
self.kwargs,
|
|
287
|
+
self.dynamic_shapes,
|
|
288
|
+
flatten_unflatten=flatten_unflatten,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
assert self.args_names, (
|
|
292
|
+
"args and kwargs are filled, then args_names must be specified in "
|
|
293
|
+
"the constructor to move positional arguments to named arguments."
|
|
294
|
+
)
|
|
295
|
+
assert len(self.args) <= len(self.args_names), (
|
|
296
|
+
f"There are {len(self.args)} positional arguments "
|
|
297
|
+
f"but only {len(self.args_names)} names. "
|
|
298
|
+
f"args={string_type(self.args, with_shape=True)}, args_name={self.args_names}"
|
|
299
|
+
)
|
|
300
|
+
kwargs = dict(zip(self.args_names, self.args))
|
|
301
|
+
kwargs.update(self.kwargs)
|
|
302
|
+
res = self._generic_walker_step(
|
|
303
|
+
processor, kwargs, self.dynamic_shapes, flatten_unflatten=flatten_unflatten
|
|
304
|
+
)
|
|
305
|
+
if args_kwargs:
|
|
306
|
+
pgs = [None for _ in range(len(self.args))]
|
|
307
|
+
kws = {}
|
|
308
|
+
for k, v in res.items():
|
|
309
|
+
if k not in self.kwargs:
|
|
310
|
+
pgs[self.args_names.index(k)] = v
|
|
311
|
+
else:
|
|
312
|
+
kws[k] = v
|
|
313
|
+
return pgs, kws
|
|
314
|
+
return res
|
|
315
|
+
|
|
316
|
+
raise NotImplementedError(
|
|
317
|
+
f"Not yet implemented when args is filled, "
|
|
318
|
+
f"kwargs as well but args_names is {type(self.args_names)}"
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
@classmethod
|
|
322
|
+
def _generic_walker_step(
|
|
323
|
+
cls, processor: Callable, inputs, ds, flatten_unflatten: bool = False
|
|
324
|
+
):
|
|
325
|
+
if isinstance(inputs, torch.Tensor):
|
|
326
|
+
return processor(inputs, ds)
|
|
327
|
+
if isinstance(inputs, (int, float, str)):
|
|
328
|
+
return None
|
|
329
|
+
if type(inputs) in (tuple, list, dict):
|
|
330
|
+
# Type must be strict, some custom classes can inherit from those.
|
|
331
|
+
assert type(inputs) is type(ds), (
|
|
332
|
+
f"Input type and dynamic shape type mush match but "
|
|
333
|
+
f"type(inputs)={type(inputs)}, type(ds)={type(ds)}, "
|
|
334
|
+
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
|
|
335
|
+
)
|
|
336
|
+
assert len(ds) == len(inputs), (
|
|
337
|
+
f"Length mismatch between inputs {len(inputs)} "
|
|
338
|
+
f"and ds={len(ds)}\n"
|
|
339
|
+
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
|
|
340
|
+
)
|
|
341
|
+
if type(inputs) in (tuple, list):
|
|
342
|
+
value = []
|
|
343
|
+
for i, d in zip(inputs, ds):
|
|
344
|
+
value.append(
|
|
345
|
+
cls._generic_walker_step(
|
|
346
|
+
processor, i, d, flatten_unflatten=flatten_unflatten
|
|
347
|
+
)
|
|
348
|
+
)
|
|
349
|
+
return (
|
|
350
|
+
(value if isinstance(ds, list) else tuple(value))
|
|
351
|
+
if any(v is not None for v in value)
|
|
352
|
+
else None
|
|
353
|
+
)
|
|
354
|
+
assert type(inputs) is dict, f"Unexpected type for inputs {type(inputs)}"
|
|
355
|
+
assert set(inputs) == set(ds), (
|
|
356
|
+
f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, "
|
|
357
|
+
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
|
|
358
|
+
)
|
|
359
|
+
dvalue = {}
|
|
360
|
+
for k, v in inputs.items():
|
|
361
|
+
t = cls._generic_walker_step(
|
|
362
|
+
processor, v, ds[k], flatten_unflatten=flatten_unflatten
|
|
363
|
+
)
|
|
364
|
+
if t is not None:
|
|
365
|
+
dvalue[k] = t
|
|
366
|
+
return dvalue if dvalue else None
|
|
367
|
+
|
|
368
|
+
# A custom class.
|
|
369
|
+
assert inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, (
|
|
370
|
+
f"Class {inputs.__class__.__name__!r} was not registered using "
|
|
371
|
+
f"torch.utils._pytree.register_pytree_node, it is not possible to "
|
|
372
|
+
f"map this class with the given dynamic shapes."
|
|
373
|
+
)
|
|
374
|
+
if flatten_unflatten:
|
|
375
|
+
flatunflat = flatten_unflatten_for_dynamic_shapes(inputs)
|
|
376
|
+
res = cls._generic_walker_step(
|
|
377
|
+
processor, flatunflat, ds, flatten_unflatten=flatten_unflatten
|
|
378
|
+
)
|
|
379
|
+
# Should we restore the original class?
|
|
380
|
+
return res
|
|
381
|
+
flat, spec = torch.utils._pytree.tree_flatten(inputs)
|
|
382
|
+
if all(isinstance(t, torch.Tensor) for t in flat):
|
|
383
|
+
# We need to flatten dynamic shapes as well
|
|
384
|
+
ds = _flatten_dynamic_shapes(ds)
|
|
385
|
+
res = cls._generic_walker_step(
|
|
386
|
+
processor, flat, ds, flatten_unflatten=flatten_unflatten
|
|
387
|
+
)
|
|
388
|
+
# Then we restore the original class.
|
|
389
|
+
return torch.utils._pytree.tree_unflatten(res, spec)
|
|
390
|
+
|
|
391
|
+
class ChangeDimensionProcessor:
|
|
392
|
+
def __init__(self, desired_values, only_desired):
|
|
393
|
+
self.mapping = desired_values or {}
|
|
394
|
+
self.only_desired = only_desired
|
|
395
|
+
|
|
396
|
+
def _build_new_shape(
|
|
397
|
+
self, shape: Tuple[int, ...], ds: Dict[int, Any]
|
|
398
|
+
) -> Tuple[int, ...]:
|
|
399
|
+
new_shape = list(shape)
|
|
400
|
+
for i in range(len(shape)):
|
|
401
|
+
if i in ds:
|
|
402
|
+
if isinstance(ds[i], str):
|
|
403
|
+
d = ds[i]
|
|
404
|
+
elif isinstance(
|
|
405
|
+
ds[i],
|
|
406
|
+
(
|
|
407
|
+
torch.export.dynamic_shapes._DerivedDim,
|
|
408
|
+
torch.export.dynamic_shapes._Dim,
|
|
409
|
+
),
|
|
410
|
+
):
|
|
411
|
+
d = ds[i].__name__
|
|
412
|
+
elif not isinstance(ds[i], int):
|
|
413
|
+
raise NotImplementedError(f"Unable to handle type {ds[i]} in {ds}")
|
|
414
|
+
if d in self.mapping:
|
|
415
|
+
new_dim = self.mapping[d]
|
|
416
|
+
elif not self.only_desired:
|
|
417
|
+
new_dim = shape[i] + 1
|
|
418
|
+
self.mapping[d] = new_dim
|
|
419
|
+
else:
|
|
420
|
+
new_dim = shape[i]
|
|
421
|
+
new_shape[i] = new_dim
|
|
422
|
+
return tuple(new_shape)
|
|
423
|
+
|
|
424
|
+
def _build_new_tensor(self, tensor: torch.Tensor, new_shape: Tuple[int, ...]):
|
|
425
|
+
rank = len(tensor.shape)
|
|
426
|
+
for i in range(len(tensor.shape)):
|
|
427
|
+
d0 = tensor.shape[i]
|
|
428
|
+
d1 = new_shape[i]
|
|
429
|
+
if d0 == d1:
|
|
430
|
+
continue
|
|
431
|
+
alt_shape = list(tensor.shape)
|
|
432
|
+
alt_shape[i] = d1
|
|
433
|
+
new_tensor = torch.zeros(
|
|
434
|
+
tuple(alt_shape), dtype=tensor.dtype, device=tensor.device
|
|
435
|
+
)
|
|
436
|
+
mind = min(d0, d1)
|
|
437
|
+
indices: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
|
|
438
|
+
indices[i] = slice(0, mind)
|
|
439
|
+
ind = tuple(indices)
|
|
440
|
+
new_tensor[ind] = tensor[ind]
|
|
441
|
+
if d1 > mind:
|
|
442
|
+
for k in range(d1 - mind):
|
|
443
|
+
indices0: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
|
|
444
|
+
indices1: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
|
|
445
|
+
indices1[i] = mind + k
|
|
446
|
+
indices0[i] = k % mind
|
|
447
|
+
new_tensor[tuple(indices1)] = tensor[tuple(indices0)]
|
|
448
|
+
tensor = new_tensor
|
|
449
|
+
return tensor
|
|
450
|
+
|
|
451
|
+
def __call__(self, inputs, ds):
|
|
452
|
+
assert isinstance(
|
|
453
|
+
inputs, torch.Tensor
|
|
454
|
+
), f"unexpected type for inputs {type(inputs)}"
|
|
455
|
+
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
|
|
456
|
+
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
|
|
457
|
+
f"a dictionary is expected to specify a dimension dimension"
|
|
458
|
+
)
|
|
459
|
+
new_shape = self._build_new_shape(inputs.shape, ds)
|
|
460
|
+
return self._build_new_tensor(inputs, new_shape)
|
|
461
|
+
|
|
462
|
+
def change_dynamic_dimensions(
|
|
463
|
+
self,
|
|
464
|
+
desired_values: Optional[Dict[str, int]] = None,
|
|
465
|
+
args_kwargs: bool = False,
|
|
466
|
+
only_desired: bool = False,
|
|
467
|
+
):
|
|
468
|
+
"""
|
|
469
|
+
A model exported with dynamic shapes is not necessarily dynamic
|
|
470
|
+
just because the user specified dynamic shapes. The algorithm
|
|
471
|
+
may discover that a dimension cannot be dynamic and then continues
|
|
472
|
+
the export making the assumption it is static. That may lead a wrong
|
|
473
|
+
model. This function produces a new set of inputs with different values
|
|
474
|
+
for the dimension than the first ones, assuming they were used to export
|
|
475
|
+
the model.
|
|
476
|
+
|
|
477
|
+
:param desired_values: to fixed named dimension to have the desired value
|
|
478
|
+
:param args_kwargs: return both args, kwargs even if empty
|
|
479
|
+
:param only_desired: if True, only change the dimension specified in
|
|
480
|
+
``desired_values``
|
|
481
|
+
:return: new inputs
|
|
482
|
+
|
|
483
|
+
Example:
|
|
484
|
+
|
|
485
|
+
.. runpython::
|
|
486
|
+
:showcode:
|
|
487
|
+
|
|
488
|
+
import torch
|
|
489
|
+
from onnx_diagnostic.helpers import string_type
|
|
490
|
+
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
|
|
491
|
+
|
|
492
|
+
T3x15 = torch.rand((3, 15))
|
|
493
|
+
T3x20 = torch.rand((3, 20))
|
|
494
|
+
T3x4 = torch.rand((3, 4))
|
|
495
|
+
ds_batch = {0: "batch"}
|
|
496
|
+
ds_batch_seq = {0: "batch", 1: "seq"}
|
|
497
|
+
kwargs = {"A": T3x4, "B": (T3x15, T3x20)}
|
|
498
|
+
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
|
|
499
|
+
new_kwargs = CoupleInputsDynamicShapes((), kwargs, ds).change_dynamic_dimensions()
|
|
500
|
+
print("before:", string_type(kwargs, with_shape=True))
|
|
501
|
+
print("-after:", string_type(new_kwargs, with_shape=True))
|
|
502
|
+
"""
|
|
503
|
+
return self._generic_walker(
|
|
504
|
+
self.ChangeDimensionProcessor(desired_values, only_desired=only_desired),
|
|
505
|
+
args_kwargs=args_kwargs,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
class ModelInputs:
|
|
510
|
+
"""
|
|
511
|
+
Wraps a model and a couple of sets of valid inputs.
|
|
512
|
+
Based on that information, the class is able to infer the dynamic shapes
|
|
513
|
+
for :func:`torch.export.export`.
|
|
514
|
+
|
|
515
|
+
:param model: model to export
|
|
516
|
+
:param inputs: list of valid set of inputs
|
|
517
|
+
:param level: if this module is a submodule, it is the level of submodule
|
|
518
|
+
:param method_name: by default, the forward method is processed but it
|
|
519
|
+
could be another one
|
|
520
|
+
:param name: a name, mostly for debugging purposes
|
|
521
|
+
|
|
522
|
+
Examples:
|
|
523
|
+
|
|
524
|
+
**args**
|
|
525
|
+
|
|
526
|
+
.. runpython::
|
|
527
|
+
:showcode:
|
|
528
|
+
|
|
529
|
+
import pprint
|
|
530
|
+
import torch
|
|
531
|
+
from onnx_diagnostic.export import ModelInputs
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
class Model(torch.nn.Module):
|
|
535
|
+
def forward(self, x, y):
|
|
536
|
+
return x + y
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
model = Model()
|
|
540
|
+
x = torch.randn((5, 6))
|
|
541
|
+
y = torch.randn((1, 6))
|
|
542
|
+
model(x, y) # to check it works
|
|
543
|
+
|
|
544
|
+
inputs = [(x, y), (torch.randn((7, 8)), torch.randn((1, 8)))]
|
|
545
|
+
mi = ModelInputs(Model(), inputs)
|
|
546
|
+
ds = mi.guess_dynamic_shapes()
|
|
547
|
+
pprint.pprint(ds)
|
|
548
|
+
|
|
549
|
+
**kwargs**
|
|
550
|
+
|
|
551
|
+
.. runpython::
|
|
552
|
+
:showcode:
|
|
553
|
+
|
|
554
|
+
import pprint
|
|
555
|
+
import torch
|
|
556
|
+
from onnx_diagnostic.export import ModelInputs
|
|
557
|
+
|
|
558
|
+
class Model(torch.nn.Module):
|
|
559
|
+
def forward(self, x, y):
|
|
560
|
+
return x + y
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
model = Model()
|
|
564
|
+
x = torch.randn((5, 6))
|
|
565
|
+
y = torch.randn((1, 6))
|
|
566
|
+
model(x=x, y=y) # to check it works
|
|
567
|
+
|
|
568
|
+
inputs = [dict(x=x, y=y), dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)))]
|
|
569
|
+
mi = ModelInputs(Model(), inputs)
|
|
570
|
+
ds = mi.guess_dynamic_shapes()
|
|
571
|
+
pprint.pprint(ds)
|
|
572
|
+
|
|
573
|
+
**args and kwargs**
|
|
574
|
+
|
|
575
|
+
.. runpython::
|
|
576
|
+
:showcode:
|
|
577
|
+
|
|
578
|
+
import pprint
|
|
579
|
+
import torch
|
|
580
|
+
from onnx_diagnostic.export import ModelInputs
|
|
581
|
+
|
|
582
|
+
class Model(torch.nn.Module):
|
|
583
|
+
def forward(self, x, y):
|
|
584
|
+
return x + y
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
model = Model()
|
|
588
|
+
x = torch.randn((5, 6))
|
|
589
|
+
y = torch.randn((1, 6))
|
|
590
|
+
model(x, y=y) # to check it works
|
|
591
|
+
|
|
592
|
+
inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
|
|
593
|
+
mi = ModelInputs(Model(), inputs)
|
|
594
|
+
ds = mi.guess_dynamic_shapes()
|
|
595
|
+
pprint.pprint(ds)
|
|
596
|
+
|
|
597
|
+
:func:`torch.export.export` does not like dynamic shapes defined both as args and kwargs.
|
|
598
|
+
kwargs must be used. ``move_to_kwargs`` modifies the inputs and the dynamic shapes
|
|
599
|
+
to make the model and the given inputs exportable.
|
|
600
|
+
|
|
601
|
+
.. runpython::
|
|
602
|
+
:showcode:
|
|
603
|
+
|
|
604
|
+
import pprint
|
|
605
|
+
import torch
|
|
606
|
+
from onnx_diagnostic.export import ModelInputs
|
|
607
|
+
from onnx_diagnostic.helpers import string_type
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
class Model(torch.nn.Module):
|
|
611
|
+
def forward(self, x, y):
|
|
612
|
+
return x + y
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
model = Model()
|
|
616
|
+
x = torch.randn((5, 6))
|
|
617
|
+
y = torch.randn((1, 6))
|
|
618
|
+
model(x, y=y) # to check it works
|
|
619
|
+
|
|
620
|
+
inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
|
|
621
|
+
mi = ModelInputs(Model(), inputs)
|
|
622
|
+
ds = mi.guess_dynamic_shapes()
|
|
623
|
+
|
|
624
|
+
a, kw, nds = mi.move_to_kwargs(*mi.inputs[0], ds)
|
|
625
|
+
print("moved args:", string_type(a, with_shape=True))
|
|
626
|
+
print("moved kwargs:", string_type(kw, with_shape=True))
|
|
627
|
+
print("dynamic shapes:")
|
|
628
|
+
pprint.pprint(nds)
|
|
629
|
+
"""
|
|
630
|
+
|
|
631
|
+
def __init__(
|
|
632
|
+
self,
|
|
633
|
+
model: torch.nn.Module,
|
|
634
|
+
inputs: Union[
|
|
635
|
+
List[Tuple[Any, ...]],
|
|
636
|
+
List[Dict[str, Any]],
|
|
637
|
+
List[Tuple[Tuple[Any, ...], Dict[str, Any]]],
|
|
638
|
+
],
|
|
639
|
+
level: int = 0,
|
|
640
|
+
method_name: str = "forward",
|
|
641
|
+
name: str = "main",
|
|
642
|
+
):
|
|
643
|
+
assert (
|
|
644
|
+
model is None or isinstance(model, torch.nn.Module) or inspect.ismodule(model)
|
|
645
|
+
), (
|
|
646
|
+
f"unexpected type for model={type(model)}, "
|
|
647
|
+
f"it must be a torch.nn.Module or None"
|
|
648
|
+
)
|
|
649
|
+
assert name, (
|
|
650
|
+
f"name={name!r} cannot be empty this string is used to "
|
|
651
|
+
f"display meaningful error messages"
|
|
652
|
+
)
|
|
653
|
+
self.name = name
|
|
654
|
+
self.model = model
|
|
655
|
+
self.level = level
|
|
656
|
+
self.method_name = method_name
|
|
657
|
+
self.forward = getattr(model, method_name) if model is not None else None
|
|
658
|
+
self.signature = inspect.signature(self.forward) if self.forward else None
|
|
659
|
+
|
|
660
|
+
# information about the signature
|
|
661
|
+
self.forward_parameter_names = (
|
|
662
|
+
set(
|
|
663
|
+
p.name
|
|
664
|
+
for p in self.signature.parameters.values()
|
|
665
|
+
if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
|
|
666
|
+
)
|
|
667
|
+
if self.signature
|
|
668
|
+
else None
|
|
669
|
+
)
|
|
670
|
+
self.forward_ordered_parameter_names = (
|
|
671
|
+
list(self.signature.parameters) if self.signature else None
|
|
672
|
+
)
|
|
673
|
+
self.forward_positioned_parameter_names = (
|
|
674
|
+
[
|
|
675
|
+
p.name
|
|
676
|
+
for p in self.signature.parameters.values()
|
|
677
|
+
if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
|
|
678
|
+
]
|
|
679
|
+
if self.signature
|
|
680
|
+
else None
|
|
681
|
+
)
|
|
682
|
+
names = (
|
|
683
|
+
[p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL]
|
|
684
|
+
if self.signature
|
|
685
|
+
else None
|
|
686
|
+
)
|
|
687
|
+
self.forward_args = names[0] if names else None
|
|
688
|
+
names = (
|
|
689
|
+
[p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD]
|
|
690
|
+
if self.signature
|
|
691
|
+
else None
|
|
692
|
+
)
|
|
693
|
+
self.forward_kwargs = names[0] if names else None
|
|
694
|
+
self.forward_custom_op_schema = None
|
|
695
|
+
self.forward_need_serialization = False
|
|
696
|
+
self.forward_fill_kwargs = bool(self.forward_kwargs)
|
|
697
|
+
assert not isinstance(
|
|
698
|
+
model, (torch.nn.ModuleList, torch.nn.ModuleDict)
|
|
699
|
+
), f"ModuleList or ModuleDict should not be traced: {type(model)}"
|
|
700
|
+
|
|
701
|
+
# process the inputs
|
|
702
|
+
self.inputs = self.process_inputs(inputs)
|
|
703
|
+
|
|
704
|
+
def process_inputs(
|
|
705
|
+
self,
|
|
706
|
+
inputs: Union[
|
|
707
|
+
List[Tuple[Any, ...]],
|
|
708
|
+
List[Dict[str, Any]],
|
|
709
|
+
List[Tuple[Tuple[Any, ...], Dict[str, Any]]],
|
|
710
|
+
],
|
|
711
|
+
) -> List[Tuple[Tuple[Any, ...], Dict[str, Any]]]:
|
|
712
|
+
"""
|
|
713
|
+
Transforms a list of valid inputs, list of args, list of kwargs or list of both
|
|
714
|
+
into a list of (args, kwargs).
|
|
715
|
+
"""
|
|
716
|
+
if not isinstance(inputs, list):
|
|
717
|
+
raise ValueError(
|
|
718
|
+
f"inputs should be specified as a list of sets of "
|
|
719
|
+
f"inputs but type(inputs) is {type(inputs)}"
|
|
720
|
+
)
|
|
721
|
+
new_inputs = []
|
|
722
|
+
for i, inp in enumerate(inputs):
|
|
723
|
+
if (
|
|
724
|
+
isinstance(inp, tuple)
|
|
725
|
+
and len(inp) == 2
|
|
726
|
+
and isinstance(inp[0], tuple)
|
|
727
|
+
and isinstance(inp[1], dict)
|
|
728
|
+
):
|
|
729
|
+
new_inputs.append(inp)
|
|
730
|
+
continue
|
|
731
|
+
if isinstance(inp, tuple):
|
|
732
|
+
new_inputs.append((inp, {}))
|
|
733
|
+
continue
|
|
734
|
+
if isinstance(inp, dict):
|
|
735
|
+
new_inputs.append(((), inp))
|
|
736
|
+
continue
|
|
737
|
+
raise ValueError(f"Unable to interpret inputs {i}: {string_type(inp)}")
|
|
738
|
+
return new_inputs
|
|
739
|
+
|
|
740
|
+
@property
|
|
741
|
+
def true_model_name(self) -> str:
|
|
742
|
+
"Returns class name or module name."
|
|
743
|
+
assert self.model is not None, "model was None when the class was initialized."
|
|
744
|
+
return (
|
|
745
|
+
self.model.__class__.__name__
|
|
746
|
+
if isinstance(self.model, torch.nn.Module)
|
|
747
|
+
else self.model.__name__
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
@property
|
|
751
|
+
def full_name(self) -> str:
|
|
752
|
+
"Returns a name and class name."
|
|
753
|
+
if self.method_name == "forward":
|
|
754
|
+
return f"{self.name}:{self.true_model_name}"
|
|
755
|
+
return f"{self.name}:{self.true_model_name}.{self.method_name}"
|
|
756
|
+
|
|
757
|
+
@property
|
|
758
|
+
def module_name_type(self):
|
|
759
|
+
"Returns name and module type."
|
|
760
|
+
if self.method_name == "forward":
|
|
761
|
+
return f"type({self.name})={self.true_model_name}"
|
|
762
|
+
return f"type({self.name})={self.true_model_name}.{self.method_name}"
|
|
763
|
+
|
|
764
|
+
def guess_dynamic_dimensions(
|
|
765
|
+
self, *tensors, auto: Union[bool, str] = False
|
|
766
|
+
) -> Optional[Dict[int, Any]]:
|
|
767
|
+
"""
|
|
768
|
+
Infers the dynamic dimension from multiple shapes.
|
|
769
|
+
If auto is True, it returns ``torch.export.Dim.AUTO`` for every dimension
|
|
770
|
+
which cannot be guessed. Two tensors with the same value for one dimension
|
|
771
|
+
can be guessed, but if there is only 1, it cannot. ``auto``` can be a string
|
|
772
|
+
to produce strings.
|
|
773
|
+
"""
|
|
774
|
+
if len(tensors) == 1:
|
|
775
|
+
if isinstance(tensors[0], (int, float)):
|
|
776
|
+
return None
|
|
777
|
+
assert isinstance(tensors[0], torch.Tensor), (
|
|
778
|
+
f"Unexpected type for tensors {string_type(tensors, with_shape=True)}, "
|
|
779
|
+
f"Only tensors are allowed."
|
|
780
|
+
)
|
|
781
|
+
return (
|
|
782
|
+
{i: torch.export.Dim.AUTO for i in range(len(tensors[0].shape))} # noqa: C420
|
|
783
|
+
if auto and not isinstance(auto, str)
|
|
784
|
+
else {}
|
|
785
|
+
)
|
|
786
|
+
shapes = [t.shape for t in tensors]
|
|
787
|
+
set_length = set(len(s) for s in shapes)
|
|
788
|
+
assert len(set_length) == 1, (
|
|
789
|
+
f"Shapes can be different but not ranks possible shapes={set_length} "
|
|
790
|
+
f"shapes={shapes} for module {self.name!r}, "
|
|
791
|
+
f"class={self.true_model_name!r}"
|
|
792
|
+
)
|
|
793
|
+
dynamic: Any = (
|
|
794
|
+
auto
|
|
795
|
+
if isinstance(auto, str)
|
|
796
|
+
else (torch.export.Dim.AUTO if auto else torch.export.Dim.DYNAMIC)
|
|
797
|
+
)
|
|
798
|
+
rk = set_length.pop()
|
|
799
|
+
res = {}
|
|
800
|
+
for i in range(rk):
|
|
801
|
+
set_dim = set(s[i] for s in shapes)
|
|
802
|
+
if len(set_dim) > 1:
|
|
803
|
+
res[i] = dynamic if not isinstance(dynamic, str) else f"{dynamic}{i}"
|
|
804
|
+
continue
|
|
805
|
+
if set_dim == {0}:
|
|
806
|
+
# It is unexpected to find a null dimension. Let's replace it by a dynamic one.
|
|
807
|
+
res[i] = dynamic if not isinstance(dynamic, str) else f"{dynamic}{i}"
|
|
808
|
+
continue
|
|
809
|
+
return res
|
|
810
|
+
|
|
811
|
+
def guess_dynamic_shape_object(
|
|
812
|
+
self, *objs: Any, auto: Union[bool, str] = False, msg: Optional[Callable] = None
|
|
813
|
+
) -> Any:
|
|
814
|
+
"""Guesses the dynamic shapes for one argument."""
|
|
815
|
+
if len(objs) == 0:
|
|
816
|
+
return None
|
|
817
|
+
set_types = set(type(o) for o in objs)
|
|
818
|
+
assert (
|
|
819
|
+
len(set_types) == 1
|
|
820
|
+
), f"Unexpected variety of input type {set_types}{msg() if msg else ''})"
|
|
821
|
+
obj = objs[0]
|
|
822
|
+
if obj is None:
|
|
823
|
+
return None
|
|
824
|
+
if isinstance(obj, (bool, int, float, str)):
|
|
825
|
+
return None
|
|
826
|
+
if isinstance(obj, (torch.Tensor, np.ndarray)):
|
|
827
|
+
return self.guess_dynamic_dimensions(*objs, auto=auto)
|
|
828
|
+
|
|
829
|
+
if isinstance(obj, tuple):
|
|
830
|
+
kl = set(len(o) for o in objs)
|
|
831
|
+
assert (
|
|
832
|
+
len(kl) == 1
|
|
833
|
+
), f"Unexpected variety of tuple lengths {kl}{msg() if msg else ''}"
|
|
834
|
+
shapes: Any = []
|
|
835
|
+
for i in range(kl.pop()):
|
|
836
|
+
shapes.append(
|
|
837
|
+
self.guess_dynamic_shape_object(
|
|
838
|
+
*[o[i] for o in objs],
|
|
839
|
+
auto=auto if isinstance(auto, bool) else f"{auto}_{i}t",
|
|
840
|
+
msg=msg,
|
|
841
|
+
)
|
|
842
|
+
)
|
|
843
|
+
return tuple(shapes)
|
|
844
|
+
|
|
845
|
+
if isinstance(obj, list):
|
|
846
|
+
kl = set(len(o) for o in objs)
|
|
847
|
+
assert (
|
|
848
|
+
len(kl) == 1
|
|
849
|
+
), f"Unexpected variety of list lengths {kl}{msg() if msg else ''}"
|
|
850
|
+
shapes = []
|
|
851
|
+
for i in range(kl.pop()):
|
|
852
|
+
shapes.append(
|
|
853
|
+
self.guess_dynamic_shape_object(
|
|
854
|
+
*[o[i] for o in objs],
|
|
855
|
+
auto=auto if isinstance(auto, bool) else f"{auto}_{i}l",
|
|
856
|
+
msg=msg,
|
|
857
|
+
)
|
|
858
|
+
)
|
|
859
|
+
return shapes
|
|
860
|
+
|
|
861
|
+
if isinstance(obj, dict):
|
|
862
|
+
kl = set(len(o) for o in objs)
|
|
863
|
+
assert (
|
|
864
|
+
len(kl) == 1
|
|
865
|
+
), f"Unexpected variety of dict lengths {kl}{msg() if msg else ''}"
|
|
866
|
+
shapes = {}
|
|
867
|
+
for i in obj:
|
|
868
|
+
shapes[i] = self.guess_dynamic_shape_object(
|
|
869
|
+
*[o[i] for o in objs],
|
|
870
|
+
auto=auto if isinstance(auto, bool) else f"{auto}_{i}d",
|
|
871
|
+
msg=msg,
|
|
872
|
+
)
|
|
873
|
+
return shapes
|
|
874
|
+
|
|
875
|
+
if obj.__class__ in torch.utils._pytree.SUPPORTED_NODES:
|
|
876
|
+
kcl = set(o.__class__ for o in objs)
|
|
877
|
+
assert len(kcl) == 1, (
|
|
878
|
+
f"All instances of argument {i} are not of the same class but {kcl}, "
|
|
879
|
+
f"types should be the same."
|
|
880
|
+
)
|
|
881
|
+
col_args = [flatten_unflatten_for_dynamic_shapes(o) for o in objs]
|
|
882
|
+
kc = set(len(o) for o in col_args)
|
|
883
|
+
assert len(kc) == 1, (
|
|
884
|
+
f"All instances of type {kcl.pop()} are not serialized into the same number "
|
|
885
|
+
f"of arguments, it should be the same."
|
|
886
|
+
)
|
|
887
|
+
values = []
|
|
888
|
+
for i in range(kc.pop()):
|
|
889
|
+
values.append(
|
|
890
|
+
self.guess_dynamic_shape_object(
|
|
891
|
+
*[ca[i] for ca in col_args],
|
|
892
|
+
auto=auto if isinstance(auto, bool) else f"{auto}_{i}o",
|
|
893
|
+
msg=msg,
|
|
894
|
+
)
|
|
895
|
+
)
|
|
896
|
+
return values
|
|
897
|
+
|
|
898
|
+
# In case DynamicCache is not registered.
|
|
899
|
+
if obj.__class__.__name__ == "DynamicCache":
|
|
900
|
+
if hasattr(obj, "layers"):
|
|
901
|
+
kc = set(len(o.layers) for o in objs)
|
|
902
|
+
assert (
|
|
903
|
+
len(kc) == 1
|
|
904
|
+
), f"All attribute 'key_cache' should have the same length but found {kc}"
|
|
905
|
+
vc = kc.copy()
|
|
906
|
+
else:
|
|
907
|
+
kc = set(len(o.key_cache) for o in objs)
|
|
908
|
+
assert (
|
|
909
|
+
len(kc) == 1
|
|
910
|
+
), f"All attribute 'key_cache' should have the same length but found {kc}"
|
|
911
|
+
vc = set(len(o.value_cache) for o in objs)
|
|
912
|
+
assert (
|
|
913
|
+
len(vc) == 1
|
|
914
|
+
), f"All attribute 'value_cache' should have the same length but found {vc}"
|
|
915
|
+
|
|
916
|
+
key_cache = []
|
|
917
|
+
for i in range(kc.pop()):
|
|
918
|
+
key_cache.append(
|
|
919
|
+
self.guess_dynamic_dimensions(
|
|
920
|
+
*[
|
|
921
|
+
o.layers[i].keys if hasattr(o, "layers") else o.key_cache[i]
|
|
922
|
+
for o in objs
|
|
923
|
+
],
|
|
924
|
+
auto=auto if isinstance(auto, bool) else f"{auto}_{i}kdc",
|
|
925
|
+
)
|
|
926
|
+
)
|
|
927
|
+
value_cache = []
|
|
928
|
+
for i in range(vc.pop()):
|
|
929
|
+
value_cache.append(
|
|
930
|
+
self.guess_dynamic_dimensions(
|
|
931
|
+
*[
|
|
932
|
+
o.layers[i].values if hasattr(o, "layers") else o.value_cache[i]
|
|
933
|
+
for o in objs
|
|
934
|
+
],
|
|
935
|
+
auto=auto if isinstance(auto, bool) else f"{auto}_{i}vdc",
|
|
936
|
+
)
|
|
937
|
+
)
|
|
938
|
+
return list(itertools.chain.from_iterable(zip(key_cache, value_cache)))
|
|
939
|
+
|
|
940
|
+
raise NotImplementedError(
|
|
941
|
+
f"Unable to build dynamic shapes for type {set_types.pop()}: "
|
|
942
|
+
f"{string_type(objs)}{msg() if msg else ''} in {self.module_name_type}, "
|
|
943
|
+
f"this object needs serialization function to be registered."
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
def guess_dynamic_shapes(self, auto: Union[bool, str] = False) -> DYNAMIC_SHAPES:
|
|
947
|
+
"""
|
|
948
|
+
Guesses the dynamic shapes for that module from two execution.
|
|
949
|
+
If there is only one execution, then that would be static dimensions.
|
|
950
|
+
|
|
951
|
+
:param auto: if auto is True, use ``torch.export.Dim.AUTO`` for any
|
|
952
|
+
dimension if the number of inputs is one,
|
|
953
|
+
if ``auto`` is a string, it uses strings
|
|
954
|
+
:return: guessed dynamic shapes
|
|
955
|
+
|
|
956
|
+
See example :ref:`l-guess-dynamic-shapes-example`.
|
|
957
|
+
"""
|
|
958
|
+
if len(self.inputs) == 0:
|
|
959
|
+
# No inputs, unable to guess.
|
|
960
|
+
return (tuple(), {})
|
|
961
|
+
if len(self.inputs) == 1:
|
|
962
|
+
# No dynamic shapes.
|
|
963
|
+
return tuple(
|
|
964
|
+
self.guess_dynamic_shape_object(a, auto=auto) for a in self.inputs[0][0]
|
|
965
|
+
), {
|
|
966
|
+
k: self.guess_dynamic_shape_object(v, auto=auto)
|
|
967
|
+
for k, v in self.inputs[0][1].items()
|
|
968
|
+
}
|
|
969
|
+
|
|
970
|
+
# Otherwise.
|
|
971
|
+
s1 = set(len(i[0]) for i in self.inputs)
|
|
972
|
+
assert (
|
|
973
|
+
len(s1) == 1
|
|
974
|
+
), f"Different numbers of positional arguments {s1} for {self.full_name}"
|
|
975
|
+
s2 = set(tuple(sorted(set(i[1]))) for i in self.inputs)
|
|
976
|
+
assert len(s2) == 1, f"Different named arguments {s2} for {self.full_name}"
|
|
977
|
+
args = []
|
|
978
|
+
kwargs = {}
|
|
979
|
+
for i in range(s1.pop()):
|
|
980
|
+
objs = [_[0][i] for _ in self.inputs]
|
|
981
|
+
args.append(
|
|
982
|
+
self.guess_dynamic_shape_object(
|
|
983
|
+
*objs,
|
|
984
|
+
auto=auto if isinstance(auto, bool) else f"{auto}_{i}I",
|
|
985
|
+
msg=lambda i=i: f" failing input {i}",
|
|
986
|
+
)
|
|
987
|
+
)
|
|
988
|
+
names = s2.pop()
|
|
989
|
+
for i, name in enumerate(names):
|
|
990
|
+
assert name not in {"_diag", "verbose"}, (
|
|
991
|
+
f"{self.full_name}: unexpected parameter {name!r}, names={names}"
|
|
992
|
+
f"\ninputs[0]={string_type(self.inputs[0], with_shape=True)}"
|
|
993
|
+
f"\ninputs[1]={string_type(self.inputs[1], with_shape=True)}"
|
|
994
|
+
)
|
|
995
|
+
|
|
996
|
+
objs = [_[1][name] for _ in self.inputs]
|
|
997
|
+
kwargs[name] = self.guess_dynamic_shape_object(
|
|
998
|
+
*objs,
|
|
999
|
+
auto=auto if isinstance(auto, bool) else f"{auto}_{i}I",
|
|
1000
|
+
msg=lambda name=name: f" failing input {name!r}",
|
|
1001
|
+
)
|
|
1002
|
+
return tuple(args), kwargs
|
|
1003
|
+
|
|
1004
|
+
def move_to_kwargs(
|
|
1005
|
+
self,
|
|
1006
|
+
args: Tuple[Any, ...],
|
|
1007
|
+
kwargs: Dict[str, Any],
|
|
1008
|
+
dynamic_shapes: Tuple[Tuple[Any, ...], Dict[str, Any]],
|
|
1009
|
+
) -> Tuple[Tuple[Any, ...], Dict[str, Any], DYNAMIC_SHAPES]:
|
|
1010
|
+
"""
|
|
1011
|
+
Uses the signatures to move positional arguments (args) to named arguments (kwargs)
|
|
1012
|
+
with the corresponding dynamic shapes.
|
|
1013
|
+
*kwargs*, *dynamic_shapes* are modified inplace.
|
|
1014
|
+
"""
|
|
1015
|
+
assert (
|
|
1016
|
+
self.signature is not None
|
|
1017
|
+
and self.forward_parameter_names is not None
|
|
1018
|
+
and self.forward_ordered_parameter_names is not None
|
|
1019
|
+
), (
|
|
1020
|
+
"model was None when the class was initialized, "
|
|
1021
|
+
"cannot move args to kwargs without the signature."
|
|
1022
|
+
)
|
|
1023
|
+
sig = self.signature
|
|
1024
|
+
arg_dyn, kw_dyn = dynamic_shapes
|
|
1025
|
+
for i, p in enumerate(sig.parameters):
|
|
1026
|
+
if i >= len(arg_dyn):
|
|
1027
|
+
break
|
|
1028
|
+
kw_dyn[p] = arg_dyn[i]
|
|
1029
|
+
if self.forward_kwargs:
|
|
1030
|
+
kdw = {}
|
|
1031
|
+
for k, v in kw_dyn.items():
|
|
1032
|
+
if k not in self.forward_parameter_names:
|
|
1033
|
+
kdw[k] = v
|
|
1034
|
+
if kdw:
|
|
1035
|
+
for k in kdw:
|
|
1036
|
+
del kw_dyn[k]
|
|
1037
|
+
kw_dyn[self.forward_kwargs] = kdw
|
|
1038
|
+
|
|
1039
|
+
# Let's reorder as it seems to matter later
|
|
1040
|
+
# in the shape inference algorithm.
|
|
1041
|
+
_kwargs = kwargs
|
|
1042
|
+
kwargs = {}
|
|
1043
|
+
_kw_dyn = kw_dyn
|
|
1044
|
+
kw_dyn = {}
|
|
1045
|
+
for name in self.forward_ordered_parameter_names:
|
|
1046
|
+
if name in _kwargs:
|
|
1047
|
+
kwargs[name] = _kwargs[name]
|
|
1048
|
+
if name in _kw_dyn:
|
|
1049
|
+
kw_dyn[name] = _kw_dyn[name]
|
|
1050
|
+
for k in _kwargs:
|
|
1051
|
+
if k not in kwargs:
|
|
1052
|
+
# Then it is part of **kwargs.
|
|
1053
|
+
kwargs[k] = _kwargs[k]
|
|
1054
|
+
assert len(kw_dyn) == len(_kw_dyn), (
|
|
1055
|
+
f"{self.full_name}: unexpected mismatch between _kw_dyn={set(_kw_dyn)} "
|
|
1056
|
+
f"and kw_dyn={set(kw_dyn)}, "
|
|
1057
|
+
f"forward_ordered_parameter_names={self.forward_ordered_parameter_names}"
|
|
1058
|
+
)
|
|
1059
|
+
assert len(kwargs) == len(_kwargs), (
|
|
1060
|
+
f"{self.full_name}: unexpected mismatch between _kwargs={set(_kwargs)} "
|
|
1061
|
+
f"and kwargs={set(kwargs)}, "
|
|
1062
|
+
f"forward_ordered_parameter_names={self.forward_ordered_parameter_names}"
|
|
1063
|
+
)
|
|
1064
|
+
return args, kwargs, (tuple(), kw_dyn)
|
|
1065
|
+
|
|
1066
|
+
def validate_inputs_for_export(
|
|
1067
|
+
self, dynamic_shapes: Optional[DYNAMIC_SHAPES] = None
|
|
1068
|
+
) -> List[List[Union[int, str]]]:
|
|
1069
|
+
"""
|
|
1070
|
+
Validates the inputs the class contains for the given dynamic shapes.
|
|
1071
|
+
If not specified, the dynamic_shapes are guessed.
|
|
1072
|
+
|
|
1073
|
+
:param dynamic_shapes: dynamic shapes to validate
|
|
1074
|
+
:return: a list of lists, every list contains the path the invalid dimension
|
|
1075
|
+
"""
|
|
1076
|
+
if dynamic_shapes is None:
|
|
1077
|
+
if len(self.inputs) == 1:
|
|
1078
|
+
return []
|
|
1079
|
+
dyn_shapes = self.guess_dynamic_shapes()
|
|
1080
|
+
return [
|
|
1081
|
+
CoupleInputsDynamicShapes(*i, dyn_shapes).invalid_dimensions_for_export()
|
|
1082
|
+
for i in self.inputs
|
|
1083
|
+
]
|