onnx-diagnostic 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +45 -4
  3. onnx_diagnostic/export/__init__.py +1 -0
  4. onnx_diagnostic/export/dynamic_shapes.py +169 -29
  5. onnx_diagnostic/export/validate.py +170 -0
  6. onnx_diagnostic/ext_test_case.py +66 -3
  7. onnx_diagnostic/helpers/cache_helper.py +81 -5
  8. onnx_diagnostic/helpers/config_helper.py +80 -0
  9. onnx_diagnostic/helpers/helper.py +283 -81
  10. onnx_diagnostic/helpers/ort_session.py +1 -39
  11. onnx_diagnostic/helpers/rt_helper.py +47 -0
  12. onnx_diagnostic/helpers/torch_test_helper.py +14 -3
  13. onnx_diagnostic/tasks/__init__.py +48 -0
  14. onnx_diagnostic/tasks/automatic_speech_recognition.py +165 -0
  15. onnx_diagnostic/tasks/fill_mask.py +67 -0
  16. onnx_diagnostic/tasks/image_classification.py +96 -0
  17. onnx_diagnostic/tasks/image_text_to_text.py +145 -0
  18. onnx_diagnostic/tasks/sentence_similarity.py +67 -0
  19. onnx_diagnostic/tasks/text2text_generation.py +172 -0
  20. onnx_diagnostic/tasks/text_classification.py +67 -0
  21. onnx_diagnostic/tasks/text_generation.py +248 -0
  22. onnx_diagnostic/tasks/zero_shot_image_classification.py +106 -0
  23. onnx_diagnostic/torch_export_patches/__init__.py +0 -107
  24. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +21 -160
  25. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +339 -61
  26. onnx_diagnostic/torch_export_patches/patch_inputs.py +29 -0
  27. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +29 -0
  28. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +37 -47
  29. onnx_diagnostic/torch_models/hghub/hub_api.py +26 -6
  30. onnx_diagnostic/torch_models/hghub/hub_data.py +22 -14
  31. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +337 -25
  32. onnx_diagnostic/torch_models/hghub/model_inputs.py +32 -608
  33. onnx_diagnostic/torch_models/test_helper.py +651 -228
  34. {onnx_diagnostic-0.3.0.dist-info → onnx_diagnostic-0.4.0.dist-info}/METADATA +13 -3
  35. {onnx_diagnostic-0.3.0.dist-info → onnx_diagnostic-0.4.0.dist-info}/RECORD +38 -25
  36. {onnx_diagnostic-0.3.0.dist-info → onnx_diagnostic-0.4.0.dist-info}/WHEEL +1 -1
  37. {onnx_diagnostic-0.3.0.dist-info → onnx_diagnostic-0.4.0.dist-info}/licenses/LICENSE.txt +0 -0
  38. {onnx_diagnostic-0.3.0.dist-info → onnx_diagnostic-0.4.0.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.3.0"
6
+ __version__ = "0.4.0"
7
7
  __author__ = "Xavier Dupré"
@@ -1,3 +1,4 @@
1
+ import argparse
1
2
  import json
2
3
  import sys
3
4
  import textwrap
@@ -227,6 +228,21 @@ def _cmd_config(argv: List[Any]):
227
228
  print(f"task: {task_from_id(args.mid)}")
228
229
 
229
230
 
231
+ class _ParseDict(argparse.Action):
232
+ def __call__(self, parser, namespace, values, option_string=None):
233
+ d = getattr(namespace, self.dest) or {}
234
+
235
+ if values:
236
+ for item in values:
237
+ split_items = item.split("=", 1)
238
+ key = split_items[0].strip() # we remove blanks around keys, as is logical
239
+ value = split_items[1]
240
+
241
+ d[key] = value
242
+
243
+ setattr(namespace, self.dest, d)
244
+
245
+
230
246
  def get_parser_validate() -> ArgumentParser:
231
247
  parser = ArgumentParser(
232
248
  prog="test",
@@ -287,22 +303,37 @@ def get_parser_validate() -> ArgumentParser:
287
303
  help="drops the following inputs names, it should be a list "
288
304
  "with comma separated values",
289
305
  )
306
+ parser.add_argument(
307
+ "--ortfusiontype",
308
+ required=False,
309
+ help="applies onnxruntime fusion, this parameter should contain the "
310
+ "model type or multiple values separated by `|`. `ALL` can be used "
311
+ "to run them all",
312
+ )
290
313
  parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
291
314
  parser.add_argument("--dtype", help="changes dtype if necessary")
292
315
  parser.add_argument("--device", help="changes the device if necessary")
316
+ parser.add_argument(
317
+ "--iop",
318
+ metavar="KEY=VALUE",
319
+ nargs="*",
320
+ help="Additional input options, use to change the default "
321
+ "inputs use to export, example: --iop cls_cache=SlidingWindowCache",
322
+ action=_ParseDict,
323
+ )
293
324
  return parser
294
325
 
295
326
 
296
327
  def _cmd_validate(argv: List[Any]):
297
328
  from .helpers import string_type
298
- from .torch_models.test_helper import get_inputs_for_task, validate_model, _ds_clean
299
- from .torch_models.hghub.model_inputs import get_get_inputs_function_for_tasks
329
+ from .torch_models.test_helper import get_inputs_for_task, validate_model
330
+ from .tasks import supported_tasks
300
331
 
301
332
  parser = get_parser_validate()
302
333
  args = parser.parse_args(argv[1:])
303
334
  if not args.task and not args.mid:
304
335
  print("-- list of supported tasks:")
305
- print("\n".join(sorted(get_get_inputs_function_for_tasks())))
336
+ print("\n".join(supported_tasks()))
306
337
  elif not args.mid:
307
338
  data = get_inputs_for_task(args.task)
308
339
  if args.verbose:
@@ -313,8 +344,16 @@ def _cmd_validate(argv: List[Any]):
313
344
  print(f" + {k.ljust(max_length)}: {string_type(v, with_shape=True)}")
314
345
  print("-- dynamic_shapes")
315
346
  for k, v in data["dynamic_shapes"].items():
316
- print(f" + {k.ljust(max_length)}: {_ds_clean(v)}")
347
+ print(f" + {k.ljust(max_length)}: {string_type(v)}")
317
348
  else:
349
+ # Let's skip any invalid combination if known to be unsupported
350
+ if (
351
+ "onnx" not in (args.export or "")
352
+ and "custom" not in (args.export or "")
353
+ and (args.opt or "")
354
+ ):
355
+ print(f"validate - unsupported args: export={args.export!r}, opt={args.opt!r}")
356
+ return
318
357
  summary, _data = validate_model(
319
358
  model_id=args.mid,
320
359
  task=args.task,
@@ -330,6 +369,8 @@ def _cmd_validate(argv: List[Any]):
330
369
  exporter=args.export,
331
370
  dump_folder=args.dump_folder,
332
371
  drop_inputs=None if not args.drop else args.drop.split(","),
372
+ ortfusiontype=args.ortfusiontype,
373
+ input_options=args.iop,
333
374
  )
334
375
  print("")
335
376
  print("-- summary --")
@@ -1 +1,2 @@
1
1
  from .dynamic_shapes import CoupleInputsDynamicShapes, ModelInputs
2
+ from .validate import validate_ep
@@ -1,5 +1,5 @@
1
1
  import inspect
2
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
3
3
  import numpy as np
4
4
  import torch
5
5
  from ..helpers import string_type
@@ -8,6 +8,30 @@ from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
8
8
  DYNAMIC_SHAPES = Tuple[Tuple[Any, ...], Dict[str, Any]]
9
9
 
10
10
 
11
+ def flatten_dynamic_shapes(ds: Any) -> Any:
12
+ """Flattens the dynamic shapes."""
13
+ if isinstance(ds, list):
14
+ return _flat_list([flatten_dynamic_shapes(t) for t in ds])
15
+ if isinstance(ds, tuple):
16
+ return tuple(_flat_list([flatten_dynamic_shapes(t) for t in ds]))
17
+ if isinstance(ds, dict):
18
+ if all(isinstance(i, int) for i in ds):
19
+ # That's a dynamic shape
20
+ return ds
21
+ return _flat_list([flatten_dynamic_shapes(t) for t in ds.values()])
22
+ raise AssertionError(f"Not implemented for {type(ds)}: {ds}")
23
+
24
+
25
+ def _flat_list(li: List[Any]) -> List[Dict[int, str]]:
26
+ res = []
27
+ for t in li:
28
+ if isinstance(t, dict):
29
+ res.append(t)
30
+ else:
31
+ res.extend(t)
32
+ return res
33
+
34
+
11
35
  class CoupleInputsDynamicShapes:
12
36
  """
13
37
  Pair inputs / dynamic shapes.
@@ -68,7 +92,8 @@ class CoupleInputsDynamicShapes:
68
92
  return self._generic_walker(
69
93
  lambda inputs, ds, value=value: self._replace_string_dim_tensor(
70
94
  inputs, ds, value=value
71
- )
95
+ ),
96
+ flatten_unflatten=True,
72
97
  )
73
98
 
74
99
  @classmethod
@@ -76,7 +101,7 @@ class CoupleInputsDynamicShapes:
76
101
  assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
77
102
  assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
78
103
  f"Unexpected types, inputs is a Tensor but ds is {ds}, "
79
- f"a dictionary is expected to specify a dimension dimension"
104
+ f"a dictionary is expected to specify a dimension"
80
105
  )
81
106
  if value is None:
82
107
  value = torch.export.Dim.DYNAMIC
@@ -86,6 +111,57 @@ class CoupleInputsDynamicShapes:
86
111
  new_ds[i] = value
87
112
  return new_ds
88
113
 
114
+ def replace_by_string(self):
115
+ """
116
+ Replaces dimensions by strings.
117
+
118
+ Example:
119
+
120
+ .. runpython::
121
+ :showcode:
122
+
123
+ import torch
124
+ from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
125
+
126
+ Dim = torch.export.Dim
127
+ T3x1 = torch.rand((3, 1))
128
+ T3x4 = torch.rand((3, 4))
129
+ ds_batch = {0: Dim("batch")}
130
+ ds_batch_seq = {0: Dim("batch"), 1: Dim("seq")}
131
+ kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
132
+ ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
133
+ print(CoupleInputsDynamicShapes((), kwargs, ds).replace_by_string())
134
+ """
135
+ unique = set()
136
+ return self._generic_walker(
137
+ lambda inputs, ds, unique=unique: self._replace_dim_tensor_by_string(
138
+ inputs, ds, unique=unique
139
+ ),
140
+ flatten_unflatten=True,
141
+ )
142
+
143
+ @classmethod
144
+ def _replace_dim_tensor_by_string(cls, inputs, ds, unique: Set[str]):
145
+ assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
146
+ assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
147
+ f"Unexpected types, inputs is a Tensor but ds is {ds}, "
148
+ f"a dictionary is expected to specify a dimension"
149
+ )
150
+ new_ds = ds.copy()
151
+ for i, v in ds.items():
152
+ if isinstance(v, str):
153
+ unique.add(v)
154
+ new_ds[i] = v
155
+ elif v in (torch.export.Dim.DYNAMIC, torch.export.Dim.AUTO):
156
+ name = f"Dim{len(unique)}"
157
+ new_ds[i] = name
158
+ unique.add(name)
159
+ else:
160
+ name = v.__name__
161
+ unique.add(name)
162
+ new_ds[i] = name
163
+ return new_ds
164
+
89
165
  def invalid_dimensions_for_export(self):
90
166
  """
91
167
  Tells if the inputs are valid based on the dynamic shapes definition.
@@ -129,7 +205,7 @@ class CoupleInputsDynamicShapes:
129
205
  ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
130
206
  print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
131
207
  """
132
- return self._generic_walker(self._valid_shapes_tensor)
208
+ return self._generic_walker(self._valid_shapes_tensor, flatten_unflatten=True)
133
209
 
134
210
  @classmethod
135
211
  def _valid_shapes_tensor(cls, inputs, ds):
@@ -147,7 +223,9 @@ class CoupleInputsDynamicShapes:
147
223
  issues[i] = f"d=[{d}]"
148
224
  return issues if issues else None
149
225
 
150
- def _generic_walker(self, processor: Callable):
226
+ def _generic_walker(
227
+ self, processor: Callable, args_kwargs: bool = False, flatten_unflatten: bool = False
228
+ ):
151
229
  """
152
230
  Generic deserializator walking through inputs and dynamic_shapes all along.
153
231
  The function returns a result with the same structure as the dynamic shapes.
@@ -157,14 +235,23 @@ class CoupleInputsDynamicShapes:
157
235
  f"Type mismatch, args={string_type(self.args)} and "
158
236
  f"dynamic_shapes={self.dynamic_shapes} should have the same type."
159
237
  )
