onnx-diagnostic 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +411 -0
  4. onnx_diagnostic/doc.py +32 -0
  5. onnx_diagnostic/export/__init__.py +1 -1
  6. onnx_diagnostic/export/dynamic_shapes.py +433 -22
  7. onnx_diagnostic/ext_test_case.py +90 -29
  8. onnx_diagnostic/helpers/__init__.py +1 -0
  9. onnx_diagnostic/helpers/bench_run.py +450 -0
  10. onnx_diagnostic/{cache_helpers.py → helpers/cache_helper.py} +62 -4
  11. onnx_diagnostic/{helpers.py → helpers/helper.py} +136 -659
  12. onnx_diagnostic/helpers/memory_peak.py +249 -0
  13. onnx_diagnostic/helpers/onnx_helper.py +921 -0
  14. onnx_diagnostic/{ort_session.py → helpers/ort_session.py} +54 -4
  15. onnx_diagnostic/{torch_test_helper.py → helpers/torch_test_helper.py} +142 -55
  16. onnx_diagnostic/reference/ops/op_cast_like.py +1 -1
  17. onnx_diagnostic/reference/ort_evaluator.py +7 -2
  18. onnx_diagnostic/torch_export_patches/__init__.py +107 -0
  19. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +160 -28
  20. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +13 -2
  21. onnx_diagnostic/torch_export_patches/patch_inputs.py +174 -0
  22. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +194 -1
  23. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +18 -5
  24. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  25. onnx_diagnostic/torch_models/hghub/hub_api.py +234 -0
  26. onnx_diagnostic/torch_models/hghub/hub_data.py +195 -0
  27. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +3259 -0
  28. onnx_diagnostic/torch_models/hghub/model_inputs.py +727 -0
  29. onnx_diagnostic/torch_models/llms.py +2 -96
  30. onnx_diagnostic/torch_models/test_helper.py +827 -0
  31. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  32. onnx_diagnostic/torch_models/untrained/llm_phi2.py +108 -0
  33. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +103 -0
  34. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  35. onnx_diagnostic/torch_onnx/sbs.py +439 -0
  36. {onnx_diagnostic-0.2.1.dist-info → onnx_diagnostic-0.3.0.dist-info}/METADATA +4 -2
  37. onnx_diagnostic-0.3.0.dist-info/RECORD +73 -0
  38. {onnx_diagnostic-0.2.1.dist-info → onnx_diagnostic-0.3.0.dist-info}/WHEEL +1 -1
  39. onnx_diagnostic/onnx_tools.py +0 -260
  40. onnx_diagnostic-0.2.1.dist-info/RECORD +0 -55
  41. /onnx_diagnostic/{args.py → helpers/args_helper.py} +0 -0
  42. {onnx_diagnostic-0.2.1.dist-info → onnx_diagnostic-0.3.0.dist-info}/licenses/LICENSE.txt +0 -0
  43. {onnx_diagnostic-0.2.1.dist-info → onnx_diagnostic-0.3.0.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,344 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
3
  import numpy as np
4
4
  import torch
5
5
  from ..helpers import string_type
6
+ from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
7
+
8
+ DYNAMIC_SHAPES = Tuple[Tuple[Any, ...], Dict[str, Any]]
9
+
10
+
11
+ class CoupleInputsDynamicShapes:
12
+ """
13
+ Pair inputs / dynamic shapes.
14
+
15
+ :param args: positional arguments
16
+ :param kwargs: named arguments
17
+ :param dynamic_shapes: dynamic shapes
18
+ :param args_names: if both args and kwargs are not empty, then
19
+ dynamic shapes must be a dictionary, and positional must be added
20
+ to the named arguments. Arguments names or a module must be given
21
+ in that case.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ args: Tuple[Any, ...],
27
+ kwargs: Dict[str, Any],
28
+ dynamic_shapes: DYNAMIC_SHAPES,
29
+ args_names: Optional[Union[torch.nn.Module, List[str]]] = None,
30
+ ):
31
+ self.args = args
32
+ self.kwargs = kwargs
33
+ self.dynamic_shapes = dynamic_shapes
34
+ self.args_names = args_names
35
+
36
+ def __str__(self) -> str:
37
+ return "\n".join(
38
+ [
39
+ f"{self.__class__.__name__}(",
40
+ f" args={string_type(self.args, with_shape=True)},"
41
+ f" kwargs={string_type(self.kwargs, with_shape=True)},"
42
+ f" dynamic_shapes={string_type(self.dynamic_shapes, with_shape=True)},"
43
+ f")",
44
+ ]
45
+ )
46
+
47
+ def replace_string_by(self, value: Any = None):
48
+ """
49
+ Replaces string by the value ``torch.export.Dim.DYNAMIC``
50
+ (default) or any other value specified by value.
51
+
52
+ Example:
53
+
54
+ .. runpython::
55
+ :showcode:
56
+
57
+ import torch
58
+ from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
59
+
60
+ T3x1 = torch.rand((3, 1))
61
+ T3x4 = torch.rand((3, 4))
62
+ ds_batch = {0: "batch"}
63
+ ds_batch_seq = {0: "batch", 1: "seq"}
64
+ kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
65
+ ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
66
+ print(CoupleInputsDynamicShapes((), kwargs, ds).replace_string_by())
67
+ """
68
+ return self._generic_walker(
69
+ lambda inputs, ds, value=value: self._replace_string_dim_tensor(
70
+ inputs, ds, value=value
71
+ )
72
+ )
73
+
74
+ @classmethod
75
+ def _replace_string_dim_tensor(cls, inputs, ds, value=None):
76
+ assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
77
+ assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
78
+ f"Unexpected types, inputs is a Tensor but ds is {ds}, "
79
+ f"a dictionary is expected to specify a dimension dimension"
80
+ )
81
+ if value is None:
82
+ value = torch.export.Dim.DYNAMIC
83
+ new_ds = ds.copy()
84
+ for i, v in ds.items():
85
+ if isinstance(v, str):
86
+ new_ds[i] = value
87
+ return new_ds
88
+
89
+ def invalid_dimensions_for_export(self):
90
+ """
91
+ Tells if the inputs are valid based on the dynamic shapes definition.
92
+ The method assumes that all custom classes can be serialized.
93
+ If some patches were applied to export, they should enabled while
94
+ calling this method if the inputs contains such classes.
95
+
96
+ The function checks that a dynamic dimension does not receive a value
97
+ of 0 or 1. It returns the unexpected values in the same structure as
98
+ the given dynamic shapes.
99
+
100
+ Example:
101
+
102
+ .. runpython::
103
+ :showcode:
104
+
105
+ import torch
106
+ from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
107
+
108
+ T3x1 = torch.rand((3, 1))
109
+ T3x4 = torch.rand((3, 4))
110
+ ds_batch = {0: "batch"}
111
+ ds_batch_seq = {0: "batch", 1: "seq"}
112
+ kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
113
+ ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
114
+ print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
115
+
116
+ In case it works, it shows:
117
+
118
+ .. runpython::
119
+ :showcode:
120
+
121
+ import torch
122
+ from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
123
+
124
+ T3x2 = torch.rand((3, 2))
125
+ T3x4 = torch.rand((3, 4))
126
+ ds_batch = {0: "batch"}
127
+ ds_batch_seq = {0: "batch", 1: "seq"}
128
+ kwargs = {"A": T3x4, "B": (T3x2, T3x2)}
129
+ ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
130
+ print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
131
+ """
132
+ return self._generic_walker(self._valid_shapes_tensor)
133
+
134
+ @classmethod
135
+ def _valid_shapes_tensor(cls, inputs, ds):
136
+ assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
137
+ assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
138
+ f"Unexpected types, inputs is a Tensor but ds is {ds}, "
139
+ f"a dictionary is expected to specify a dimension dimension"
140
+ )
141
+ issues = {}
142
+ for i, d in enumerate(inputs.shape):
143
+ if i in ds and not isinstance(ds[i], int):
144
+ # dynamic then
145
+ if d in {0, 1}:
146
+ # export issues for sure
147
+ issues[i] = f"d=[{d}]"
148
+ return issues if issues else None
149
+
150
+ def _generic_walker(self, processor: Callable):
151
+ """
152
+ Generic deserializator walking through inputs and dynamic_shapes all along.
153
+ The function returns a result with the same structure as the dynamic shapes.
154
+ """
155
+ if not self.args:
156
+ assert isinstance(self.kwargs, dict) and isinstance(self.dynamic_shapes, dict), (
157
+ f"Type mismatch, args={string_type(self.args)} and "
158
+ f"dynamic_shapes={self.dynamic_shapes} should have the same type."
159
+ )
160
+ return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
161
+
162
+ if not self.kwargs:
163
+ assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), (
164
+ f"Type mismatch, args={string_type(self.args)} and "
165
+ f"dynamic_shapes={self.dynamic_shapes} should have the same type."
166
+ )
167
+ return self._generic_walker_step(processor, self.args, self.dynamic_shapes)
168
+
169
+ assert isinstance(self.dynamic_shapes, dict), (
170
+ f"Both positional and named arguments (args and kwargs) are filled. "
171
+ f"dynamic shapes must a dictionary not {type(self.dynamic_shapes)}"
172
+ )
173
+ if not self.args_names and set(self.dynamic_shapes) & set(self.kwargs) == set(
174
+ self.dynamic_shapes
175
+ ):
176
+ # No dynamic shapes for the positional arguments.
177
+ return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
178
+
179
+ if isinstance(self.args_names, list):
180
+ if not set(self.args_names) & set(self.dynamic_shapes):
181
+ # No dynamic shapes for the positional arguments.
182
+ return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes)
183
+
184
+ assert self.args_names, (
185
+ "args and kwargs are filled, then args_names must be specified in "
186
+ "the constructor to move positional arguments to named arguments."
187
+ )
188
+ assert len(self.args) <= len(self.args_names), (
189
+ f"There are {len(self.args)} positional arguments "
190
+ f"but only {len(self.args_names)} names. "
191
+ f"args={string_type(self.args, with_shape=True)}, args_name={self.args_names}"
192
+ )
193
+ kwargs = dict(zip(self.args_names, self.args))
194
+ kwargs.update(self.kwargs)
195
+ return self._generic_walker_step(processor, kwargs, self.dynamic_shapes)
196
+
197
+ raise NotImplementedError(
198
+ f"Not yet implemented when args is filled, "
199
+ f"kwargs as well but args_names is {type(self.args_names)}"
200
+ )
201
+
202
+ @classmethod
203
+ def _generic_walker_step(cls, processor: Callable, inputs, ds):
204
+ if isinstance(inputs, torch.Tensor):
205
+ return processor(inputs, ds)
206
+ if isinstance(inputs, (int, float, str)):
207
+ return None
208
+ if isinstance(inputs, (tuple, list, dict)):
209
+ assert type(ds) is type(
210
+ inputs
211
+ ), f"Type mismatch between inputs {type(inputs)} and ds={type(ds)}"
212
+ assert len(ds) == len(inputs), (
213
+ f"Length mismatch between inputs {len(inputs)} "
214
+ f"and ds={len(ds)}\n"
215
+ f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
216
+ )
217
+ if isinstance(inputs, (tuple, list)):
218
+ value = []
219
+ for i, d in zip(inputs, ds):
220
+ value.append(cls._generic_walker_step(processor, i, d))
221
+ return (
222
+ (value if isinstance(ds, list) else tuple(value))
223
+ if any(v is not None for v in value)
224
+ else None
225
+ )
226
+ assert set(inputs) == set(
227
+ ds
228
+ ), f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}"
229
+ dvalue = {}
230
+ for k, v in inputs.items():
231
+ t = cls._generic_walker_step(processor, v, ds[k])
232
+ if t is not None:
233
+ dvalue[k] = t
234
+ return dvalue if dvalue else None
235
+
236
+ # A custom class.
237
+ assert inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, (
238
+ f"Class {inputs.__class__.__name__!r} was not registered using "
239
+ f"torch.utils._pytree.register_pytree_node, it is not possible to "
240
+ f"map this class with the given dynamic shapes."
241
+ )
242
+ flat, _spec = torch.utils._pytree.tree_flatten(inputs)
243
+ return cls._generic_walker_step(processor, flat, ds)
244
+
245
+ class ChangeDimensionProcessor:
246
+ def __init__(self):
247
+ self.mapping = {}
248
+
249
+ def _build_new_shape(
250
+ self, shape: Tuple[int, ...], ds: Dict[int, Any]
251
+ ) -> Tuple[int, ...]:
252
+ new_shape = list(shape)
253
+ for i in range(len(shape)):
254
+ if i in ds:
255
+ if isinstance(ds[i], str):
256
+ d = ds[i]
257
+ elif isinstance(
258
+ ds[i],
259
+ (
260
+ torch.export.dynamic_shapes._DerivedDim,
261
+ torch.export.dynamic_shapes._Dim,
262
+ ),
263
+ ):
264
+ d = str(ds[i])
265
+ elif not isinstance(ds[i], int):
266
+ raise NotImplementedError(f"Unable to handle type {ds[i]} in {ds}")
267
+ if d in self.mapping:
268
+ new_dim = self.mapping[d]
269
+ else:
270
+ new_dim = shape[i] + 1
271
+ self.mapping[d] = new_dim
272
+ new_shape[i] = new_dim
273
+ return tuple(new_shape)
274
+
275
+ def _build_new_tensor(self, tensor: torch.Tensor, new_shape: Tuple[int, ...]):
276
+ rank = len(tensor.shape)
277
+ for i in range(len(tensor.shape)):
278
+ d0 = tensor.shape[i]
279
+ d1 = new_shape[i]
280
+ if d0 == d1:
281
+ continue
282
+ alt_shape = list(tensor.shape)
283
+ alt_shape[i] = d1
284
+ new_tensor = torch.zeros(
285
+ tuple(alt_shape), dtype=tensor.dtype, device=tensor.device
286
+ )
287
+ mind = min(d0, d1)
288
+ indices = [slice(None) for _ in range(rank)]
289
+ indices[i] = slice(0, mind)
290
+ ind = tuple(indices)
291
+ new_tensor[ind] = tensor[ind]
292
+ if d1 > mind:
293
+ for k in range(d1 - mind):
294
+ indices0 = [slice(None) for _ in range(rank)]
295
+ indices1 = [slice(None) for _ in range(rank)]
296
+ indices1[i] = mind + k
297
+ indices0[i] = k % mind
298
+ new_tensor[tuple(indices1)] = tensor[tuple(indices0)]
299
+ tensor = new_tensor
300
+ return tensor
301
+
302
+ def __call__(self, inputs, ds):
303
+ assert isinstance(
304
+ inputs, torch.Tensor
305
+ ), f"unexpected type for inputs {type(inputs)}"
306
+ assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
307
+ f"Unexpected types, inputs is a Tensor but ds is {ds}, "
308
+ f"a dictionary is expected to specify a dimension dimension"
309
+ )
310
+ new_shape = self._build_new_shape(inputs.shape, ds)
311
+ return self._build_new_tensor(inputs, new_shape)
312
+
313
+ def change_dynamic_dimensions(self):
314
+ """
315
+ A model exported with dynamic shapes is not necessarily dynamic
316
+ just because the user specified dynamic shapes. The algorithm
317
+ may discover that a dimension cannot be dynamic and then continues
318
+ the export making the assumption it is static. That may lead a wrong
319
+ model. This function produces a new set of inputs with different values
320
+ for the dimension than the first ones, assuming they were used to export
321
+ the model.
322
+
323
+ Example:
324
+
325
+ .. runpython::
326
+ :showcode:
327
+
328
+ import torch
329
+ from onnx_diagnostic.helpers import string_type
330
+ from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
331
+
332
+ T3x15 = torch.rand((3, 15))
333
+ T3x20 = torch.rand((3, 20))
334
+ T3x4 = torch.rand((3, 4))
335
+ ds_batch = {0: "batch"}
336
+ ds_batch_seq = {0: "batch", 1: "seq"}
337
+ kwargs = {"A": T3x4, "B": (T3x15, T3x20)}
338
+ ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
339
+ new_kwargs = CoupleInputsDynamicShapes((), kwargs, ds).change_dynamic_dimensions()
340
+ print("before:", string_type(kwargs, with_shape=True))
341
+ print("-after:", string_type(new_kwargs, with_shape=True))
342
+ """
343
+ return self._generic_walker(self.ChangeDimensionProcessor())
6
344
 
7
345
 
8
346
  class ModelInputs:
@@ -69,7 +407,7 @@ class ModelInputs:
69
407
  ds = mi.guess_dynamic_shapes()
70
408
  pprint.pprint(ds)
71
409
 
72
- **and and kwargs**
410
+ **args and kwargs**
73
411
 
74
412
  .. runpython::
75
413
  :showcode:
@@ -218,7 +556,7 @@ class ModelInputs:
218
556
  return new_inputs
219
557
 
220
558
  @property
221
- def true_model_name(self):
559
+ def true_model_name(self) -> str:
222
560
  "Returns class name or module name."
223
561
  return (
224
562
  self.model.__class__.__name__
@@ -227,7 +565,7 @@ class ModelInputs:
227
565
  )
228
566
 
229
567
  @property
230
- def full_name(self):
568
+ def full_name(self) -> str:
231
569
  "Returns a name and class name."
232
570
  if self.method_name == "forward":
233
571
  return f"{self.name}:{self.true_model_name}"
@@ -240,10 +578,27 @@ class ModelInputs:
240
578
  return f"type({self.name})={self.true_model_name}"
241
579
  return f"type({self.name})={self.true_model_name}.{self.method_name}"
242
580
 
243
- def guess_dynamic_dimensions(self, *tensors) -> Dict[int, Any]:
244
- """Infers the dynamic dimension from multiple shapes."""
581
+ def guess_dynamic_dimensions(
582
+ self, *tensors, auto: bool = False
583
+ ) -> Optional[Dict[int, Any]]:
584
+ """
585
+ Infers the dynamic dimension from multiple shapes.
586
+ If auto is True, it returns ``torch.export.Dim.AUTO`` for every dimension
587
+ which cannot be guessed. Two tensors with the same value for one dimension
588
+ can be guessed, but if there is only 1, it cannot.
589
+ """
245
590
  if len(tensors) == 1:
246
- return {}
591
+ if isinstance(tensors[0], (int, float)):
592
+ return None
593
+ assert isinstance(tensors[0], torch.Tensor), (
594
+ f"Unexpected type for tensors {string_type(tensors, with_shape=True)}, "
595
+ f"Only tensors are allowed."
596
+ )
597
+ return (
598
+ {i: torch.export.Dim.AUTO for i in range(len(tensors[0].shape))} # noqa: C420
599
+ if auto
600
+ else {}
601
+ )
247
602
  shapes = [t.shape for t in tensors]
248
603
  set_length = set(len(s) for s in shapes)
249
604
  assert len(set_length) == 1, (
@@ -265,7 +620,9 @@ class ModelInputs:
265
620
  continue
266
621
  return res
267
622
 
268
- def guess_dynamic_shape_object(self, *objs: Any, msg: Optional[Callable] = None) -> Any:
623
+ def guess_dynamic_shape_object(
624
+ self, *objs: Any, auto: bool = False, msg: Optional[Callable] = None
625
+ ) -> Any:
269
626
  """Guesses the dynamic shapes for one argument."""
270
627
  if len(objs) == 0:
271
628
  return None
@@ -279,7 +636,7 @@ class ModelInputs:
279
636
  if isinstance(obj, (bool, int, float, str)):
280
637
  return None
281
638
  if isinstance(obj, (torch.Tensor, np.ndarray)):
282
- return self.guess_dynamic_dimensions(*objs)
639
+ return self.guess_dynamic_dimensions(*objs, auto=auto)
283
640
 
284
641
  if isinstance(obj, tuple):
285
642
  kl = set(len(o) for o in objs)
@@ -288,7 +645,9 @@ class ModelInputs:
288
645
  ), f"Unexpected variety of tuple lengths {kl}{msg() if msg else ''}"
289
646
  shapes: Any = []
290
647
  for i in range(kl.pop()):
291
- shapes.append(self.guess_dynamic_shape_object(*[o[i] for o in objs]))
648
+ shapes.append(
649
+ self.guess_dynamic_shape_object(*[o[i] for o in objs], auto=auto, msg=msg)
650
+ )
292
651
  return tuple(shapes)
293
652
 
294
653
  if isinstance(obj, list):
@@ -298,7 +657,9 @@ class ModelInputs:
298
657
  ), f"Unexpected variety of list lengths {kl}{msg() if msg else ''}"
299
658
  shapes = []
300
659
  for i in range(kl.pop()):
301
- shapes.append(self.guess_dynamic_shape_object(*[o[i] for o in objs]))
660
+ shapes.append(
661
+ self.guess_dynamic_shape_object(*[o[i] for o in objs], auto=auto, msg=msg)
662
+ )
302
663
  return shapes
303
664
 
304
665
  if isinstance(obj, dict):
@@ -308,9 +669,33 @@ class ModelInputs:
308
669
  ), f"Unexpected variety of dict lengths {kl}{msg() if msg else ''}"
309
670
  shapes = {}
310
671
  for i in obj:
311
- shapes[i] = self.guess_dynamic_shape_object(*[o[i] for o in objs])
672
+ shapes[i] = self.guess_dynamic_shape_object(
673
+ *[o[i] for o in objs], auto=auto, msg=msg
674
+ )
312
675
  return shapes
313
676
 
677
+ if obj.__class__ in torch.utils._pytree.SUPPORTED_NODES:
678
+ kcl = set(o.__class__ for o in objs)
679
+ assert len(kcl) == 1, (
680
+ f"All instances of argument {i} are not of the same class but {kcl}, "
681
+ f"types should be the same."
682
+ )
683
+ col_args = [flatten_unflatten_for_dynamic_shapes(o) for o in objs]
684
+ kc = set(len(o) for o in col_args)
685
+ assert len(kc) == 1, (
686
+ f"All instances of type {kcl.pop()} are not serialized into the same number "
687
+ f"of arguments, it should be the same."
688
+ )
689
+ values = []
690
+ for i in range(kc.pop()):
691
+ values.append(
692
+ self.guess_dynamic_shape_object(
693
+ *[ca[i] for ca in col_args], auto=auto, msg=msg
694
+ )
695
+ )
696
+ return values
697
+
698
+ # In case DynamicCache is not registered.
314
699
  if obj.__class__.__name__ == "DynamicCache":
315
700
  kc = set(len(o.key_cache) for o in objs)
316
701
  assert (
@@ -323,34 +708,39 @@ class ModelInputs:
323
708
  key_cache = []
324
709
  for i in range(kc.pop()):
325
710
  key_cache.append(
326
- self.guess_dynamic_dimensions(*[o.key_cache[i] for o in objs])
711
+ self.guess_dynamic_dimensions(*[o.key_cache[i] for o in objs], auto=auto)
327
712
  )
328
713
  value_cache = []
329
714
  for i in range(vc.pop()):
330
715
  value_cache.append(
331
- self.guess_dynamic_dimensions(*[o.value_cache[i] for o in objs])
716
+ self.guess_dynamic_dimensions(*[o.value_cache[i] for o in objs], auto=auto)
332
717
  )
333
718
  return [key_cache, value_cache]
334
719
 
335
720
  raise NotImplementedError(
336
721
  f"Unable to build dynamic shapes for type {set_types.pop()}: "
337
- f"{string_type(objs)}{msg() if msg else ''} in {self.module_name_type}"
722
+ f"{string_type(objs)}{msg() if msg else ''} in {self.module_name_type}, "
723
+ f"this object needs serialization function to be registered."
338
724
  )
339
725
 
340
- def guess_dynamic_shapes(
341
- self,
342
- ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
726
+ def guess_dynamic_shapes(self, auto: bool = False) -> DYNAMIC_SHAPES:
343
727
  """
344
728
  Guesses the dynamic shapes for that module from two execution.
345
729
  If there is only one execution, then that would be static dimensions.
730
+
731
+ :param auto: if auto is True, use ``torch.export.Dim.AUTO`` for any
732
+ dimension if the number of inputs is one
346
733
  """
347
734
  if len(self.inputs) == 0:
348
735
  # No inputs, unable to guess.
349
736
  return (tuple(), {})
350
737
  if len(self.inputs) == 1:
351
738
  # No dynamic shapes.
352
- return tuple(self.guess_dynamic_shape_object(a) for a in self.inputs[0][0]), {
353
- k: self.guess_dynamic_shape_object(v) for k, v in self.inputs[0][1].items()
739
+ return tuple(
740
+ self.guess_dynamic_shape_object(a, auto=auto) for a in self.inputs[0][0]
741
+ ), {
742
+ k: self.guess_dynamic_shape_object(v, auto=auto)
743
+ for k, v in self.inputs[0][1].items()
354
744
  }
355
745
 
356
746
  # Otherwise.
@@ -365,7 +755,9 @@ class ModelInputs:
365
755
  for i in range(s1.pop()):
366
756
  objs = [_[0][i] for _ in self.inputs]
367
757
  args.append(
368
- self.guess_dynamic_shape_object(*objs, msg=lambda i=i: f" failing input {i}")
758
+ self.guess_dynamic_shape_object(
759
+ *objs, auto=auto, msg=lambda i=i: f" failing input {i}"
760
+ )
369
761
  )
370
762
  names = s2.pop()
371
763
  for name in names:
@@ -377,7 +769,7 @@ class ModelInputs:
377
769
 
378
770
  objs = [_[1][name] for _ in self.inputs]
379
771
  kwargs[name] = self.guess_dynamic_shape_object(
380
- *objs, msg=lambda name=name: f" failing input {name!r}"
772
+ *objs, auto=auto, msg=lambda name=name: f" failing input {name!r}"
381
773
  )
382
774
  return tuple(args), kwargs
383
775
 
@@ -386,7 +778,7 @@ class ModelInputs:
386
778
  args: Tuple[Any, ...],
387
779
  kwargs: Dict[str, Any],
388
780
  dynamic_shapes: Tuple[Tuple[Any, ...], Dict[str, Any]],
389
- ) -> Tuple[Tuple[Any, ...], Dict[str, Any], Tuple[Tuple[Any, ...], Dict[str, Any]]]:
781
+ ) -> Tuple[Tuple[Any, ...], Dict[str, Any], DYNAMIC_SHAPES]:
390
782
  """
391
783
  Uses the signatures to move positional arguments (args) to named arguments (kwargs)
392
784
  with the corresponding dynamic shapes.
@@ -434,3 +826,22 @@ class ModelInputs:
434
826
  f"forward_ordered_parameter_names={self.forward_ordered_parameter_names}"
435
827
  )
436
828
  return args, kwargs, (tuple(), kw_dyn)
829
+
830
+ def validate_inputs_for_export(
831
+ self, dynamic_shapes: Optional[DYNAMIC_SHAPES] = None
832
+ ) -> List[List[Union[int, str]]]:
833
+ """
834
+ Validates the inputs the class contains for the given dynamic shapes.
835
+ If not specified, the dynamic_shapes are guessed.
836
+
837
+ :param dynamic_shapes: dynamic shapes to validate
838
+ :return: a list of lists, every list contains the path the invalid dimension
839
+ """
840
+ if dynamic_shapes is None:
841
+ if len(self.inputs) == 1:
842
+ return []
843
+ dyn_shapes = self.guess_dynamic_shapes()
844
+ return [
845
+ CoupleInputsDynamicShapes(*i, dyn_shapes).invalid_dimensions_for_export()
846
+ for i in self.inputs
847
+ ]