onnx-diagnostic 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +411 -0
- onnx_diagnostic/doc.py +32 -0
- onnx_diagnostic/export/__init__.py +1 -1
- onnx_diagnostic/export/dynamic_shapes.py +433 -22
- onnx_diagnostic/ext_test_case.py +90 -29
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/{cache_helpers.py → helpers/cache_helper.py} +62 -4
- onnx_diagnostic/{helpers.py → helpers/helper.py} +136 -659
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/onnx_helper.py +921 -0
- onnx_diagnostic/{ort_session.py → helpers/ort_session.py} +54 -4
- onnx_diagnostic/{torch_test_helper.py → helpers/torch_test_helper.py} +142 -55
- onnx_diagnostic/reference/ops/op_cast_like.py +1 -1
- onnx_diagnostic/reference/ort_evaluator.py +7 -2
- onnx_diagnostic/torch_export_patches/__init__.py +107 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +160 -28
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +13 -2
- onnx_diagnostic/torch_export_patches/patch_inputs.py +174 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +194 -1
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +18 -5
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +195 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +3259 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +727 -0
- onnx_diagnostic/torch_models/llms.py +2 -96
- onnx_diagnostic/torch_models/test_helper.py +827 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +108 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +103 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/sbs.py +439 -0
- {onnx_diagnostic-0.2.1.dist-info → onnx_diagnostic-0.3.0.dist-info}/METADATA +4 -2
- onnx_diagnostic-0.3.0.dist-info/RECORD +73 -0
- {onnx_diagnostic-0.2.1.dist-info → onnx_diagnostic-0.3.0.dist-info}/WHEEL +1 -1
- onnx_diagnostic/onnx_tools.py +0 -260
- onnx_diagnostic-0.2.1.dist-info/RECORD +0 -55
- /onnx_diagnostic/{args.py → helpers/args_helper.py} +0 -0
- {onnx_diagnostic-0.2.1.dist-info → onnx_diagnostic-0.3.0.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.2.1.dist-info → onnx_diagnostic-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -3,6 +3,344 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
5
5
|
from ..helpers import string_type
|
|
6
|
+
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
|
|
7
|
+
|
|
8
|
+
DYNAMIC_SHAPES = Tuple[Tuple[Any, ...], Dict[str, Any]]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CoupleInputsDynamicShapes:
|
|
12
|
+
"""
|
|
13
|
+
Pair inputs / dynamic shapes.
|
|
14
|
+
|
|
15
|
+
:param args: positional arguments
|
|
16
|
+
:param kwargs: named arguments
|
|
17
|
+
:param dynamic_shapes: dynamic shapes
|
|
18
|
+
:param args_names: if both args and kwargs are not empty, then
|
|
19
|
+
dynamic shapes must be a dictionary, and positional must be added
|
|
20
|
+
to the named arguments. Arguments names or a module must be given
|
|
21
|
+
in that case.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
args: Tuple[Any, ...],
|
|
27
|
+
kwargs: Dict[str, Any],
|
|
28
|
+
dynamic_shapes: DYNAMIC_SHAPES,
|
|
29
|
+
args_names: Optional[Union[torch.nn.Module, List[str]]] = None,
|
|
30
|
+
):
|
|
31
|
+
self.args = args
|
|
32
|
+
self.kwargs = kwargs
|
|
33
|
+
self.dynamic_shapes = dynamic_shapes
|
|
34
|
+
self.args_names = args_names
|
|
35
|
+
|
|
36
|
+
def __str__(self) -> str:
|
|
37
|
+
return "\n".join(
|
|
38
|
+
[
|
|
39
|
+
f"{self.__class__.__name__}(",
|
|
40
|
+
f" args={string_type(self.args, with_shape=True)},"
|
|
41
|
+
f" kwargs={string_type(self.kwargs, with_shape=True)},"
|
|
42
|
+
f" dynamic_shapes={string_type(self.dynamic_shapes, with_shape=True)},"
|
|
43
|
+
f")",
|
|
44
|
+
]
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def replace_string_by(self, value: Any = None):
|
|
48
|
+
"""
|
|
49
|
+
Replaces string by the value ``torch.export.Dim.DYNAMIC``
|
|
50
|
+
(default) or any other value specified by value.
|
|
51
|
+
|
|
52
|
+
Example:
|
|
53
|
+
|
|
54
|
+
.. runpython::
|
|
55
|
+
:showcode:
|
|
56
|
+
|
|
57
|
+
import torch
|
|
58
|
+
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
|
|
59
|
+
|
|
60
|
+
T3x1 = torch.rand((3, 1))
|
|
61
|
+
T3x4 = torch.rand((3, 4))
|
|
62
|
+
ds_batch = {0: "batch"}
|
|
63
|
+
ds_batch_seq = {0: "batch", 1: "seq"}
|
|
64
|
+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
|
|
65
|
+
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
|
|
66
|
+
print(CoupleInputsDynamicShapes((), kwargs, ds).replace_string_by())
|
|
67
|
+
"""
|
|
68
|
+
return self._generic_walker(
|
|
69
|
+
lambda inputs, ds, value=value: self._replace_string_dim_tensor(
|
|
70
|
+
inputs, ds, value=value
|
|
71
|
+
)
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def _replace_string_dim_tensor(cls, inputs, ds, value=None):
|
|
76
|
+
assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
|
|
77
|
+
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
|
|
78
|
+
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
|
|
79
|
+
f"a dictionary is expected to specify a dimension dimension"
|
|
80
|
+
)
|
|
81
|
+
if value is None:
|
|
82
|
+
value = torch.export.Dim.DYNAMIC
|
|
83
|
+
new_ds = ds.copy()
|
|
84
|
+
for i, v in ds.items():
|
|
85
|
+
if isinstance(v, str):
|
|
86
|
+
new_ds[i] = value
|
|
87
|
+
return new_ds
|
|
88
|
+
|
|
89
|
+
def invalid_dimensions_for_export(self):
|
|
90
|
+
"""
|
|
91
|
+
Tells if the inputs are valid based on the dynamic shapes definition.
|
|
92
|
+
The method assumes that all custom classes can be serialized.
|
|
93
|
+
If some patches were applied to export, they should enabled while
|
|
94
|
+
calling this method if the inputs contains such classes.
|
|
95
|
+
|
|
96
|
+
The function checks that a dynamic dimension does not receive a value
|
|
97
|
+
of 0 or 1. It returns the unexpected values in the same structure as
|
|
98
|
+
the given dynamic shapes.
|
|
99
|
+
|
|
100
|
+
Example:
|
|
101
|
+
|
|
102
|
+
.. runpython::
|
|
103
|
+
:showcode:
|
|
104
|
+
|
|
105
|
+
import torch
|
|
106
|
+
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
|
|
107
|
+
|
|
108
|
+
T3x1 = torch.rand((3, 1))
|
|
109
|
+
T3x4 = torch.rand((3, 4))
|
|
110
|
+
ds_batch = {0: "batch"}
|
|
111
|
+
ds_batch_seq = {0: "batch", 1: "seq"}
|
|
112
|
+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
|
|
113
|
+
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
|
|
114
|
+
print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
|
|
115
|
+
|
|
116
|
+
In case it works, it shows:
|
|
117
|
+
|
|
118
|
+
.. runpython::
|
|
119
|
+
:showcode:
|
|
120
|
+
|
|
121
|
+
import torch
|
|
122
|
+
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
|
|
123
|
+
|
|
124
|
+
T3x2 = torch.rand((3, 2))
|
|
125
|
+
T3x4 = torch.rand((3, 4))
|
|
126
|
+
ds_batch = {0: "batch"}
|
|
127
|
+
ds_batch_seq = {0: "batch", 1: "seq"}
|
|
128
|
+
kwargs = {"A": T3x4, "B": (T3x2, T3x2)}
|
|
129
|
+
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
|
|
130
|
+
print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
|
|
131
|
+
"""
|
|
132
|
+
return self._generic_walker(self._valid_shapes_tensor)
|
|
133
|
+
|
|
134
|
+
@classmethod
|
|
135
|
+
def _valid_shapes_tensor(cls, inputs, ds):
|
|
136
|
+
assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
|
|
137
|
+
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
|
|
138
|
+
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
|
|
139
|
+
f"a dictionary is expected to specify a dimension dimension"
|
|
140
|
+
)
|
|
141
|
+
issues = {}
|
|
142
|
+
for i, d in enumerate(inputs.shape):
|
|
143
|
+
if i in ds and not isinstance(ds[i], int):
|
|
144
|
+
# dynamic then
|
|
145
|
+
if d in {0, 1}:
|
|
146
|
+
# export issues for sure
|
|
147
|
+
issues[i] = f"d=[{d}]"
|
|
148
|
+
return issues if issues else None
|
|
149
|
+
|
|
150
|
+
def _generic_walker(self, processor: Callable):
|
|
151
|
+
"""
|
|
152
|
+
Generic deserializator walking through inputs and dynamic_shapes all along.
|
|
153
|
+
The function returns a result with the same structure as the dynamic shapes.
|
|
154
|
+
"""
|
|
155
|
+
if not self.args:
|
|
156
|
+
assert isinstance(self.kwargs, dict) and isinstance(self.dynamic_shapes, dict), (
|
|
157
|
+
f"Type mismatch, args={string_type(self.args)} and "
|
|
158
|
+
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
|
|
159
|
+
)
|
|
160
|
+
return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
|
|
161
|
+
|
|
162
|
+
if not self.kwargs:
|
|
163
|
+
assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), (
|
|
164
|
+
f"Type mismatch, args={string_type(self.args)} and "
|
|
165
|
+
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
|
|
166
|
+
)
|
|
167
|
+
return self._generic_walker_step(processor, self.args, self.dynamic_shapes)
|
|
168
|
+
|
|
169
|
+
assert isinstance(self.dynamic_shapes, dict), (
|
|
170
|
+
f"Both positional and named arguments (args and kwargs) are filled. "
|
|
171
|
+
f"dynamic shapes must a dictionary not {type(self.dynamic_shapes)}"
|
|
172
|
+
)
|
|
173
|
+
if not self.args_names and set(self.dynamic_shapes) & set(self.kwargs) == set(
|
|
174
|
+
self.dynamic_shapes
|
|
175
|
+
):
|
|
176
|
+
# No dynamic shapes for the positional arguments.
|
|
177
|
+
return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
|
|
178
|
+
|
|
179
|
+
if isinstance(self.args_names, list):
|
|
180
|
+
if not set(self.args_names) & set(self.dynamic_shapes):
|
|
181
|
+
# No dynamic shapes for the positional arguments.
|
|
182
|
+
return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
|
|
183
|
+
|
|
184
|
+
assert self.args_names, (
|
|
185
|
+
"args and kwargs are filled, then args_names must be specified in "
|
|
186
|
+
"the constructor to move positional arguments to named arguments."
|
|
187
|
+
)
|
|
188
|
+
assert len(self.args) <= len(self.args_names), (
|
|
189
|
+
f"There are {len(self.args)} positional arguments "
|
|
190
|
+
f"but only {len(self.args_names)} names. "
|
|
191
|
+
f"args={string_type(self.args, with_shape=True)}, args_name={self.args_names}"
|
|
192
|
+
)
|
|
193
|
+
kwargs = dict(zip(self.args_names, self.args))
|
|
194
|
+
kwargs.update(self.kwargs)
|
|
195
|
+
return self._generic_walker_step(processor, kwargs, self.dynamic_shapes)
|
|
196
|
+
|
|
197
|
+
raise NotImplementedError(
|
|
198
|
+
f"Not yet implemented when args is filled, "
|
|
199
|
+
f"kwargs as well but args_names is {type(self.args_names)}"
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
@classmethod
|
|
203
|
+
def _generic_walker_step(cls, processor: Callable, inputs, ds):
|
|
204
|
+
if isinstance(inputs, torch.Tensor):
|
|
205
|
+
return processor(inputs, ds)
|
|
206
|
+
if isinstance(inputs, (int, float, str)):
|
|
207
|
+
return None
|
|
208
|
+
if isinstance(inputs, (tuple, list, dict)):
|
|
209
|
+
assert type(ds) is type(
|
|
210
|
+
inputs
|
|
211
|
+
), f"Type mismatch between inputs {type(inputs)} and ds={type(ds)}"
|
|
212
|
+
assert len(ds) == len(inputs), (
|
|
213
|
+
f"Length mismatch between inputs {len(inputs)} "
|
|
214
|
+
f"and ds={len(ds)}\n"
|
|
215
|
+
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
|
|
216
|
+
)
|
|
217
|
+
if isinstance(inputs, (tuple, list)):
|
|
218
|
+
value = []
|
|
219
|
+
for i, d in zip(inputs, ds):
|
|
220
|
+
value.append(cls._generic_walker_step(processor, i, d))
|
|
221
|
+
return (
|
|
222
|
+
(value if isinstance(ds, list) else tuple(value))
|
|
223
|
+
if any(v is not None for v in value)
|
|
224
|
+
else None
|
|
225
|
+
)
|
|
226
|
+
assert set(inputs) == set(
|
|
227
|
+
ds
|
|
228
|
+
), f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}"
|
|
229
|
+
dvalue = {}
|
|
230
|
+
for k, v in inputs.items():
|
|
231
|
+
t = cls._generic_walker_step(processor, v, ds[k])
|
|
232
|
+
if t is not None:
|
|
233
|
+
dvalue[k] = t
|
|
234
|
+
return dvalue if dvalue else None
|
|
235
|
+
|
|
236
|
+
# A custom class.
|
|
237
|
+
assert inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, (
|
|
238
|
+
f"Class {inputs.__class__.__name__!r} was not registered using "
|
|
239
|
+
f"torch.utils._pytree.register_pytree_node, it is not possible to "
|
|
240
|
+
f"map this class with the given dynamic shapes."
|
|
241
|
+
)
|
|
242
|
+
flat, _spec = torch.utils._pytree.tree_flatten(inputs)
|
|
243
|
+
return cls._generic_walker_step(processor, flat, ds)
|
|
244
|
+
|
|
245
|
+
class ChangeDimensionProcessor:
|
|
246
|
+
def __init__(self):
|
|
247
|
+
self.mapping = {}
|
|
248
|
+
|
|
249
|
+
def _build_new_shape(
|
|
250
|
+
self, shape: Tuple[int, ...], ds: Dict[int, Any]
|
|
251
|
+
) -> Tuple[int, ...]:
|
|
252
|
+
new_shape = list(shape)
|
|
253
|
+
for i in range(len(shape)):
|
|
254
|
+
if i in ds:
|
|
255
|
+
if isinstance(ds[i], str):
|
|
256
|
+
d = ds[i]
|
|
257
|
+
elif isinstance(
|
|
258
|
+
ds[i],
|
|
259
|
+
(
|
|
260
|
+
torch.export.dynamic_shapes._DerivedDim,
|
|
261
|
+
torch.export.dynamic_shapes._Dim,
|
|
262
|
+
),
|
|
263
|
+
):
|
|
264
|
+
d = str(ds[i])
|
|
265
|
+
elif not isinstance(ds[i], int):
|
|
266
|
+
raise NotImplementedError(f"Unable to handle type {ds[i]} in {ds}")
|
|
267
|
+
if d in self.mapping:
|
|
268
|
+
new_dim = self.mapping[d]
|
|
269
|
+
else:
|
|
270
|
+
new_dim = shape[i] + 1
|
|
271
|
+
self.mapping[d] = new_dim
|
|
272
|
+
new_shape[i] = new_dim
|
|
273
|
+
return tuple(new_shape)
|
|
274
|
+
|
|
275
|
+
def _build_new_tensor(self, tensor: torch.Tensor, new_shape: Tuple[int, ...]):
|
|
276
|
+
rank = len(tensor.shape)
|
|
277
|
+
for i in range(len(tensor.shape)):
|
|
278
|
+
d0 = tensor.shape[i]
|
|
279
|
+
d1 = new_shape[i]
|
|
280
|
+
if d0 == d1:
|
|
281
|
+
continue
|
|
282
|
+
alt_shape = list(tensor.shape)
|
|
283
|
+
alt_shape[i] = d1
|
|
284
|
+
new_tensor = torch.zeros(
|
|
285
|
+
tuple(alt_shape), dtype=tensor.dtype, device=tensor.device
|
|
286
|
+
)
|
|
287
|
+
mind = min(d0, d1)
|
|
288
|
+
indices = [slice(None) for _ in range(rank)]
|
|
289
|
+
indices[i] = slice(0, mind)
|
|
290
|
+
ind = tuple(indices)
|
|
291
|
+
new_tensor[ind] = tensor[ind]
|
|
292
|
+
if d1 > mind:
|
|
293
|
+
for k in range(d1 - mind):
|
|
294
|
+
indices0 = [slice(None) for _ in range(rank)]
|
|
295
|
+
indices1 = [slice(None) for _ in range(rank)]
|
|
296
|
+
indices1[i] = mind + k
|
|
297
|
+
indices0[i] = k % mind
|
|
298
|
+
new_tensor[tuple(indices1)] = tensor[tuple(indices0)]
|
|
299
|
+
tensor = new_tensor
|
|
300
|
+
return tensor
|
|
301
|
+
|
|
302
|
+
def __call__(self, inputs, ds):
|
|
303
|
+
assert isinstance(
|
|
304
|
+
inputs, torch.Tensor
|
|
305
|
+
), f"unexpected type for inputs {type(inputs)}"
|
|
306
|
+
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
|
|
307
|
+
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
|
|
308
|
+
f"a dictionary is expected to specify a dimension dimension"
|
|
309
|
+
)
|
|
310
|
+
new_shape = self._build_new_shape(inputs.shape, ds)
|
|
311
|
+
return self._build_new_tensor(inputs, new_shape)
|
|
312
|
+
|
|
313
|
+
def change_dynamic_dimensions(self):
|
|
314
|
+
"""
|
|
315
|
+
A model exported with dynamic shapes is not necessarily dynamic
|
|
316
|
+
just because the user specified dynamic shapes. The algorithm
|
|
317
|
+
may discover that a dimension cannot be dynamic and then continues
|
|
318
|
+
the export making the assumption it is static. That may lead a wrong
|
|
319
|
+
model. This function produces a new set of inputs with different values
|
|
320
|
+
for the dimension than the first ones, assuming they were used to export
|
|
321
|
+
the model.
|
|
322
|
+
|
|
323
|
+
Example:
|
|
324
|
+
|
|
325
|
+
.. runpython::
|
|
326
|
+
:showcode:
|
|
327
|
+
|
|
328
|
+
import torch
|
|
329
|
+
from onnx_diagnostic.helpers import string_type
|
|
330
|
+
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
|
|
331
|
+
|
|
332
|
+
T3x15 = torch.rand((3, 15))
|
|
333
|
+
T3x20 = torch.rand((3, 20))
|
|
334
|
+
T3x4 = torch.rand((3, 4))
|
|
335
|
+
ds_batch = {0: "batch"}
|
|
336
|
+
ds_batch_seq = {0: "batch", 1: "seq"}
|
|
337
|
+
kwargs = {"A": T3x4, "B": (T3x15, T3x20)}
|
|
338
|
+
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
|
|
339
|
+
new_kwargs = CoupleInputsDynamicShapes((), kwargs, ds).change_dynamic_dimensions()
|
|
340
|
+
print("before:", string_type(kwargs, with_shape=True))
|
|
341
|
+
print("-after:", string_type(new_kwargs, with_shape=True))
|
|
342
|
+
"""
|
|
343
|
+
return self._generic_walker(self.ChangeDimensionProcessor())
|
|
6
344
|
|
|
7
345
|
|
|
8
346
|
class ModelInputs:
|
|
@@ -69,7 +407,7 @@ class ModelInputs:
|
|
|
69
407
|
ds = mi.guess_dynamic_shapes()
|
|
70
408
|
pprint.pprint(ds)
|
|
71
409
|
|
|
72
|
-
**
|
|
410
|
+
**args and kwargs**
|
|
73
411
|
|
|
74
412
|
.. runpython::
|
|
75
413
|
:showcode:
|
|
@@ -218,7 +556,7 @@ class ModelInputs:
|
|
|
218
556
|
return new_inputs
|
|
219
557
|
|
|
220
558
|
@property
|
|
221
|
-
def true_model_name(self):
|
|
559
|
+
def true_model_name(self) -> str:
|
|
222
560
|
"Returns class name or module name."
|
|
223
561
|
return (
|
|
224
562
|
self.model.__class__.__name__
|
|
@@ -227,7 +565,7 @@ class ModelInputs:
|
|
|
227
565
|
)
|
|
228
566
|
|
|
229
567
|
@property
|
|
230
|
-
def full_name(self):
|
|
568
|
+
def full_name(self) -> str:
|
|
231
569
|
"Returns a name and class name."
|
|
232
570
|
if self.method_name == "forward":
|
|
233
571
|
return f"{self.name}:{self.true_model_name}"
|
|
@@ -240,10 +578,27 @@ class ModelInputs:
|
|
|
240
578
|
return f"type({self.name})={self.true_model_name}"
|
|
241
579
|
return f"type({self.name})={self.true_model_name}.{self.method_name}"
|
|
242
580
|
|
|
243
|
-
def guess_dynamic_dimensions(
|
|
244
|
-
|
|
581
|
+
def guess_dynamic_dimensions(
|
|
582
|
+
self, *tensors, auto: bool = False
|
|
583
|
+
) -> Optional[Dict[int, Any]]:
|
|
584
|
+
"""
|
|
585
|
+
Infers the dynamic dimension from multiple shapes.
|
|
586
|
+
If auto is True, it returns ``torch.export.Dim.AUTO`` for every dimension
|
|
587
|
+
which cannot be guessed. Two tensors with the same value for one dimension
|
|
588
|
+
can be guessed, but if there is only 1, it cannot.
|
|
589
|
+
"""
|
|
245
590
|
if len(tensors) == 1:
|
|
246
|
-
|
|
591
|
+
if isinstance(tensors[0], (int, float)):
|
|
592
|
+
return None
|
|
593
|
+
assert isinstance(tensors[0], torch.Tensor), (
|
|
594
|
+
f"Unexpected type for tensors {string_type(tensors, with_shape=True)}, "
|
|
595
|
+
f"Only tensors are allowed."
|
|
596
|
+
)
|
|
597
|
+
return (
|
|
598
|
+
{i: torch.export.Dim.AUTO for i in range(len(tensors[0].shape))} # noqa: C420
|
|
599
|
+
if auto
|
|
600
|
+
else {}
|
|
601
|
+
)
|
|
247
602
|
shapes = [t.shape for t in tensors]
|
|
248
603
|
set_length = set(len(s) for s in shapes)
|
|
249
604
|
assert len(set_length) == 1, (
|
|
@@ -265,7 +620,9 @@ class ModelInputs:
|
|
|
265
620
|
continue
|
|
266
621
|
return res
|
|
267
622
|
|
|
268
|
-
def guess_dynamic_shape_object(
|
|
623
|
+
def guess_dynamic_shape_object(
|
|
624
|
+
self, *objs: Any, auto: bool = False, msg: Optional[Callable] = None
|
|
625
|
+
) -> Any:
|
|
269
626
|
"""Guesses the dynamic shapes for one argument."""
|
|
270
627
|
if len(objs) == 0:
|
|
271
628
|
return None
|
|
@@ -279,7 +636,7 @@ class ModelInputs:
|
|
|
279
636
|
if isinstance(obj, (bool, int, float, str)):
|
|
280
637
|
return None
|
|
281
638
|
if isinstance(obj, (torch.Tensor, np.ndarray)):
|
|
282
|
-
return self.guess_dynamic_dimensions(*objs)
|
|
639
|
+
return self.guess_dynamic_dimensions(*objs, auto=auto)
|
|
283
640
|
|
|
284
641
|
if isinstance(obj, tuple):
|
|
285
642
|
kl = set(len(o) for o in objs)
|
|
@@ -288,7 +645,9 @@ class ModelInputs:
|
|
|
288
645
|
), f"Unexpected variety of tuple lengths {kl}{msg() if msg else ''}"
|
|
289
646
|
shapes: Any = []
|
|
290
647
|
for i in range(kl.pop()):
|
|
291
|
-
shapes.append(
|
|
648
|
+
shapes.append(
|
|
649
|
+
self.guess_dynamic_shape_object(*[o[i] for o in objs], auto=auto, msg=msg)
|
|
650
|
+
)
|
|
292
651
|
return tuple(shapes)
|
|
293
652
|
|
|
294
653
|
if isinstance(obj, list):
|
|
@@ -298,7 +657,9 @@ class ModelInputs:
|
|
|
298
657
|
), f"Unexpected variety of list lengths {kl}{msg() if msg else ''}"
|
|
299
658
|
shapes = []
|
|
300
659
|
for i in range(kl.pop()):
|
|
301
|
-
shapes.append(
|
|
660
|
+
shapes.append(
|
|
661
|
+
self.guess_dynamic_shape_object(*[o[i] for o in objs], auto=auto, msg=msg)
|
|
662
|
+
)
|
|
302
663
|
return shapes
|
|
303
664
|
|
|
304
665
|
if isinstance(obj, dict):
|
|
@@ -308,9 +669,33 @@ class ModelInputs:
|
|
|
308
669
|
), f"Unexpected variety of dict lengths {kl}{msg() if msg else ''}"
|
|
309
670
|
shapes = {}
|
|
310
671
|
for i in obj:
|
|
311
|
-
shapes[i] = self.guess_dynamic_shape_object(
|
|
672
|
+
shapes[i] = self.guess_dynamic_shape_object(
|
|
673
|
+
*[o[i] for o in objs], auto=auto, msg=msg
|
|
674
|
+
)
|
|
312
675
|
return shapes
|
|
313
676
|
|
|
677
|
+
if obj.__class__ in torch.utils._pytree.SUPPORTED_NODES:
|
|
678
|
+
kcl = set(o.__class__ for o in objs)
|
|
679
|
+
assert len(kcl) == 1, (
|
|
680
|
+
f"All instances of argument {i} are not of the same class but {kcl}, "
|
|
681
|
+
f"types should be the same."
|
|
682
|
+
)
|
|
683
|
+
col_args = [flatten_unflatten_for_dynamic_shapes(o) for o in objs]
|
|
684
|
+
kc = set(len(o) for o in col_args)
|
|
685
|
+
assert len(kc) == 1, (
|
|
686
|
+
f"All instances of type {kcl.pop()} are not serialized into the same number "
|
|
687
|
+
f"of arguments, it should be the same."
|
|
688
|
+
)
|
|
689
|
+
values = []
|
|
690
|
+
for i in range(kc.pop()):
|
|
691
|
+
values.append(
|
|
692
|
+
self.guess_dynamic_shape_object(
|
|
693
|
+
*[ca[i] for ca in col_args], auto=auto, msg=msg
|
|
694
|
+
)
|
|
695
|
+
)
|
|
696
|
+
return values
|
|
697
|
+
|
|
698
|
+
# In case DynamicCache is not registered.
|
|
314
699
|
if obj.__class__.__name__ == "DynamicCache":
|
|
315
700
|
kc = set(len(o.key_cache) for o in objs)
|
|
316
701
|
assert (
|
|
@@ -323,34 +708,39 @@ class ModelInputs:
|
|
|
323
708
|
key_cache = []
|
|
324
709
|
for i in range(kc.pop()):
|
|
325
710
|
key_cache.append(
|
|
326
|
-
self.guess_dynamic_dimensions(*[o.key_cache[i] for o in objs])
|
|
711
|
+
self.guess_dynamic_dimensions(*[o.key_cache[i] for o in objs], auto=auto)
|
|
327
712
|
)
|
|
328
713
|
value_cache = []
|
|
329
714
|
for i in range(vc.pop()):
|
|
330
715
|
value_cache.append(
|
|
331
|
-
self.guess_dynamic_dimensions(*[o.value_cache[i] for o in objs])
|
|
716
|
+
self.guess_dynamic_dimensions(*[o.value_cache[i] for o in objs], auto=auto)
|
|
332
717
|
)
|
|
333
718
|
return [key_cache, value_cache]
|
|
334
719
|
|
|
335
720
|
raise NotImplementedError(
|
|
336
721
|
f"Unable to build dynamic shapes for type {set_types.pop()}: "
|
|
337
|
-
f"{string_type(objs)}{msg() if msg else ''} in {self.module_name_type}"
|
|
722
|
+
f"{string_type(objs)}{msg() if msg else ''} in {self.module_name_type}, "
|
|
723
|
+
f"this object needs serialization function to be registered."
|
|
338
724
|
)
|
|
339
725
|
|
|
340
|
-
def guess_dynamic_shapes(
|
|
341
|
-
self,
|
|
342
|
-
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
|
|
726
|
+
def guess_dynamic_shapes(self, auto: bool = False) -> DYNAMIC_SHAPES:
|
|
343
727
|
"""
|
|
344
728
|
Guesses the dynamic shapes for that module from two execution.
|
|
345
729
|
If there is only one execution, then that would be static dimensions.
|
|
730
|
+
|
|
731
|
+
:param auto: if auto is True, use ``torch.export.Dim.AUTO`` for any
|
|
732
|
+
dimension if the number of inputs is one
|
|
346
733
|
"""
|
|
347
734
|
if len(self.inputs) == 0:
|
|
348
735
|
# No inputs, unable to guess.
|
|
349
736
|
return (tuple(), {})
|
|
350
737
|
if len(self.inputs) == 1:
|
|
351
738
|
# No dynamic shapes.
|
|
352
|
-
return tuple(
|
|
353
|
-
|
|
739
|
+
return tuple(
|
|
740
|
+
self.guess_dynamic_shape_object(a, auto=auto) for a in self.inputs[0][0]
|
|
741
|
+
), {
|
|
742
|
+
k: self.guess_dynamic_shape_object(v, auto=auto)
|
|
743
|
+
for k, v in self.inputs[0][1].items()
|
|
354
744
|
}
|
|
355
745
|
|
|
356
746
|
# Otherwise.
|
|
@@ -365,7 +755,9 @@ class ModelInputs:
|
|
|
365
755
|
for i in range(s1.pop()):
|
|
366
756
|
objs = [_[0][i] for _ in self.inputs]
|
|
367
757
|
args.append(
|
|
368
|
-
self.guess_dynamic_shape_object(
|
|
758
|
+
self.guess_dynamic_shape_object(
|
|
759
|
+
*objs, auto=auto, msg=lambda i=i: f" failing input {i}"
|
|
760
|
+
)
|
|
369
761
|
)
|
|
370
762
|
names = s2.pop()
|
|
371
763
|
for name in names:
|
|
@@ -377,7 +769,7 @@ class ModelInputs:
|
|
|
377
769
|
|
|
378
770
|
objs = [_[1][name] for _ in self.inputs]
|
|
379
771
|
kwargs[name] = self.guess_dynamic_shape_object(
|
|
380
|
-
*objs, msg=lambda name=name: f" failing input {name!r}"
|
|
772
|
+
*objs, auto=auto, msg=lambda name=name: f" failing input {name!r}"
|
|
381
773
|
)
|
|
382
774
|
return tuple(args), kwargs
|
|
383
775
|
|
|
@@ -386,7 +778,7 @@ class ModelInputs:
|
|
|
386
778
|
args: Tuple[Any, ...],
|
|
387
779
|
kwargs: Dict[str, Any],
|
|
388
780
|
dynamic_shapes: Tuple[Tuple[Any, ...], Dict[str, Any]],
|
|
389
|
-
) -> Tuple[Tuple[Any, ...], Dict[str, Any],
|
|
781
|
+
) -> Tuple[Tuple[Any, ...], Dict[str, Any], DYNAMIC_SHAPES]:
|
|
390
782
|
"""
|
|
391
783
|
Uses the signatures to move positional arguments (args) to named arguments (kwargs)
|
|
392
784
|
with the corresponding dynamic shapes.
|
|
@@ -434,3 +826,22 @@ class ModelInputs:
|
|
|
434
826
|
f"forward_ordered_parameter_names={self.forward_ordered_parameter_names}"
|
|
435
827
|
)
|
|
436
828
|
return args, kwargs, (tuple(), kw_dyn)
|
|
829
|
+
|
|
830
|
+
def validate_inputs_for_export(
|
|
831
|
+
self, dynamic_shapes: Optional[DYNAMIC_SHAPES] = None
|
|
832
|
+
) -> List[List[Union[int, str]]]:
|
|
833
|
+
"""
|
|
834
|
+
Validates the inputs the class contains for the given dynamic shapes.
|
|
835
|
+
If not specified, the dynamic_shapes are guessed.
|
|
836
|
+
|
|
837
|
+
:param dynamic_shapes: dynamic shapes to validate
|
|
838
|
+
:return: a list of lists, every list contains the path the invalid dimension
|
|
839
|
+
"""
|
|
840
|
+
if dynamic_shapes is None:
|
|
841
|
+
if len(self.inputs) == 1:
|
|
842
|
+
return []
|
|
843
|
+
dyn_shapes = self.guess_dynamic_shapes()
|
|
844
|
+
return [
|
|
845
|
+
CoupleInputsDynamicShapes(*i, dyn_shapes).invalid_dimensions_for_export()
|
|
846
|
+
for i in self.inputs
|
|
847
|
+
]
|