160
- return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
238
+ res = self._generic_walker_step(
239
+ processor,
240
+ self.kwargs,
241
+ self.dynamic_shapes,
242
+ flatten_unflatten=flatten_unflatten,
243
+ )
244
+ return (tuple(), res) if args_kwargs else res
161
245
 
162
246
  if not self.kwargs:
163
247
  assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), (
164
248
  f"Type mismatch, args={string_type(self.args)} and "
165
249
  f"dynamic_shapes={self.dynamic_shapes} should have the same type."
166
250
  )
167
- return self._generic_walker_step(processor, self.args, self.dynamic_shapes)
251
+ res = self._generic_walker_step(
252
+ processor, self.args, self.dynamic_shapes, flatten_unflatten=flatten_unflatten
253
+ )
254
+ return (res, {}) if args_kwargs else res
168
255
 
169
256
  assert isinstance(self.dynamic_shapes, dict), (
170
257
  f"Both positional and named arguments (args and kwargs) are filled. "
@@ -174,12 +261,22 @@ class CoupleInputsDynamicShapes:
174
261
  self.dynamic_shapes
175
262
  ):
176
263
  # No dynamic shapes for the positional arguments.
177
- return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
264
+ return self._generic_walker_step(
265
+ processor,
266
+ self.kwargs,
267
+ self.dynamic_shapes,
268
+ flatten_unflatten=flatten_unflatten,
269
+ )
178
270
 
