onnx-diagnostic 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +45 -4
- onnx_diagnostic/export/__init__.py +1 -0
- onnx_diagnostic/export/dynamic_shapes.py +169 -29
- onnx_diagnostic/export/validate.py +170 -0
- onnx_diagnostic/ext_test_case.py +66 -3
- onnx_diagnostic/helpers/cache_helper.py +81 -5
- onnx_diagnostic/helpers/config_helper.py +80 -0
- onnx_diagnostic/helpers/helper.py +283 -81
- onnx_diagnostic/helpers/ort_session.py +1 -39
- onnx_diagnostic/helpers/rt_helper.py +47 -0
- onnx_diagnostic/helpers/torch_test_helper.py +14 -3
- onnx_diagnostic/tasks/__init__.py +48 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +165 -0
- onnx_diagnostic/tasks/fill_mask.py +67 -0
- onnx_diagnostic/tasks/image_classification.py +96 -0
- onnx_diagnostic/tasks/image_text_to_text.py +145 -0
- onnx_diagnostic/tasks/sentence_similarity.py +67 -0
- onnx_diagnostic/tasks/text2text_generation.py +172 -0
- onnx_diagnostic/tasks/text_classification.py +67 -0
- onnx_diagnostic/tasks/text_generation.py +248 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +106 -0
- onnx_diagnostic/torch_export_patches/__init__.py +0 -107
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +21 -160
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +339 -61
- onnx_diagnostic/torch_export_patches/patch_inputs.py +29 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +29 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +37 -47
- onnx_diagnostic/torch_models/hghub/hub_api.py +26 -6
- onnx_diagnostic/torch_models/hghub/hub_data.py +22 -14
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +337 -25
- onnx_diagnostic/torch_models/hghub/model_inputs.py +32 -608
- onnx_diagnostic/torch_models/test_helper.py +651 -228
- {onnx_diagnostic-0.3.0.dist-info → onnx_diagnostic-0.4.0.dist-info}/METADATA +13 -3
- {onnx_diagnostic-0.3.0.dist-info → onnx_diagnostic-0.4.0.dist-info}/RECORD +38 -25
- {onnx_diagnostic-0.3.0.dist-info → onnx_diagnostic-0.4.0.dist-info}/WHEEL +1 -1
- {onnx_diagnostic-0.3.0.dist-info → onnx_diagnostic-0.4.0.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.3.0.dist-info → onnx_diagnostic-0.4.0.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import argparse
|
|
1
2
|
import json
|
|
2
3
|
import sys
|
|
3
4
|
import textwrap
|
|
@@ -227,6 +228,21 @@ def _cmd_config(argv: List[Any]):
|
|
|
227
228
|
print(f"task: {task_from_id(args.mid)}")
|
|
228
229
|
|
|
229
230
|
|
|
231
|
+
class _ParseDict(argparse.Action):
|
|
232
|
+
def __call__(self, parser, namespace, values, option_string=None):
|
|
233
|
+
d = getattr(namespace, self.dest) or {}
|
|
234
|
+
|
|
235
|
+
if values:
|
|
236
|
+
for item in values:
|
|
237
|
+
split_items = item.split("=", 1)
|
|
238
|
+
key = split_items[0].strip() # we remove blanks around keys, as is logical
|
|
239
|
+
value = split_items[1]
|
|
240
|
+
|
|
241
|
+
d[key] = value
|
|
242
|
+
|
|
243
|
+
setattr(namespace, self.dest, d)
|
|
244
|
+
|
|
245
|
+
|
|
230
246
|
def get_parser_validate() -> ArgumentParser:
|
|
231
247
|
parser = ArgumentParser(
|
|
232
248
|
prog="test",
|
|
@@ -287,22 +303,37 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
287
303
|
help="drops the following inputs names, it should be a list "
|
|
288
304
|
"with comma separated values",
|
|
289
305
|
)
|
|
306
|
+
parser.add_argument(
|
|
307
|
+
"--ortfusiontype",
|
|
308
|
+
required=False,
|
|
309
|
+
help="applies onnxruntime fusion, this parameter should contain the "
|
|
310
|
+
"model type or multiple values separated by `|`. `ALL` can be used "
|
|
311
|
+
"to run them all",
|
|
312
|
+
)
|
|
290
313
|
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
|
|
291
314
|
parser.add_argument("--dtype", help="changes dtype if necessary")
|
|
292
315
|
parser.add_argument("--device", help="changes the device if necessary")
|
|
316
|
+
parser.add_argument(
|
|
317
|
+
"--iop",
|
|
318
|
+
metavar="KEY=VALUE",
|
|
319
|
+
nargs="*",
|
|
320
|
+
help="Additional input options, use to change the default "
|
|
321
|
+
"inputs use to export, example: --iop cls_cache=SlidingWindowCache",
|
|
322
|
+
action=_ParseDict,
|
|
323
|
+
)
|
|
293
324
|
return parser
|
|
294
325
|
|
|
295
326
|
|
|
296
327
|
def _cmd_validate(argv: List[Any]):
|
|
297
328
|
from .helpers import string_type
|
|
298
|
-
from .torch_models.test_helper import get_inputs_for_task, validate_model
|
|
299
|
-
from .
|
|
329
|
+
from .torch_models.test_helper import get_inputs_for_task, validate_model
|
|
330
|
+
from .tasks import supported_tasks
|
|
300
331
|
|
|
301
332
|
parser = get_parser_validate()
|
|
302
333
|
args = parser.parse_args(argv[1:])
|
|
303
334
|
if not args.task and not args.mid:
|
|
304
335
|
print("-- list of supported tasks:")
|
|
305
|
-
print("\n".join(
|
|
336
|
+
print("\n".join(supported_tasks()))
|
|
306
337
|
elif not args.mid:
|
|
307
338
|
data = get_inputs_for_task(args.task)
|
|
308
339
|
if args.verbose:
|
|
@@ -313,8 +344,16 @@ def _cmd_validate(argv: List[Any]):
|
|
|
313
344
|
print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}")
|
|
314
345
|
print("-- dynamic_shapes")
|
|
315
346
|
for k, v in data["dynamic_shapes"].items():
|
|
316
|
-
print(f" + {k.ljust(max_length)}: {
|
|
347
|
+
print(f" + {k.ljust(max_length)}: {string_type(v)}")
|
|
317
348
|
else:
|
|
349
|
+
# Let's skip any invalid combination if known to be unsupported
|
|
350
|
+
if (
|
|
351
|
+
"onnx" not in (args.export or "")
|
|
352
|
+
and "custom" not in (args.export or "")
|
|
353
|
+
and (args.opt or "")
|
|
354
|
+
):
|
|
355
|
+
print(f"validate - unsupported args: export={args.export!r}, opt={args.opt!r}")
|
|
356
|
+
return
|
|
318
357
|
summary, _data = validate_model(
|
|
319
358
|
model_id=args.mid,
|
|
320
359
|
task=args.task,
|
|
@@ -330,6 +369,8 @@ def _cmd_validate(argv: List[Any]):
|
|
|
330
369
|
exporter=args.export,
|
|
331
370
|
dump_folder=args.dump_folder,
|
|
332
371
|
drop_inputs=None if not args.drop else args.drop.split(","),
|
|
372
|
+
ortfusiontype=args.ortfusiontype,
|
|
373
|
+
input_options=args.iop,
|
|
333
374
|
)
|
|
334
375
|
print("")
|
|
335
376
|
print("-- summary --")
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
2
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
5
5
|
from ..helpers import string_type
|
|
@@ -8,6 +8,30 @@ from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
|
|
|
8
8
|
DYNAMIC_SHAPES = Tuple[Tuple[Any, ...], Dict[str, Any]]
|
|
9
9
|
|
|
10
10
|
|
|
11
|
+
def flatten_dynamic_shapes(ds: Any) -> Any:
|
|
12
|
+
"""Flattens the dynamic shapes."""
|
|
13
|
+
if isinstance(ds, list):
|
|
14
|
+
return _flat_list([flatten_dynamic_shapes(t) for t in ds])
|
|
15
|
+
if isinstance(ds, tuple):
|
|
16
|
+
return tuple(_flat_list([flatten_dynamic_shapes(t) for t in ds]))
|
|
17
|
+
if isinstance(ds, dict):
|
|
18
|
+
if all(isinstance(i, int) for i in ds):
|
|
19
|
+
# That's a dynamic shape
|
|
20
|
+
return ds
|
|
21
|
+
return _flat_list([flatten_dynamic_shapes(t) for t in ds.values()])
|
|
22
|
+
raise AssertionError(f"Not implemented for {type(ds)}: {ds}")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _flat_list(li: List[Any]) -> List[Dict[int, str]]:
|
|
26
|
+
res = []
|
|
27
|
+
for t in li:
|
|
28
|
+
if isinstance(t, dict):
|
|
29
|
+
res.append(t)
|
|
30
|
+
else:
|
|
31
|
+
res.extend(t)
|
|
32
|
+
return res
|
|
33
|
+
|
|
34
|
+
|
|
11
35
|
class CoupleInputsDynamicShapes:
|
|
12
36
|
"""
|
|
13
37
|
Pair inputs / dynamic shapes.
|
|
@@ -68,7 +92,8 @@ class CoupleInputsDynamicShapes:
|
|
|
68
92
|
return self._generic_walker(
|
|
69
93
|
lambda inputs, ds, value=value: self._replace_string_dim_tensor(
|
|
70
94
|
inputs, ds, value=value
|
|
71
|
-
)
|
|
95
|
+
),
|
|
96
|
+
flatten_unflatten=True,
|
|
72
97
|
)
|
|
73
98
|
|
|
74
99
|
@classmethod
|
|
@@ -76,7 +101,7 @@ class CoupleInputsDynamicShapes:
|
|
|
76
101
|
assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
|
|
77
102
|
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
|
|
78
103
|
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
|
|
79
|
-
f"a dictionary is expected to specify a dimension
|
|
104
|
+
f"a dictionary is expected to specify a dimension"
|
|
80
105
|
)
|
|
81
106
|
if value is None:
|
|
82
107
|
value = torch.export.Dim.DYNAMIC
|
|
@@ -86,6 +111,57 @@ class CoupleInputsDynamicShapes:
|
|
|
86
111
|
new_ds[i] = value
|
|
87
112
|
return new_ds
|
|
88
113
|
|
|
114
|
+
def replace_by_string(self):
|
|
115
|
+
"""
|
|
116
|
+
Replaces dimensions by strings.
|
|
117
|
+
|
|
118
|
+
Example:
|
|
119
|
+
|
|
120
|
+
.. runpython::
|
|
121
|
+
:showcode:
|
|
122
|
+
|
|
123
|
+
import torch
|
|
124
|
+
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
|
|
125
|
+
|
|
126
|
+
Dim = torch.export.Dim
|
|
127
|
+
T3x1 = torch.rand((3, 1))
|
|
128
|
+
T3x4 = torch.rand((3, 4))
|
|
129
|
+
ds_batch = {0: Dim("batch")}
|
|
130
|
+
ds_batch_seq = {0: Dim("batch"), 1: Dim("seq")}
|
|
131
|
+
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
|
|
132
|
+
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
|
|
133
|
+
print(CoupleInputsDynamicShapes((), kwargs, ds).replace_by_string())
|
|
134
|
+
"""
|
|
135
|
+
unique = set()
|
|
136
|
+
return self._generic_walker(
|
|
137
|
+
lambda inputs, ds, unique=unique: self._replace_dim_tensor_by_string(
|
|
138
|
+
inputs, ds, unique=unique
|
|
139
|
+
),
|
|
140
|
+
flatten_unflatten=True,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
@classmethod
|
|
144
|
+
def _replace_dim_tensor_by_string(cls, inputs, ds, unique: Set[str]):
|
|
145
|
+
assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
|
|
146
|
+
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
|
|
147
|
+
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
|
|
148
|
+
f"a dictionary is expected to specify a dimension"
|
|
149
|
+
)
|
|
150
|
+
new_ds = ds.copy()
|
|
151
|
+
for i, v in ds.items():
|
|
152
|
+
if isinstance(v, str):
|
|
153
|
+
unique.add(v)
|
|
154
|
+
new_ds[i] = v
|
|
155
|
+
elif v in (torch.export.Dim.DYNAMIC, torch.export.Dim.AUTO):
|
|
156
|
+
name = f"Dim{len(unique)}"
|
|
157
|
+
new_ds[i] = name
|
|
158
|
+
unique.add(name)
|
|
159
|
+
else:
|
|
160
|
+
name = v.__name__
|
|
161
|
+
unique.add(name)
|
|
162
|
+
new_ds[i] = name
|
|
163
|
+
return new_ds
|
|
164
|
+
|
|
89
165
|
def invalid_dimensions_for_export(self):
|
|
90
166
|
"""
|
|
91
167
|
Tells if the inputs are valid based on the dynamic shapes definition.
|
|
@@ -129,7 +205,7 @@ class CoupleInputsDynamicShapes:
|
|
|
129
205
|
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
|
|
130
206
|
print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
|
|
131
207
|
"""
|
|
132
|
-
return self._generic_walker(self._valid_shapes_tensor)
|
|
208
|
+
return self._generic_walker(self._valid_shapes_tensor, flatten_unflatten=True)
|
|
133
209
|
|
|
134
210
|
@classmethod
|
|
135
211
|
def _valid_shapes_tensor(cls, inputs, ds):
|
|
@@ -147,7 +223,9 @@ class CoupleInputsDynamicShapes:
|
|
|
147
223
|
issues[i] = f"d=[{d}]"
|
|
148
224
|
return issues if issues else None
|
|
149
225
|
|
|
150
|
-
def _generic_walker(
|
|
226
|
+
def _generic_walker(
|
|
227
|
+
self, processor: Callable, args_kwargs: bool = False, flatten_unflatten: bool = False
|
|
228
|
+
):
|
|
151
229
|
"""
|
|
152
230
|
Generic deserializator walking through inputs and dynamic_shapes all along.
|
|
153
231
|
The function returns a result with the same structure as the dynamic shapes.
|
|
@@ -157,14 +235,23 @@ class CoupleInputsDynamicShapes:
|
|
|
157
235
|
f"Type mismatch, args={string_type(self.args)} and "
|
|
158
236
|
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
|
|
159
237
|
)
|
|
160
|
-
|
|
238
|
+
res = self._generic_walker_step(
|
|
239
|
+
processor,
|
|
240
|
+
self.kwargs,
|
|
241
|
+
self.dynamic_shapes,
|
|
242
|
+
flatten_unflatten=flatten_unflatten,
|
|
243
|
+
)
|
|
244
|
+
return (tuple(), res) if args_kwargs else res
|
|
161
245
|
|
|
162
246
|
if not self.kwargs:
|
|
163
247
|
assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), (
|
|
164
248
|
f"Type mismatch, args={string_type(self.args)} and "
|
|
165
249
|
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
|
|
166
250
|
)
|
|
167
|
-
|
|
251
|
+
res = self._generic_walker_step(
|
|
252
|
+
processor, self.args, self.dynamic_shapes, flatten_unflatten=flatten_unflatten
|
|
253
|
+
)
|
|
254
|
+
return (res, {}) if args_kwargs else res
|
|
168
255
|
|
|
169
256
|
assert isinstance(self.dynamic_shapes, dict), (
|
|
170
257
|
f"Both positional and named arguments (args and kwargs) are filled. "
|
|
@@ -174,12 +261,22 @@ class CoupleInputsDynamicShapes:
|
|
|
174
261
|
self.dynamic_shapes
|
|
175
262
|
):
|
|
176
263
|
# No dynamic shapes for the positional arguments.
|
|
177
|
-
return self._generic_walker_step(
|
|
264
|
+
return self._generic_walker_step(
|
|
265
|
+
processor,
|
|
266
|
+
self.kwargs,
|
|
267
|
+
self.dynamic_shapes,
|
|
268
|
+
flatten_unflatten=flatten_unflatten,
|
|
269
|
+
)
|
|
178
270
|
|
|
179
271
|
if isinstance(self.args_names, list):
|
|
180
272
|
if not set(self.args_names) & set(self.dynamic_shapes):
|
|
181
273
|
# No dynamic shapes for the positional arguments.
|
|
182
|
-
return self._generic_walker_step(
|
|
274
|
+
return self._generic_walker_step(
|
|
275
|
+
processor,
|
|
276
|
+
self.kwargs,
|
|
277
|
+
self.dynamic_shapes,
|
|
278
|
+
flatten_unflatten=flatten_unflatten,
|
|
279
|
+
)
|
|
183
280
|
|
|
184
281
|
assert self.args_names, (
|
|
185
282
|
"args and kwargs are filled, then args_names must be specified in "
|
|
@@ -192,7 +289,19 @@ class CoupleInputsDynamicShapes:
|
|
|
192
289
|
)
|
|
193
290
|
kwargs = dict(zip(self.args_names, self.args))
|
|
194
291
|
kwargs.update(self.kwargs)
|
|
195
|
-
|
|
292
|
+
res = self._generic_walker_step(
|
|
293
|
+
processor, kwargs, self.dynamic_shapes, flatten_unflatten=flatten_unflatten
|
|
294
|
+
)
|
|
295
|
+
if args_kwargs:
|
|
296
|
+
pgs = [None for _ in range(len(self.args))]
|
|
297
|
+
kws = {}
|
|
298
|
+
for k, v in res.items():
|
|
299
|
+
if k not in self.kwargs:
|
|
300
|
+
pgs[self.args_names.index(k)] = v
|
|
301
|
+
else:
|
|
302
|
+
kws[k] = v
|
|
303
|
+
return pgs, kws
|
|
304
|
+
return res
|
|
196
305
|
|
|
197
306
|
raise NotImplementedError(
|
|
198
307
|
f"Not yet implemented when args is filled, "
|
|
@@ -200,35 +309,48 @@ class CoupleInputsDynamicShapes:
|
|
|
200
309
|
)
|
|
201
310
|
|
|
202
311
|
@classmethod
|
|
203
|
-
def _generic_walker_step(
|
|
312
|
+
def _generic_walker_step(
|
|
313
|
+
cls, processor: Callable, inputs, ds, flatten_unflatten: bool = False
|
|
314
|
+
):
|
|
204
315
|
if isinstance(inputs, torch.Tensor):
|
|
205
316
|
return processor(inputs, ds)
|
|
206
317
|
if isinstance(inputs, (int, float, str)):
|
|
207
318
|
return None
|
|
208
|
-
if
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
319
|
+
if type(inputs) in (tuple, list, dict):
|
|
320
|
+
# Type must be strict, some custom classes can inherit from those.
|
|
321
|
+
assert type(inputs) is type(ds), (
|
|
322
|
+
f"Input type and dynamic shape type mush match but "
|
|
323
|
+
f"type(inputs)={type(inputs)}, type(ds)={type(ds)}, "
|
|
324
|
+
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
|
|
325
|
+
)
|
|
212
326
|
assert len(ds) == len(inputs), (
|
|
213
327
|
f"Length mismatch between inputs {len(inputs)} "
|
|
214
328
|
f"and ds={len(ds)}\n"
|
|
215
329
|
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
|
|
216
330
|
)
|
|
217
|
-
if
|
|
331
|
+
if type(inputs) in (tuple, list):
|
|
218
332
|
value = []
|
|
219
333
|
for i, d in zip(inputs, ds):
|
|
220
|
-
value.append(
|
|
334
|
+
value.append(
|
|
335
|
+
cls._generic_walker_step(
|
|
336
|
+
processor, i, d, flatten_unflatten=flatten_unflatten
|
|
337
|
+
)
|
|
338
|
+
)
|
|
221
339
|
return (
|
|
222
340
|
(value if isinstance(ds, list) else tuple(value))
|
|
223
341
|
if any(v is not None for v in value)
|
|
224
342
|
else None
|
|
225
343
|
)
|
|
226
|
-
assert
|
|
227
|
-
|
|
228
|
-
|
|
344
|
+
assert type(inputs) is dict, f"Unexpected type for inputs {type(inputs)}"
|
|
345
|
+
assert set(inputs) == set(ds), (
|
|
346
|
+
f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, "
|
|
347
|
+
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
|
|
348
|
+
)
|
|
229
349
|
dvalue = {}
|
|
230
350
|
for k, v in inputs.items():
|
|
231
|
-
t = cls._generic_walker_step(
|
|
351
|
+
t = cls._generic_walker_step(
|
|
352
|
+
processor, v, ds[k], flatten_unflatten=flatten_unflatten
|
|
353
|
+
)
|
|
232
354
|
if t is not None:
|
|
233
355
|
dvalue[k] = t
|
|
234
356
|
return dvalue if dvalue else None
|
|
@@ -239,12 +361,22 @@ class CoupleInputsDynamicShapes:
|
|
|
239
361
|
f"torch.utils._pytree.register_pytree_node, it is not possible to "
|
|
240
362
|
f"map this class with the given dynamic shapes."
|
|
241
363
|
)
|
|
364
|
+
if flatten_unflatten:
|
|
365
|
+
flatunflat = flatten_unflatten_for_dynamic_shapes(inputs)
|
|
366
|
+
return cls._generic_walker_step(
|
|
367
|
+
processor, flatunflat, ds, flatten_unflatten=flatten_unflatten
|
|
368
|
+
)
|
|
242
369
|
flat, _spec = torch.utils._pytree.tree_flatten(inputs)
|
|
243
|
-
|
|
370
|
+
if all(isinstance(t, torch.Tensor) for t in flat):
|
|
371
|
+
# We need to flatten dynamic shapes as well
|
|
372
|
+
ds = flatten_dynamic_shapes(ds)
|
|
373
|
+
return cls._generic_walker_step(
|
|
374
|
+
processor, flat, ds, flatten_unflatten=flatten_unflatten
|
|
375
|
+
)
|
|
244
376
|
|
|
245
377
|
class ChangeDimensionProcessor:
|
|
246
|
-
def __init__(self):
|
|
247
|
-
self.mapping = {}
|
|
378
|
+
def __init__(self, desired_values):
|
|
379
|
+
self.mapping = desired_values or {}
|
|
248
380
|
|
|
249
381
|
def _build_new_shape(
|
|
250
382
|
self, shape: Tuple[int, ...], ds: Dict[int, Any]
|
|
@@ -285,14 +417,14 @@ class CoupleInputsDynamicShapes:
|
|
|
285
417
|
tuple(alt_shape), dtype=tensor.dtype, device=tensor.device
|
|
286
418
|
)
|
|
287
419
|
mind = min(d0, d1)
|
|
288
|
-
indices = [slice(None) for _ in range(rank)]
|
|
420
|
+
indices: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
|
|
289
421
|
indices[i] = slice(0, mind)
|
|
290
422
|
ind = tuple(indices)
|
|
291
423
|
new_tensor[ind] = tensor[ind]
|
|
292
424
|
if d1 > mind:
|
|
293
425
|
for k in range(d1 - mind):
|
|
294
|
-
indices0 = [slice(None) for _ in range(rank)]
|
|
295
|
-
indices1 = [slice(None) for _ in range(rank)]
|
|
426
|
+
indices0: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
|
|
427
|
+
indices1: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
|
|
296
428
|
indices1[i] = mind + k
|
|
297
429
|
indices0[i] = k % mind
|
|
298
430
|
new_tensor[tuple(indices1)] = tensor[tuple(indices0)]
|
|
@@ -310,7 +442,9 @@ class CoupleInputsDynamicShapes:
|
|
|
310
442
|
new_shape = self._build_new_shape(inputs.shape, ds)
|
|
311
443
|
return self._build_new_tensor(inputs, new_shape)
|
|
312
444
|
|
|
313
|
-
def change_dynamic_dimensions(
|
|
445
|
+
def change_dynamic_dimensions(
|
|
446
|
+
self, desired_values: Optional[Dict[str, int]] = None, args_kwargs: bool = False
|
|
447
|
+
):
|
|
314
448
|
"""
|
|
315
449
|
A model exported with dynamic shapes is not necessarily dynamic
|
|
316
450
|
just because the user specified dynamic shapes. The algorithm
|
|
@@ -320,6 +454,10 @@ class CoupleInputsDynamicShapes:
|
|
|
320
454
|
for the dimension than the first ones, assuming they were used to export
|
|
321
455
|
the model.
|
|
322
456
|
|
|
457
|
+
:param desired_values: to fixed named dimension to have the desired value
|
|
458
|
+
:param args_kwargs: return both args, kwargs even if empty
|
|
459
|
+
:return: new inputs
|
|
460
|
+
|
|
323
461
|
Example:
|
|
324
462
|
|
|
325
463
|
.. runpython::
|
|
@@ -340,7 +478,9 @@ class CoupleInputsDynamicShapes:
|
|
|
340
478
|
print("before:", string_type(kwargs, with_shape=True))
|
|
341
479
|
print("-after:", string_type(new_kwargs, with_shape=True))
|
|
342
480
|
"""
|
|
343
|
-
return self._generic_walker(
|
|
481
|
+
return self._generic_walker(
|
|
482
|
+
self.ChangeDimensionProcessor(desired_values), args_kwargs=args_kwargs
|
|
483
|
+
)
|
|
344
484
|
|
|
345
485
|
|
|
346
486
|
class ModelInputs:
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import itertools
|
|
3
|
+
import time
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
5
|
+
import torch
|
|
6
|
+
from ..helpers import string_type, max_diff, string_diff
|
|
7
|
+
from ..helpers.torch_test_helper import torch_deepcopy
|
|
8
|
+
from .dynamic_shapes import CoupleInputsDynamicShapes
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def compare_modules(
|
|
12
|
+
modep: torch.nn.Module,
|
|
13
|
+
mod: Optional[torch.nn.Module] = None,
|
|
14
|
+
args: Optional[Tuple[Any, ...]] = None,
|
|
15
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
|
16
|
+
copy: bool = False,
|
|
17
|
+
exc: bool = True,
|
|
18
|
+
verbose: int = 0,
|
|
19
|
+
atol: float = 1e-2,
|
|
20
|
+
rtol: float = 1e-1,
|
|
21
|
+
) -> Dict[str, Any]:
|
|
22
|
+
"""
|
|
23
|
+
Compares two torch modules, usually one coming from an exported program,
|
|
24
|
+
the other being the origin model.
|
|
25
|
+
|
|
26
|
+
:param model: first module
|
|
27
|
+
:param mod: second module (it produces the expected values)
|
|
28
|
+
:param args: positional arguments
|
|
29
|
+
:param kwargs: named arguments
|
|
30
|
+
:param copy: copy the inputs before executing the model (they may modify them inplace)
|
|
31
|
+
:param exc: raise exception if discrepancies are too high
|
|
32
|
+
:param verbose: verbosity level
|
|
33
|
+
:param atol: absolute tolerance
|
|
34
|
+
:param rtol: relative tolerance
|
|
35
|
+
:return: dictionary with inputs, outputs and tolerance
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
|
|
39
|
+
.. runpython::
|
|
40
|
+
:showcode:
|
|
41
|
+
|
|
42
|
+
import torch
|
|
43
|
+
from onnx_diagnostic.export import validate_ep, CoupleInputsDynamicShapes
|
|
44
|
+
|
|
45
|
+
class Model(torch.nn.Module):
|
|
46
|
+
def forward(self, x, y):
|
|
47
|
+
return x + y
|
|
48
|
+
|
|
49
|
+
model = Model()
|
|
50
|
+
x = torch.randn((5, 6))
|
|
51
|
+
y = torch.randn((1, 6))
|
|
52
|
+
model(x, y) # to make it is running
|
|
53
|
+
|
|
54
|
+
ds = ({0: "a", 1: "b"}, {1: "b"})
|
|
55
|
+
cpl = CoupleInputsDynamicShapes((x, y), {}, ds)
|
|
56
|
+
ep = torch.export.export(model, (x, y), dynamic_shapes=cpl.replace_string_by())
|
|
57
|
+
validate_ep(
|
|
58
|
+
ep,
|
|
59
|
+
model,
|
|
60
|
+
args=(x, y),
|
|
61
|
+
verbose=2,
|
|
62
|
+
copy=True,
|
|
63
|
+
dynamic_shapes=ds,
|
|
64
|
+
values_to_try={"a": [5, 10], "b": [10, 20]},
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
"""
|
|
68
|
+
args = args or ()
|
|
69
|
+
kwargs = kwargs or {}
|
|
70
|
+
|
|
71
|
+
def _get(a):
|
|
72
|
+
return torch_deepcopy(a) if copy else a
|
|
73
|
+
|
|
74
|
+
if verbose:
|
|
75
|
+
begin = time.perf_counter()
|
|
76
|
+
print(
|
|
77
|
+
f"[compare_modules] check ep with "
|
|
78
|
+
f"args={string_type(args, with_shape=True, with_device=True)}, "
|
|
79
|
+
f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}..."
|
|
80
|
+
)
|
|
81
|
+
got = modep(*_get(args), **_get(kwargs))
|
|
82
|
+
if verbose:
|
|
83
|
+
d = time.perf_counter() - begin
|
|
84
|
+
print(f"[compare_modules] done in {d} with output={string_type(got, with_shape=True)}")
|
|
85
|
+
if mod:
|
|
86
|
+
if verbose:
|
|
87
|
+
begin = time.perf_counter()
|
|
88
|
+
print("[compare_modules] run torch module...")
|
|
89
|
+
expected = mod(*_get(args), **_get(kwargs))
|
|
90
|
+
diff = max_diff(expected, got)
|
|
91
|
+
if verbose:
|
|
92
|
+
d = time.perf_counter() - begin
|
|
93
|
+
print(
|
|
94
|
+
f"[compare_modules] done in {d} with "
|
|
95
|
+
f"output={string_type(expected, with_shape=True)}"
|
|
96
|
+
)
|
|
97
|
+
print(f"[compare_modules] discrepancies={string_diff(diff)}")
|
|
98
|
+
assert not exc or (
|
|
99
|
+
diff["abs"] <= atol and diff["rel"] <= rtol
|
|
100
|
+
), f"Discrepancies={string_diff(diff)} higher than expected."
|
|
101
|
+
return dict(args=args, kwargs=kwargs, expected=expected, got=got, diff=diff)
|
|
102
|
+
return dict(args=args, kwargs=kwargs, got=got)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def validate_ep(
|
|
106
|
+
ep: Union[torch.nn.Module, torch.export.ExportedProgram],
|
|
107
|
+
mod: Optional[torch.nn.Module] = None,
|
|
108
|
+
args: Optional[Tuple[Any, ...]] = None,
|
|
109
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
|
110
|
+
copy: bool = False,
|
|
111
|
+
dynamic_shapes: Optional[Any] = None,
|
|
112
|
+
values_to_try: Optional[Dict[str, List[int]]] = None,
|
|
113
|
+
exc: bool = True,
|
|
114
|
+
verbose: int = 0,
|
|
115
|
+
atol: float = 1e-2,
|
|
116
|
+
rtol: float = 1e-1,
|
|
117
|
+
) -> List[Dict[str, Any]]:
|
|
118
|
+
"""
|
|
119
|
+
Validates an exported program.
|
|
120
|
+
|
|
121
|
+
:param model: first module
|
|
122
|
+
:param mod: second module (it produces the expected values)
|
|
123
|
+
:param args: positional arguments
|
|
124
|
+
:param kwargs: named arguments
|
|
125
|
+
:param copy: copy the inputs before executing the model (they may modify them inplace)
|
|
126
|
+
:param dynamic_shapes: dynamic shapes, string should be used not ``torch.export.Dim``
|
|
127
|
+
:param values_to_try: dictionary with the values to try for every dynamic dimension
|
|
128
|
+
:param exc: raise exception if discrepancies are too high
|
|
129
|
+
:param verbose: verbosity level
|
|
130
|
+
:param atol: absolute tolerance
|
|
131
|
+
:param rtol: relative tolerance
|
|
132
|
+
:return: dictionary with inputs, outputs and tolerance
|
|
133
|
+
"""
|
|
134
|
+
modep = ep.module() if isinstance(ep, torch.export.ExportedProgram) else ep
|
|
135
|
+
|
|
136
|
+
results = [
|
|
137
|
+
compare_modules(
|
|
138
|
+
modep, mod, args, kwargs, copy=copy, verbose=verbose, atol=atol, rtol=rtol
|
|
139
|
+
)
|
|
140
|
+
]
|
|
141
|
+
|
|
142
|
+
assert (dynamic_shapes and values_to_try) or (
|
|
143
|
+
not dynamic_shapes and not values_to_try
|
|
144
|
+
), "Either both dynamic_shapes and values_to_try are specified, either none."
|
|
145
|
+
if not dynamic_shapes or not values_to_try:
|
|
146
|
+
return results
|
|
147
|
+
|
|
148
|
+
items = list(values_to_try.items())
|
|
149
|
+
keys = [_[0] for _ in items]
|
|
150
|
+
values = [_[1] for _ in items]
|
|
151
|
+
all_vals = list(itertools.product(*values))
|
|
152
|
+
cpl = CoupleInputsDynamicShapes(
|
|
153
|
+
args or (),
|
|
154
|
+
kwargs or {},
|
|
155
|
+
dynamic_shapes,
|
|
156
|
+
args_names=(
|
|
157
|
+
list(inspect.signature(modep.forward).parameters) if args and kwargs else None
|
|
158
|
+
),
|
|
159
|
+
)
|
|
160
|
+
for i, vals in enumerate(all_vals):
|
|
161
|
+
change_dims = dict(zip(keys, vals))
|
|
162
|
+
if verbose:
|
|
163
|
+
print(f"[validate_ep] try {i}/{len(all_vals)}: {change_dims}")
|
|
164
|
+
new_params = cpl.change_dynamic_dimensions(change_dims, args_kwargs=True)
|
|
165
|
+
na, nkw = new_params
|
|
166
|
+
c = compare_modules(
|
|
167
|
+
modep, mod, na, nkw, copy=copy, verbose=max(verbose - 1, 0), atol=atol, rtol=rtol
|
|
168
|
+
)
|
|
169
|
+
results.append(c)
|
|
170
|
+
return results
|