179
271
  if isinstance(self.args_names, list):
180
272
  if not set(self.args_names) & set(self.dynamic_shapes):
181
273
  # No dynamic shapes for the positional arguments.
182
- return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
274
+ return self._generic_walker_step(
275
+ processor,
276
+ self.kwargs,
277
+ self.dynamic_shapes,
278
+ flatten_unflatten=flatten_unflatten,
279
+ )
183
280
 
184
281
  assert self.args_names, (
185
282
  "args and kwargs are filled, then args_names must be specified in "
@@ -192,7 +289,19 @@ class CoupleInputsDynamicShapes:
192
289
  )
193
290
  kwargs = dict(zip(self.args_names, self.args))
194
291
  kwargs.update(self.kwargs)
195
- return self._generic_walker_step(processor, kwargs, self.dynamic_shapes)
292
+ res = self._generic_walker_step(
293
+ processor, kwargs, self.dynamic_shapes, flatten_unflatten=flatten_unflatten
294
+ )
295
+ if args_kwargs:
296
+ pgs = [None for _ in range(len(self.args))]
297
+ kws = {}
298
+ for k, v in res.items():
299
+ if k not in self.kwargs:
300
+ pgs[self.args_names.index(k)] = v
301
+ else:
302
+ kws[k] = v
303
+ return pgs, kws
304
+ return res
196
305
 
197
306
  raise NotImplementedError(
198
307
  f"Not yet implemented when args is filled, "
@@ -200,35 +309,48 @@ class CoupleInputsDynamicShapes:
200
309
  )
201
310
 
202
311
  @classmethod
203
- def _generic_walker_step(cls, processor: Callable, inputs, ds):
312
+ def _generic_walker_step(
313
+ cls, processor: Callable, inputs, ds, flatten_unflatten: bool = False
314
+ ):
204
315
  if isinstance(inputs, torch.Tensor):
205
316
  return processor(inputs, ds)
206
317
  if isinstance(inputs, (int, float, str)):
207
318
  return None
208
- if isinstance(inputs, (tuple, list, dict)):
209
- assert type(ds) is type(
210
- inputs
211
- ), f"Type mismatch between inputs {type(inputs)} and ds={type(ds)}"
319
+ if type(inputs) in (tuple, list, dict):
320
+ # Type must be strict, some custom classes can inherit from those.
321
+ assert type(inputs) is type(ds), (
322
+ f"Input type and dynamic shape type mush match but "
323
+ f"type(inputs)={type(inputs)}, type(ds)={type(ds)}, "
324
+ f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
325
+ )
212
326
  assert len(ds) == len(inputs), (
213
327
  f"Length mismatch between inputs {len(inputs)} "
214
328
  f"and ds={len(ds)}\n"
215
329
  f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
216
330
  )
217
- if isinstance(inputs, (tuple, list)):
331
+ if type(inputs) in (tuple, list):
218
332
  value = []
219
333
  for i, d in zip(inputs, ds):
220
- value.append(cls._generic_walker_step(processor, i, d))
334
+ value.append(
335
+ cls._generic_walker_step(
336
+ processor, i, d, flatten_unflatten=flatten_unflatten
337
+ )
338
+ )
221
339
  return (
222
340
  (value if isinstance(ds, list) else tuple(value))
223
341
  if any(v is not None for v in value)
224
342
  else None
225
343
  )
226
- assert set(inputs) == set(
227
- ds
228
- ), f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}"
344
+ assert type(inputs) is dict, f"Unexpected type for inputs {type(inputs)}"
345
+ assert set(inputs) == set(ds), (
346
+ f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, "
347
+ f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
348
+ )
229
349
  dvalue = {}
230
350
  for k, v in inputs.items():
231
- t = cls._generic_walker_step(processor, v, ds[k])
351
+ t = cls._generic_walker_step(
352
+ processor, v, ds[k], flatten_unflatten=flatten_unflatten
353
+ )
232
354
  if t is not None:
233
355
  dvalue[k] = t
234
356
  return dvalue if dvalue else None
@@ -239,12 +361,22 @@ class CoupleInputsDynamicShapes:
239
361
  f"torch.utils._pytree.register_pytree_node, it is not possible to "
240
362
  f"map this class with the given dynamic shapes."
241
363
  )
364
+ if flatten_unflatten:
365
+ flatunflat = flatten_unflatten_for_dynamic_shapes(inputs)
366
+ return cls._generic_walker_step(
367
+ processor, flatunflat, ds, flatten_unflatten=flatten_unflatten
368
+ )
242
369
  flat, _spec = torch.utils._pytree.tree_flatten(inputs)
243
- return cls._generic_walker_step(processor, flat, ds)
370
+ if all(isinstance(t, torch.Tensor) for t in flat):
371
+ # We need to flatten dynamic shapes as well
372
+ ds = flatten_dynamic_shapes(ds)
373
+ return cls._generic_walker_step(
374
+ processor, flat, ds, flatten_unflatten=flatten_unflatten
375
+ )
244
376
 
245
377
  class ChangeDimensionProcessor:
246
- def __init__(self):
247
- self.mapping = {}
378
+ def __init__(self, desired_values):
379
+ self.mapping = desired_values or {}
248
380
 
249
381
  def _build_new_shape(
250
382
  self, shape: Tuple[int, ...], ds: Dict[int, Any]
@@ -285,14 +417,14 @@ class CoupleInputsDynamicShapes:
285
417
  tuple(alt_shape), dtype=tensor.dtype, device=tensor.device
286
418
  )
287
419
  mind = min(d0, d1)
288
- indices = [slice(None) for _ in range(rank)]
420
+ indices: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
289
421
  indices[i] = slice(0, mind)
290
422
  ind = tuple(indices)
291
423
  new_tensor[ind] = tensor[ind]
292
424
  if d1 > mind:
293
425
  for k in range(d1 - mind):
294
- indices0 = [slice(None) for _ in range(rank)]
295
- indices1 = [slice(None) for _ in range(rank)]
426
+ indices0: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
427
+ indices1: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
296
428
  indices1[i] = mind + k
297
429
  indices0[i] = k % mind
298
430
  new_tensor[tuple(indices1)] = tensor[tuple(indices0)]
@@ -310,7 +442,9 @@ class CoupleInputsDynamicShapes:
310
442
  new_shape = self._build_new_shape(inputs.shape, ds)
311
443
  return self._build_new_tensor(inputs, new_shape)
312
444
 
313
- def change_dynamic_dimensions(self):
445
+ def change_dynamic_dimensions(
446
+ self, desired_values: Optional[Dict[str, int]] = None, args_kwargs: bool = False
447
+ ):
314
448
  """
315
449
  A model exported with dynamic shapes is not necessarily dynamic
316
450
  just because the user specified dynamic shapes. The algorithm
@@ -320,6 +454,10 @@ class CoupleInputsDynamicShapes:
320
454
  for the dimension than the first ones, assuming they were used to export
321
455
  the model.
322
456
 
457
+ :param desired_values: to fixed named dimension to have the desired value
458
+ :param args_kwargs: return both args, kwargs even if empty
459
+ :return: new inputs
460
+
323
461
  Example:
324
462
 
325
463
  .. runpython::
@@ -340,7 +478,9 @@ class CoupleInputsDynamicShapes:
340
478
  print("before:", string_type(kwargs, with_shape=True))
341
479
  print("-after:", string_type(new_kwargs, with_shape=True))
342
480
  """
343
- return self._generic_walker(self.ChangeDimensionProcessor())
481
+ return self._generic_walker(
482
+ self.ChangeDimensionProcessor(desired_values), args_kwargs=args_kwargs
483
+ )
344
484
 
345
485
 
346
486
  class ModelInputs:
@@ -0,0 +1,170 @@
1
+ import inspect
2
+ import itertools
3
+ import time
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+ import torch
6
+ from ..helpers import string_type, max_diff, string_diff
7
+ from ..helpers.torch_test_helper import torch_deepcopy
8
+ from .dynamic_shapes import CoupleInputsDynamicShapes
9
+
10
+
11
+ def compare_modules(
12
+ modep: torch.nn.Module,
13
+ mod: Optional[torch.nn.Module] = None,
14
+ args: Optional[Tuple[Any, ...]] = None,
15
+ kwargs: Optional[Dict[str, Any]] = None,
16
+ copy: bool = False,
17
+ exc: bool = True,
18
+ verbose: int = 0,
19
+ atol: float = 1e-2,
20
+ rtol: float = 1e-1,
21
+ ) -> Dict[str, Any]:
22
+ """
23
+ Compares two torch modules, usually one coming from an exported program,
24
+ the other being the origin model.
25
+
26
+ :param model: first module
27
+ :param mod: second module (it produces the expected values)
28
+ :param args: positional arguments
29
+ :param kwargs: named arguments
30
+ :param copy: copy the inputs before executing the model (they may modify them inplace)
31
+ :param exc: raise exception if discrepancies are too high
32
+ :param verbose: verbosity level
33
+ :param atol: absolute tolerance
34
+ :param rtol: relative tolerance
35
+ :return: dictionary with inputs, outputs and tolerance
36
+
37
+ Example:
38
+
39
+ .. runpython::
40
+ :showcode:
41
+
42
+ import torch
43
+ from onnx_diagnostic.export import validate_ep, CoupleInputsDynamicShapes
44
+
45
+ class Model(torch.nn.Module):
46
+ def forward(self, x, y):
47
+ return x + y
48
+
49
+ model = Model()
50
+ x = torch.randn((5, 6))
51
+ y = torch.randn((1, 6))
52
+ model(x, y) # to make it is running
53
+
54
+ ds = ({0: "a", 1: "b"}, {1: "b"})
55
+ cpl = CoupleInputsDynamicShapes((x, y), {}, ds)
56
+ ep = torch.export.export(model, (x, y), dynamic_shapes=cpl.replace_string_by())
57
+ validate_ep(
58
+ ep,
59
+ model,
60
+ args=(x, y),
61
+ verbose=2,
62
+ copy=True,
63
+ dynamic_shapes=ds,
64
+ values_to_try={"a": [5, 10], "b": [10, 20]},
65
+ )
66
+
67
+ """
68
+ args = args or ()
69
+ kwargs = kwargs or {}
70
+
71
+ def _get(a):
72
+ return torch_deepcopy(a) if copy else a
73
+
74
+ if verbose:
75
+ begin = time.perf_counter()
76
+ print(
77
+ f"[compare_modules] check ep with "
78
+ f"args={string_type(args, with_shape=True, with_device=True)}, "
79
+ f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}..."
80
+ )
81
+ got = modep(*_get(args), **_get(kwargs))
82
+ if verbose:
83
+ d = time.perf_counter() - begin
84
+ print(f"[compare_modules] done in {d} with output={string_type(got, with_shape=True)}")
85
+ if mod:
86
+ if verbose:
87
+ begin = time.perf_counter()
88
+ print("[compare_modules] run torch module...")
89
+ expected = mod(*_get(args), **_get(kwargs))
90
+ diff = max_diff(expected, got)
91
+ if verbose:
92
+ d = time.perf_counter() - begin
93
+ print(
94
+ f"[compare_modules] done in {d} with "
95
+ f"output={string_type(expected, with_shape=True)}"
96
+ )
97
+ print(f"[compare_modules] discrepancies={string_diff(diff)}")
98
+ assert not exc or (
99
+ diff["abs"] <= atol and diff["rel"] <= rtol
100
+ ), f"Discrepancies={string_diff(diff)} higher than expected."
101
+ return dict(args=args, kwargs=kwargs, expected=expected, got=got, diff=diff)
102
+ return dict(args=args, kwargs=kwargs, got=got)
103
+
104
+
105
+ def validate_ep(
106
+ ep: Union[torch.nn.Module, torch.export.ExportedProgram],
107
+ mod: Optional[torch.nn.Module] = None,
108
+ args: Optional[Tuple[Any, ...]] = None,
109
+ kwargs: Optional[Dict[str, Any]] = None,
110
+ copy: bool = False,
111
+ dynamic_shapes: Optional[Any] = None,
112
+ values_to_try: Optional[Dict[str, List[int]]] = None,
113
+ exc: bool = True,
114
+ verbose: int = 0,
115
+ atol: float = 1e-2,
116
+ rtol: float = 1e-1,
117
+ ) -> List[Dict[str, Any]]:
118
+ """
119
+ Validates an exported program.
120
+
121
+ :param model: first module
122
+ :param mod: second module (it produces the expected values)
123
+ :param args: positional arguments
124
+ :param kwargs: named arguments
125
+ :param copy: copy the inputs before executing the model (they may modify them inplace)
126
+ :param dynamic_shapes: dynamic shapes, string should be used not ``torch.export.Dim``
127
+ :param values_to_try: dictionary with the values to try for every dynamic dimension
128
+ :param exc: raise exception if discrepancies are too high
129
+ :param verbose: verbosity level
130
+ :param atol: absolute tolerance
131
+ :param rtol: relative tolerance
132
+ :return: dictionary with inputs, outputs and tolerance
133
+ """
134
+ modep = ep.module() if isinstance(ep, torch.export.ExportedProgram) else ep
135
+
136
+ results = [
137
+ compare_modules(
138
+ modep, mod, args, kwargs, copy=copy, verbose=verbose, atol=atol, rtol=rtol
139
+ )
140
+ ]
141
+
142
+ assert (dynamic_shapes and values_to_try) or (
143
+ not dynamic_shapes and not values_to_try
144
+ ), "Either both dynamic_shapes and values_to_try are specified, either none."
145
+ if not dynamic_shapes or not values_to_try:
146
+ return results
147
+
148
+ items = list(values_to_try.items())
149
+ keys = [_[0] for _ in items]
150
+ values = [_[1] for _ in items]
151
+ all_vals = list(itertools.product(*values))
152
+ cpl = CoupleInputsDynamicShapes(
153
+ args or (),
154
+ kwargs or {},
155
+ dynamic_shapes,
156
+ args_names=(
157
+ list(inspect.signature(modep.forward).parameters) if args and kwargs else None
158
+ ),
159
+ )
160
+ for i, vals in enumerate(all_vals):
161
+ change_dims = dict(zip(keys, vals))
162
+ if verbose:
163
+ print(f"[validate_ep] try {i}/{len(all_vals)}: {change_dims}")
164
+ new_params = cpl.change_dynamic_dimensions(change_dims, args_kwargs=True)
165
+ na, nkw = new_params
166
+ c = compare_modules(
167
+ modep, mod, na, nkw, copy=copy, verbose=max(verbose - 1, 0), atol=atol, rtol=rtol
168
+ )
169
+ results.append(c)
170
+ return results