onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1083 @@
1
+ import inspect
2
+ import itertools
3
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
4
+ import numpy as np
5
+ import torch
6
+ from ..helpers import string_type
7
+ from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
8
+
9
+ DYNAMIC_SHAPES = Tuple[Tuple[Any, ...], Dict[str, Any]]
10
+
11
+
12
+ def _flatten_dynamic_shapes(ds: Any) -> Any:
13
+ """Flattens the dynamic shapes."""
14
+ if isinstance(ds, list):
15
+ return _flat_list([_flatten_dynamic_shapes(t) for t in ds])
16
+ if isinstance(ds, tuple):
17
+ return tuple(_flat_list([_flatten_dynamic_shapes(t) for t in ds]))
18
+ if isinstance(ds, dict):
19
+ if all(isinstance(i, int) for i in ds):
20
+ # That's a dynamic shape
21
+ return ds
22
+ return _flat_list([_flatten_dynamic_shapes(t) for t in ds.values()])
23
+ raise AssertionError(f"Not implemented for {type(ds)}: {ds}")
24
+
25
+
26
+ def _flat_list(li: List[Any]) -> List[Dict[int, str]]:
27
+ res = []
28
+ for t in li:
29
+ if isinstance(t, dict):
30
+ res.append(t)
31
+ else:
32
+ res.extend(t)
33
+ return res
34
+
35
+
36
+ class CoupleInputsDynamicShapes:
37
+ """
38
+ Pair inputs / dynamic shapes.
39
+
40
+ :param args: positional arguments
41
+ :param kwargs: named arguments
42
+ :param dynamic_shapes: dynamic shapes
43
+ :param args_names: if both args and kwargs are not empty, then
44
+ dynamic shapes must be a dictionary, and positional must be added
45
+ to the named arguments. Arguments names or a module must be given
46
+ in that case.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ args: Tuple[Any, ...],
52
+ kwargs: Dict[str, Any],
53
+ dynamic_shapes: DYNAMIC_SHAPES,
54
+ args_names: Optional[Union[torch.nn.Module, List[str]]] = None,
55
+ ):
56
+ self.args = args
57
+ self.kwargs = kwargs
58
+ self.dynamic_shapes = dynamic_shapes
59
+ self.args_names = args_names
60
+ if not self.kwargs and isinstance(self.dynamic_shapes, dict):
61
+ # This assumes the dictionary for the dynamic shapes is ordered
62
+ # the same way the args are. The input names are not known.
63
+ assert len(self.dynamic_shapes) == len(self.args), (
64
+ f"Length mismatch, kwargs is empty, len(dynamic_shapes)="
65
+ f"{len(self.dynamic_shapes)}, len(args)={len(self.args)}"
66
+ )
67
+ self.dynamic_shapes = tuple(self.dynamic_shapes.values())
68
+
69
+ def __str__(self) -> str:
70
+ return "\n".join(
71
+ [
72
+ f"{self.__class__.__name__}(",
73
+ f" args={string_type(self.args, with_shape=True)},"
74
+ f" kwargs={string_type(self.kwargs, with_shape=True)},"
75
+ f" dynamic_shapes={string_type(self.dynamic_shapes, with_shape=True)},"
76
+ f")",
77
+ ]
78
+ )
79
+
80
+ def replace_string_by(self, value: Any = None):
81
+ """
82
+ Replaces string by the value ``torch.export.Dim.DYNAMIC``
83
+ (default) or any other value specified by value.
84
+
85
+ Example:
86
+
87
+ .. runpython::
88
+ :showcode:
89
+
90
+ import torch
91
+ from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
92
+
93
+ T3x1 = torch.rand((3, 1))
94
+ T3x4 = torch.rand((3, 4))
95
+ ds_batch = {0: "batch"}
96
+ ds_batch_seq = {0: "batch", 1: "seq"}
97
+ kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
98
+ ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
99
+ print(CoupleInputsDynamicShapes((), kwargs, ds).replace_string_by())
100
+ """
101
+ return self._generic_walker(
102
+ lambda inputs, ds, value=value: self._replace_string_dim_tensor(
103
+ inputs, ds, value=value
104
+ ),
105
+ flatten_unflatten=True,
106
+ )
107
+
108
+ @classmethod
109
+ def _replace_string_dim_tensor(cls, inputs, ds, value=None):
110
+ assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
111
+ assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
112
+ f"Unexpected types, inputs is a Tensor but ds is {ds}, "
113
+ f"a dictionary is expected to specify a dimension"
114
+ )
115
+ if value is None:
116
+ value = torch.export.Dim.DYNAMIC
117
+ new_ds = ds.copy()
118
+ for i, v in ds.items():
119
+ if isinstance(v, str):
120
+ new_ds[i] = value
121
+ return new_ds
122
+
123
+ def replace_by_string(self):
124
+ """
125
+ Replaces dimensions by strings.
126
+
127
+ Example:
128
+
129
+ .. runpython::
130
+ :showcode:
131
+
132
+ import torch
133
+ from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
134
+
135
+ Dim = torch.export.Dim
136
+ T3x1 = torch.rand((3, 1))
137
+ T3x4 = torch.rand((3, 4))
138
+ ds_batch = {0: Dim("batch")}
139
+ ds_batch_seq = {0: Dim("batch"), 1: Dim("seq")}
140
+ kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
141
+ ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
142
+ print(CoupleInputsDynamicShapes((), kwargs, ds).replace_by_string())
143
+ """
144
+ unique = set()
145
+ return self._generic_walker(
146
+ lambda inputs, ds, unique=unique: self._replace_dim_tensor_by_string(
147
+ inputs, ds, unique=unique
148
+ ),
149
+ flatten_unflatten=True,
150
+ )
151
+
152
+ @classmethod
153
+ def _replace_dim_tensor_by_string(cls, inputs, ds, unique: Set[str]):
154
+ assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
155
+ assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
156
+ f"Unexpected types, inputs is a Tensor but ds is {ds}, "
157
+ f"a dictionary is expected to specify a dimension"
158
+ )
159
+ new_ds = ds.copy()
160
+ for i, v in ds.items():
161
+ if isinstance(v, str):
162
+ unique.add(v)
163
+ new_ds[i] = v
164
+ elif v in (torch.export.Dim.DYNAMIC, torch.export.Dim.AUTO):
165
+ name = f"Dim{len(unique)}"
166
+ new_ds[i] = name
167
+ unique.add(name)
168
+ else:
169
+ name = v.__name__
170
+ unique.add(name)
171
+ new_ds[i] = name
172
+ return new_ds
173
+
174
+ def invalid_dimensions_for_export(self):
175
+ """
176
+ Tells if the inputs are valid based on the dynamic shapes definition.
177
+ The method assumes that all custom classes can be serialized.
178
+ If some patches were applied to export, they should enabled while
179
+ calling this method if the inputs contains such classes.
180
+
181
+ The function checks that a dynamic dimension does not receive a value
182
+ of 0 or 1. It returns the unexpected values in the same structure as
183
+ the given dynamic shapes.
184
+
185
+ Example:
186
+
187
+ .. runpython::
188
+ :showcode:
189
+
190
+ import torch
191
+ from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
192
+
193
+ T3x1 = torch.rand((3, 1))
194
+ T3x4 = torch.rand((3, 4))
195
+ ds_batch = {0: "batch"}
196
+ ds_batch_seq = {0: "batch", 1: "seq"}
197
+ kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
198
+ ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
199
+ print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
200
+
201
+ In case it works, it shows:
202
+
203
+ .. runpython::
204
+ :showcode:
205
+
206
+ import torch
207
+ from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
208
+
209
+ T3x2 = torch.rand((3, 2))
210
+ T3x4 = torch.rand((3, 4))
211
+ ds_batch = {0: "batch"}
212
+ ds_batch_seq = {0: "batch", 1: "seq"}
213
+ kwargs = {"A": T3x4, "B": (T3x2, T3x2)}
214
+ ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
215
+ print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
216
+ """
217
+ return self._generic_walker(self._valid_shapes_tensor, flatten_unflatten=True)
218
+
219
+ @classmethod
220
+ def _valid_shapes_tensor(cls, inputs, ds):
221
+ assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
222
+ assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
223
+ f"Unexpected types, inputs is a Tensor but ds is {ds}, "
224
+ f"a dictionary is expected to specify a dimension dimension"
225
+ )
226
+ issues = {}
227
+ for i, d in enumerate(inputs.shape):
228
+ if i in ds and not isinstance(ds[i], int):
229
+ # dynamic then
230
+ if isinstance(d, int) and d in {0, 1}:
231
+ # export issues for sure
232
+ issues[i] = f"d=[{d}]"
233
+ return issues if issues else None
234
+
235
+ def _generic_walker(
236
+ self, processor: Callable, args_kwargs: bool = False, flatten_unflatten: bool = False
237
+ ):
238
+ """
239
+ Generic deserializator walking through inputs and dynamic_shapes all along.
240
+ The function returns a result with the same structure as the dynamic shapes.
241
+ """
242
+ if not self.args:
243
+ assert isinstance(self.kwargs, dict) and isinstance(self.dynamic_shapes, dict), (
244
+ f"Type mismatch, args={string_type(self.args)}, "
245
+ f"kwargs={string_type(self.kwargs)} and dynamic_shapes="
246
+ f"{string_type(self.dynamic_shapes)} should have the same type."
247
+ )
248
+ res = self._generic_walker_step(
249
+ processor,
250
+ self.kwargs,
251
+ self.dynamic_shapes,
252
+ flatten_unflatten=flatten_unflatten,
253
+ )
254
+ return (tuple(), res) if args_kwargs else res
255
+
256
+ if not self.kwargs:
257
+ assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), (
258
+ f"Type mismatch, args={string_type(self.args)} and "
259
+ f"dynamic_shapes={self.dynamic_shapes} should have the same type."
260
+ )
261
+ res = self._generic_walker_step(
262
+ processor, self.args, self.dynamic_shapes, flatten_unflatten=flatten_unflatten
263
+ )
264
+ return (res, {}) if args_kwargs else res
265
+
266
+ assert isinstance(self.dynamic_shapes, dict), (
267
+ f"Both positional and named arguments (args and kwargs) are filled. "
268
+ f"dynamic shapes must a dictionary not {type(self.dynamic_shapes)}"
269
+ )
270
+ if not self.args_names and set(self.dynamic_shapes) & set(self.kwargs) == set(
271
+ self.dynamic_shapes
272
+ ):
273
+ # No dynamic shapes for the positional arguments.
274
+ return self._generic_walker_step(
275
+ processor,
276
+ self.kwargs,
277
+ self.dynamic_shapes,
278
+ flatten_unflatten=flatten_unflatten,
279
+ )
280
+
281
+ if isinstance(self.args_names, list):
282
+ if not set(self.args_names) & set(self.dynamic_shapes):
283
+ # No dynamic shapes for the positional arguments.
284
+ return self._generic_walker_step(
285
+ processor,
286
+ self.kwargs,
287
+ self.dynamic_shapes,
288
+ flatten_unflatten=flatten_unflatten,
289
+ )
290
+
291
+ assert self.args_names, (
292
+ "args and kwargs are filled, then args_names must be specified in "
293
+ "the constructor to move positional arguments to named arguments."
294
+ )
295
+ assert len(self.args) <= len(self.args_names), (
296
+ f"There are {len(self.args)} positional arguments "
297
+ f"but only {len(self.args_names)} names. "
298
+ f"args={string_type(self.args, with_shape=True)}, args_name={self.args_names}"
299
+ )
300
+ kwargs = dict(zip(self.args_names, self.args))
301
+ kwargs.update(self.kwargs)
302
+ res = self._generic_walker_step(
303
+ processor, kwargs, self.dynamic_shapes, flatten_unflatten=flatten_unflatten
304
+ )
305
+ if args_kwargs:
306
+ pgs = [None for _ in range(len(self.args))]
307
+ kws = {}
308
+ for k, v in res.items():
309
+ if k not in self.kwargs:
310
+ pgs[self.args_names.index(k)] = v
311
+ else:
312
+ kws[k] = v
313
+ return pgs, kws
314
+ return res
315
+
316
+ raise NotImplementedError(
317
+ f"Not yet implemented when args is filled, "
318
+ f"kwargs as well but args_names is {type(self.args_names)}"
319
+ )
320
+
321
+ @classmethod
322
+ def _generic_walker_step(
323
+ cls, processor: Callable, inputs, ds, flatten_unflatten: bool = False
324
+ ):
325
+ if isinstance(inputs, torch.Tensor):
326
+ return processor(inputs, ds)
327
+ if isinstance(inputs, (int, float, str)):
328
+ return None
329
+ if type(inputs) in (tuple, list, dict):
330
+ # Type must be strict, some custom classes can inherit from those.
331
+ assert type(inputs) is type(ds), (
332
+ f"Input type and dynamic shape type mush match but "
333
+ f"type(inputs)={type(inputs)}, type(ds)={type(ds)}, "
334
+ f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
335
+ )
336
+ assert len(ds) == len(inputs), (
337
+ f"Length mismatch between inputs {len(inputs)} "
338
+ f"and ds={len(ds)}\n"
339
+ f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
340
+ )
341
+ if type(inputs) in (tuple, list):
342
+ value = []
343
+ for i, d in zip(inputs, ds):
344
+ value.append(
345
+ cls._generic_walker_step(
346
+ processor, i, d, flatten_unflatten=flatten_unflatten
347
+ )
348
+ )
349
+ return (
350
+ (value if isinstance(ds, list) else tuple(value))
351
+ if any(v is not None for v in value)
352
+ else None
353
+ )
354
+ assert type(inputs) is dict, f"Unexpected type for inputs {type(inputs)}"
355
+ assert set(inputs) == set(ds), (
356
+ f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, "
357
+ f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
358
+ )
359
+ dvalue = {}
360
+ for k, v in inputs.items():
361
+ t = cls._generic_walker_step(
362
+ processor, v, ds[k], flatten_unflatten=flatten_unflatten
363
+ )
364
+ if t is not None:
365
+ dvalue[k] = t
366
+ return dvalue if dvalue else None
367
+
368
+ # A custom class.
369
+ assert inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, (
370
+ f"Class {inputs.__class__.__name__!r} was not registered using "
371
+ f"torch.utils._pytree.register_pytree_node, it is not possible to "
372
+ f"map this class with the given dynamic shapes."
373
+ )
374
+ if flatten_unflatten:
375
+ flatunflat = flatten_unflatten_for_dynamic_shapes(inputs)
376
+ res = cls._generic_walker_step(
377
+ processor, flatunflat, ds, flatten_unflatten=flatten_unflatten
378
+ )
379
+ # Should we restore the original class?
380
+ return res
381
+ flat, spec = torch.utils._pytree.tree_flatten(inputs)
382
+ if all(isinstance(t, torch.Tensor) for t in flat):
383
+ # We need to flatten dynamic shapes as well
384
+ ds = _flatten_dynamic_shapes(ds)
385
+ res = cls._generic_walker_step(
386
+ processor, flat, ds, flatten_unflatten=flatten_unflatten
387
+ )
388
+ # Then we restore the original class.
389
+ return torch.utils._pytree.tree_unflatten(res, spec)
390
+
391
+ class ChangeDimensionProcessor:
392
+ def __init__(self, desired_values, only_desired):
393
+ self.mapping = desired_values or {}
394
+ self.only_desired = only_desired
395
+
396
+ def _build_new_shape(
397
+ self, shape: Tuple[int, ...], ds: Dict[int, Any]
398
+ ) -> Tuple[int, ...]:
399
+ new_shape = list(shape)
400
+ for i in range(len(shape)):
401
+ if i in ds:
402
+ if isinstance(ds[i], str):
403
+ d = ds[i]
404
+ elif isinstance(
405
+ ds[i],
406
+ (
407
+ torch.export.dynamic_shapes._DerivedDim,
408
+ torch.export.dynamic_shapes._Dim,
409
+ ),
410
+ ):
411
+ d = ds[i].__name__
412
+ elif not isinstance(ds[i], int):
413
+ raise NotImplementedError(f"Unable to handle type {ds[i]} in {ds}")
414
+ if d in self.mapping:
415
+ new_dim = self.mapping[d]
416
+ elif not self.only_desired:
417
+ new_dim = shape[i] + 1
418
+ self.mapping[d] = new_dim
419
+ else:
420
+ new_dim = shape[i]
421
+ new_shape[i] = new_dim
422
+ return tuple(new_shape)
423
+
424
+ def _build_new_tensor(self, tensor: torch.Tensor, new_shape: Tuple[int, ...]):
425
+ rank = len(tensor.shape)
426
+ for i in range(len(tensor.shape)):
427
+ d0 = tensor.shape[i]
428
+ d1 = new_shape[i]
429
+ if d0 == d1:
430
+ continue
431
+ alt_shape = list(tensor.shape)
432
+ alt_shape[i] = d1
433
+ new_tensor = torch.zeros(
434
+ tuple(alt_shape), dtype=tensor.dtype, device=tensor.device
435
+ )
436
+ mind = min(d0, d1)
437
+ indices: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
438
+ indices[i] = slice(0, mind)
439
+ ind = tuple(indices)
440
+ new_tensor[ind] = tensor[ind]
441
+ if d1 > mind:
442
+ for k in range(d1 - mind):
443
+ indices0: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
444
+ indices1: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
445
+ indices1[i] = mind + k
446
+ indices0[i] = k % mind
447
+ new_tensor[tuple(indices1)] = tensor[tuple(indices0)]
448
+ tensor = new_tensor
449
+ return tensor
450
+
451
+ def __call__(self, inputs, ds):
452
+ assert isinstance(
453
+ inputs, torch.Tensor
454
+ ), f"unexpected type for inputs {type(inputs)}"
455
+ assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
456
+ f"Unexpected types, inputs is a Tensor but ds is {ds}, "
457
+ f"a dictionary is expected to specify a dimension dimension"
458
+ )
459
+ new_shape = self._build_new_shape(inputs.shape, ds)
460
+ return self._build_new_tensor(inputs, new_shape)
461
+
462
+ def change_dynamic_dimensions(
463
+ self,
464
+ desired_values: Optional[Dict[str, int]] = None,
465
+ args_kwargs: bool = False,
466
+ only_desired: bool = False,
467
+ ):
468
+ """
469
+ A model exported with dynamic shapes is not necessarily dynamic
470
+ just because the user specified dynamic shapes. The algorithm
471
+ may discover that a dimension cannot be dynamic and then continues
472
+ the export making the assumption it is static. That may lead a wrong
473
+ model. This function produces a new set of inputs with different values
474
+ for the dimension than the first ones, assuming they were used to export
475
+ the model.
476
+
477
+ :param desired_values: to fixed named dimension to have the desired value
478
+ :param args_kwargs: return both args, kwargs even if empty
479
+ :param only_desired: if True, only change the dimension specified in
480
+ ``desired_values``
481
+ :return: new inputs
482
+
483
+ Example:
484
+
485
+ .. runpython::
486
+ :showcode:
487
+
488
+ import torch
489
+ from onnx_diagnostic.helpers import string_type
490
+ from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
491
+
492
+ T3x15 = torch.rand((3, 15))
493
+ T3x20 = torch.rand((3, 20))
494
+ T3x4 = torch.rand((3, 4))
495
+ ds_batch = {0: "batch"}
496
+ ds_batch_seq = {0: "batch", 1: "seq"}
497
+ kwargs = {"A": T3x4, "B": (T3x15, T3x20)}
498
+ ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
499
+ new_kwargs = CoupleInputsDynamicShapes((), kwargs, ds).change_dynamic_dimensions()
500
+ print("before:", string_type(kwargs, with_shape=True))
501
+ print("-after:", string_type(new_kwargs, with_shape=True))
502
+ """
503
+ return self._generic_walker(
504
+ self.ChangeDimensionProcessor(desired_values, only_desired=only_desired),
505
+ args_kwargs=args_kwargs,
506
+ )
507
+
508
+
509
+ class ModelInputs:
510
+ """
511
+ Wraps a model and a couple of sets of valid inputs.
512
+ Based on that information, the class is able to infer the dynamic shapes
513
+ for :func:`torch.export.export`.
514
+
515
+ :param model: model to export
516
+ :param inputs: list of valid set of inputs
517
+ :param level: if this module is a submodule, it is the level of submodule
518
+ :param method_name: by default, the forward method is processed but it
519
+ could be another one
520
+ :param name: a name, mostly for debugging purposes
521
+
522
+ Examples:
523
+
524
+ **args**
525
+
526
+ .. runpython::
527
+ :showcode:
528
+
529
+ import pprint
530
+ import torch
531
+ from onnx_diagnostic.export import ModelInputs
532
+
533
+
534
+ class Model(torch.nn.Module):
535
+ def forward(self, x, y):
536
+ return x + y
537
+
538
+
539
+ model = Model()
540
+ x = torch.randn((5, 6))
541
+ y = torch.randn((1, 6))
542
+ model(x, y) # to check it works
543
+
544
+ inputs = [(x, y), (torch.randn((7, 8)), torch.randn((1, 8)))]
545
+ mi = ModelInputs(Model(), inputs)
546
+ ds = mi.guess_dynamic_shapes()
547
+ pprint.pprint(ds)
548
+
549
+ **kwargs**
550
+
551
+ .. runpython::
552
+ :showcode:
553
+
554
+ import pprint
555
+ import torch
556
+ from onnx_diagnostic.export import ModelInputs
557
+
558
+ class Model(torch.nn.Module):
559
+ def forward(self, x, y):
560
+ return x + y
561
+
562
+
563
+ model = Model()
564
+ x = torch.randn((5, 6))
565
+ y = torch.randn((1, 6))
566
+ model(x=x, y=y) # to check it works
567
+
568
+ inputs = [dict(x=x, y=y), dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)))]
569
+ mi = ModelInputs(Model(), inputs)
570
+ ds = mi.guess_dynamic_shapes()
571
+ pprint.pprint(ds)
572
+
573
+ **args and kwargs**
574
+
575
+ .. runpython::
576
+ :showcode:
577
+
578
+ import pprint
579
+ import torch
580
+ from onnx_diagnostic.export import ModelInputs
581
+
582
+ class Model(torch.nn.Module):
583
+ def forward(self, x, y):
584
+ return x + y
585
+
586
+
587
+ model = Model()
588
+ x = torch.randn((5, 6))
589
+ y = torch.randn((1, 6))
590
+ model(x, y=y) # to check it works
591
+
592
+ inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
593
+ mi = ModelInputs(Model(), inputs)
594
+ ds = mi.guess_dynamic_shapes()
595
+ pprint.pprint(ds)
596
+
597
+ :func:`torch.export.export` does not like dynamic shapes defined both as args and kwargs.
598
+ kwargs must be used. ``move_to_kwargs`` modifies the inputs and the dynamic shapes
599
+ to make the model and the given inputs exportable.
600
+
601
+ .. runpython::
602
+ :showcode:
603
+
604
+ import pprint
605
+ import torch
606
+ from onnx_diagnostic.export import ModelInputs
607
+ from onnx_diagnostic.helpers import string_type
608
+
609
+
610
+ class Model(torch.nn.Module):
611
+ def forward(self, x, y):
612
+ return x + y
613
+
614
+
615
+ model = Model()
616
+ x = torch.randn((5, 6))
617
+ y = torch.randn((1, 6))
618
+ model(x, y=y) # to check it works
619
+
620
+ inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
621
+ mi = ModelInputs(Model(), inputs)
622
+ ds = mi.guess_dynamic_shapes()
623
+
624
+ a, kw, nds = mi.move_to_kwargs(*mi.inputs[0], ds)
625
+ print("moved args:", string_type(a, with_shape=True))
626
+ print("moved kwargs:", string_type(kw, with_shape=True))
627
+ print("dynamic shapes:")
628
+ pprint.pprint(nds)
629
+ """
630
+
631
+ def __init__(
632
+ self,
633
+ model: torch.nn.Module,
634
+ inputs: Union[
635
+ List[Tuple[Any, ...]],
636
+ List[Dict[str, Any]],
637
+ List[Tuple[Tuple[Any, ...], Dict[str, Any]]],
638
+ ],
639
+ level: int = 0,
640
+ method_name: str = "forward",
641
+ name: str = "main",
642
+ ):
643
+ assert (
644
+ model is None or isinstance(model, torch.nn.Module) or inspect.ismodule(model)
645
+ ), (
646
+ f"unexpected type for model={type(model)}, "
647
+ f"it must be a torch.nn.Module or None"
648
+ )
649
+ assert name, (
650
+ f"name={name!r} cannot be empty this string is used to "
651
+ f"display meaningful error messages"
652
+ )
653
+ self.name = name
654
+ self.model = model
655
+ self.level = level
656
+ self.method_name = method_name
657
+ self.forward = getattr(model, method_name) if model is not None else None
658
+ self.signature = inspect.signature(self.forward) if self.forward else None
659
+
660
+ # information about the signature
661
+ self.forward_parameter_names = (
662
+ set(
663
+ p.name
664
+ for p in self.signature.parameters.values()
665
+ if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
666
+ )
667
+ if self.signature
668
+ else None
669
+ )
670
+ self.forward_ordered_parameter_names = (
671
+ list(self.signature.parameters) if self.signature else None
672
+ )
673
+ self.forward_positioned_parameter_names = (
674
+ [
675
+ p.name
676
+ for p in self.signature.parameters.values()
677
+ if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
678
+ ]
679
+ if self.signature
680
+ else None
681
+ )
682
+ names = (
683
+ [p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL]
684
+ if self.signature
685
+ else None
686
+ )
687
+ self.forward_args = names[0] if names else None
688
+ names = (
689
+ [p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD]
690
+ if self.signature
691
+ else None
692
+ )
693
+ self.forward_kwargs = names[0] if names else None
694
+ self.forward_custom_op_schema = None
695
+ self.forward_need_serialization = False
696
+ self.forward_fill_kwargs = bool(self.forward_kwargs)
697
+ assert not isinstance(
698
+ model, (torch.nn.ModuleList, torch.nn.ModuleDict)
699
+ ), f"ModuleList or ModuleDict should not be traced: {type(model)}"
700
+
701
+ # process the inputs
702
+ self.inputs = self.process_inputs(inputs)
703
+
704
+ def process_inputs(
705
+ self,
706
+ inputs: Union[
707
+ List[Tuple[Any, ...]],
708
+ List[Dict[str, Any]],
709
+ List[Tuple[Tuple[Any, ...], Dict[str, Any]]],
710
+ ],
711
+ ) -> List[Tuple[Tuple[Any, ...], Dict[str, Any]]]:
712
+ """
713
+ Transforms a list of valid inputs, list of args, list of kwargs or list of both
714
+ into a list of (args, kwargs).
715
+ """
716
+ if not isinstance(inputs, list):
717
+ raise ValueError(
718
+ f"inputs should be specified as a list of sets of "
719
+ f"inputs but type(inputs) is {type(inputs)}"
720
+ )
721
+ new_inputs = []
722
+ for i, inp in enumerate(inputs):
723
+ if (
724
+ isinstance(inp, tuple)
725
+ and len(inp) == 2
726
+ and isinstance(inp[0], tuple)
727
+ and isinstance(inp[1], dict)
728
+ ):
729
+ new_inputs.append(inp)
730
+ continue
731
+ if isinstance(inp, tuple):
732
+ new_inputs.append((inp, {}))
733
+ continue
734
+ if isinstance(inp, dict):
735
+ new_inputs.append(((), inp))
736
+ continue
737
+ raise ValueError(f"Unable to interpret inputs {i}: {string_type(inp)}")
738
+ return new_inputs
739
+
740
+ @property
741
+ def true_model_name(self) -> str:
742
+ "Returns class name or module name."
743
+ assert self.model is not None, "model was None when the class was initialized."
744
+ return (
745
+ self.model.__class__.__name__
746
+ if isinstance(self.model, torch.nn.Module)
747
+ else self.model.__name__
748
+ )
749
+
750
+ @property
751
+ def full_name(self) -> str:
752
+ "Returns a name and class name."
753
+ if self.method_name == "forward":
754
+ return f"{self.name}:{self.true_model_name}"
755
+ return f"{self.name}:{self.true_model_name}.{self.method_name}"
756
+
757
+ @property
758
+ def module_name_type(self):
759
+ "Returns name and module type."
760
+ if self.method_name == "forward":
761
+ return f"type({self.name})={self.true_model_name}"
762
+ return f"type({self.name})={self.true_model_name}.{self.method_name}"
763
+
764
+ def guess_dynamic_dimensions(
765
+ self, *tensors, auto: Union[bool, str] = False
766
+ ) -> Optional[Dict[int, Any]]:
767
+ """
768
+ Infers the dynamic dimension from multiple shapes.
769
+ If auto is True, it returns ``torch.export.Dim.AUTO`` for every dimension
770
+ which cannot be guessed. Two tensors with the same value for one dimension
771
+ can be guessed, but if there is only 1, it cannot. ``auto``` can be a string
772
+ to produce strings.
773
+ """
774
+ if len(tensors) == 1:
775
+ if isinstance(tensors[0], (int, float)):
776
+ return None
777
+ assert isinstance(tensors[0], torch.Tensor), (
778
+ f"Unexpected type for tensors {string_type(tensors, with_shape=True)}, "
779
+ f"Only tensors are allowed."
780
+ )
781
+ return (
782
+ {i: torch.export.Dim.AUTO for i in range(len(tensors[0].shape))} # noqa: C420
783
+ if auto and not isinstance(auto, str)
784
+ else {}
785
+ )
786
+ shapes = [t.shape for t in tensors]
787
+ set_length = set(len(s) for s in shapes)
788
+ assert len(set_length) == 1, (
789
+ f"Shapes can be different but not ranks possible shapes={set_length} "
790
+ f"shapes={shapes} for module {self.name!r}, "
791
+ f"class={self.true_model_name!r}"
792
+ )
793
+ dynamic: Any = (
794
+ auto
795
+ if isinstance(auto, str)
796
+ else (torch.export.Dim.AUTO if auto else torch.export.Dim.DYNAMIC)
797
+ )
798
+ rk = set_length.pop()
799
+ res = {}
800
+ for i in range(rk):
801
+ set_dim = set(s[i] for s in shapes)
802
+ if len(set_dim) > 1:
803
+ res[i] = dynamic if not isinstance(dynamic, str) else f"{dynamic}{i}"
804
+ continue
805
+ if set_dim == {0}:
806
+ # It is unexpected to find a null dimension. Let's replace it by a dynamic one.
807
+ res[i] = dynamic if not isinstance(dynamic, str) else f"{dynamic}{i}"
808
+ continue
809
+ return res
810
+
811
+ def guess_dynamic_shape_object(
812
+ self, *objs: Any, auto: Union[bool, str] = False, msg: Optional[Callable] = None
813
+ ) -> Any:
814
+ """Guesses the dynamic shapes for one argument."""
815
+ if len(objs) == 0:
816
+ return None
817
+ set_types = set(type(o) for o in objs)
818
+ assert (
819
+ len(set_types) == 1
820
+ ), f"Unexpected variety of input type {set_types}{msg() if msg else ''})"
821
+ obj = objs[0]
822
+ if obj is None:
823
+ return None
824
+ if isinstance(obj, (bool, int, float, str)):
825
+ return None
826
+ if isinstance(obj, (torch.Tensor, np.ndarray)):
827
+ return self.guess_dynamic_dimensions(*objs, auto=auto)
828
+
829
+ if isinstance(obj, tuple):
830
+ kl = set(len(o) for o in objs)
831
+ assert (
832
+ len(kl) == 1
833
+ ), f"Unexpected variety of tuple lengths {kl}{msg() if msg else ''}"
834
+ shapes: Any = []
835
+ for i in range(kl.pop()):
836
+ shapes.append(
837
+ self.guess_dynamic_shape_object(
838
+ *[o[i] for o in objs],
839
+ auto=auto if isinstance(auto, bool) else f"{auto}_{i}t",
840
+ msg=msg,
841
+ )
842
+ )
843
+ return tuple(shapes)
844
+
845
+ if isinstance(obj, list):
846
+ kl = set(len(o) for o in objs)
847
+ assert (
848
+ len(kl) == 1
849
+ ), f"Unexpected variety of list lengths {kl}{msg() if msg else ''}"
850
+ shapes = []
851
+ for i in range(kl.pop()):
852
+ shapes.append(
853
+ self.guess_dynamic_shape_object(
854
+ *[o[i] for o in objs],
855
+ auto=auto if isinstance(auto, bool) else f"{auto}_{i}l",
856
+ msg=msg,
857
+ )
858
+ )
859
+ return shapes
860
+
861
+ if isinstance(obj, dict):
862
+ kl = set(len(o) for o in objs)
863
+ assert (
864
+ len(kl) == 1
865
+ ), f"Unexpected variety of dict lengths {kl}{msg() if msg else ''}"
866
+ shapes = {}
867
+ for i in obj:
868
+ shapes[i] = self.guess_dynamic_shape_object(
869
+ *[o[i] for o in objs],
870
+ auto=auto if isinstance(auto, bool) else f"{auto}_{i}d",
871
+ msg=msg,
872
+ )
873
+ return shapes
874
+
875
+ if obj.__class__ in torch.utils._pytree.SUPPORTED_NODES:
876
+ kcl = set(o.__class__ for o in objs)
877
+ assert len(kcl) == 1, (
878
+ f"All instances of argument {i} are not of the same class but {kcl}, "
879
+ f"types should be the same."
880
+ )
881
+ col_args = [flatten_unflatten_for_dynamic_shapes(o) for o in objs]
882
+ kc = set(len(o) for o in col_args)
883
+ assert len(kc) == 1, (
884
+ f"All instances of type {kcl.pop()} are not serialized into the same number "
885
+ f"of arguments, it should be the same."
886
+ )
887
+ values = []
888
+ for i in range(kc.pop()):
889
+ values.append(
890
+ self.guess_dynamic_shape_object(
891
+ *[ca[i] for ca in col_args],
892
+ auto=auto if isinstance(auto, bool) else f"{auto}_{i}o",
893
+ msg=msg,
894
+ )
895
+ )
896
+ return values
897
+
898
+ # In case DynamicCache is not registered.
899
+ if obj.__class__.__name__ == "DynamicCache":
900
+ if hasattr(obj, "layers"):
901
+ kc = set(len(o.layers) for o in objs)
902
+ assert (
903
+ len(kc) == 1
904
+ ), f"All attribute 'key_cache' should have the same length but found {kc}"
905
+ vc = kc.copy()
906
+ else:
907
+ kc = set(len(o.key_cache) for o in objs)
908
+ assert (
909
+ len(kc) == 1
910
+ ), f"All attribute 'key_cache' should have the same length but found {kc}"
911
+ vc = set(len(o.value_cache) for o in objs)
912
+ assert (
913
+ len(vc) == 1
914
+ ), f"All attribute 'value_cache' should have the same length but found {vc}"
915
+
916
+ key_cache = []
917
+ for i in range(kc.pop()):
918
+ key_cache.append(
919
+ self.guess_dynamic_dimensions(
920
+ *[
921
+ o.layers[i].keys if hasattr(o, "layers") else o.key_cache[i]
922
+ for o in objs
923
+ ],
924
+ auto=auto if isinstance(auto, bool) else f"{auto}_{i}kdc",
925
+ )
926
+ )
927
+ value_cache = []
928
+ for i in range(vc.pop()):
929
+ value_cache.append(
930
+ self.guess_dynamic_dimensions(
931
+ *[
932
+ o.layers[i].values if hasattr(o, "layers") else o.value_cache[i]
933
+ for o in objs
934
+ ],
935
+ auto=auto if isinstance(auto, bool) else f"{auto}_{i}vdc",
936
+ )
937
+ )
938
+ return list(itertools.chain.from_iterable(zip(key_cache, value_cache)))
939
+
940
+ raise NotImplementedError(
941
+ f"Unable to build dynamic shapes for type {set_types.pop()}: "
942
+ f"{string_type(objs)}{msg() if msg else ''} in {self.module_name_type}, "
943
+ f"this object needs serialization function to be registered."
944
+ )
945
+
946
+ def guess_dynamic_shapes(self, auto: Union[bool, str] = False) -> DYNAMIC_SHAPES:
947
+ """
948
+ Guesses the dynamic shapes for that module from two execution.
949
+ If there is only one execution, then that would be static dimensions.
950
+
951
+ :param auto: if auto is True, use ``torch.export.Dim.AUTO`` for any
952
+ dimension if the number of inputs is one,
953
+ if ``auto`` is a string, it uses strings
954
+ :return: guessed dynamic shapes
955
+
956
+ See example :ref:`l-guess-dynamic-shapes-example`.
957
+ """
958
+ if len(self.inputs) == 0:
959
+ # No inputs, unable to guess.
960
+ return (tuple(), {})
961
+ if len(self.inputs) == 1:
962
+ # No dynamic shapes.
963
+ return tuple(
964
+ self.guess_dynamic_shape_object(a, auto=auto) for a in self.inputs[0][0]
965
+ ), {
966
+ k: self.guess_dynamic_shape_object(v, auto=auto)
967
+ for k, v in self.inputs[0][1].items()
968
+ }
969
+
970
+ # Otherwise.
971
+ s1 = set(len(i[0]) for i in self.inputs)
972
+ assert (
973
+ len(s1) == 1
974
+ ), f"Different numbers of positional arguments {s1} for {self.full_name}"
975
+ s2 = set(tuple(sorted(set(i[1]))) for i in self.inputs)
976
+ assert len(s2) == 1, f"Different named arguments {s2} for {self.full_name}"
977
+ args = []
978
+ kwargs = {}
979
+ for i in range(s1.pop()):
980
+ objs = [_[0][i] for _ in self.inputs]
981
+ args.append(
982
+ self.guess_dynamic_shape_object(
983
+ *objs,
984
+ auto=auto if isinstance(auto, bool) else f"{auto}_{i}I",
985
+ msg=lambda i=i: f" failing input {i}",
986
+ )
987
+ )
988
+ names = s2.pop()
989
+ for i, name in enumerate(names):
990
+ assert name not in {"_diag", "verbose"}, (
991
+ f"{self.full_name}: unexpected parameter {name!r}, names={names}"
992
+ f"\ninputs[0]={string_type(self.inputs[0], with_shape=True)}"
993
+ f"\ninputs[1]={string_type(self.inputs[1], with_shape=True)}"
994
+ )
995
+
996
+ objs = [_[1][name] for _ in self.inputs]
997
+ kwargs[name] = self.guess_dynamic_shape_object(
998
+ *objs,
999
+ auto=auto if isinstance(auto, bool) else f"{auto}_{i}I",
1000
+ msg=lambda name=name: f" failing input {name!r}",
1001
+ )
1002
+ return tuple(args), kwargs
1003
+
1004
+ def move_to_kwargs(
1005
+ self,
1006
+ args: Tuple[Any, ...],
1007
+ kwargs: Dict[str, Any],
1008
+ dynamic_shapes: Tuple[Tuple[Any, ...], Dict[str, Any]],
1009
+ ) -> Tuple[Tuple[Any, ...], Dict[str, Any], DYNAMIC_SHAPES]:
1010
+ """
1011
+ Uses the signatures to move positional arguments (args) to named arguments (kwargs)
1012
+ with the corresponding dynamic shapes.
1013
+ *kwargs*, *dynamic_shapes* are modified inplace.
1014
+ """
1015
+ assert (
1016
+ self.signature is not None
1017
+ and self.forward_parameter_names is not None
1018
+ and self.forward_ordered_parameter_names is not None
1019
+ ), (
1020
+ "model was None when the class was initialized, "
1021
+ "cannot move args to kwargs without the signature."
1022
+ )
1023
+ sig = self.signature
1024
+ arg_dyn, kw_dyn = dynamic_shapes
1025
+ for i, p in enumerate(sig.parameters):
1026
+ if i >= len(arg_dyn):
1027
+ break
1028
+ kw_dyn[p] = arg_dyn[i]
1029
+ if self.forward_kwargs:
1030
+ kdw = {}
1031
+ for k, v in kw_dyn.items():
1032
+ if k not in self.forward_parameter_names:
1033
+ kdw[k] = v
1034
+ if kdw:
1035
+ for k in kdw:
1036
+ del kw_dyn[k]
1037
+ kw_dyn[self.forward_kwargs] = kdw
1038
+
1039
+ # Let's reorder as it seems to matter later
1040
+ # in the shape inference algorithm.
1041
+ _kwargs = kwargs
1042
+ kwargs = {}
1043
+ _kw_dyn = kw_dyn
1044
+ kw_dyn = {}
1045
+ for name in self.forward_ordered_parameter_names:
1046
+ if name in _kwargs:
1047
+ kwargs[name] = _kwargs[name]
1048
+ if name in _kw_dyn:
1049
+ kw_dyn[name] = _kw_dyn[name]
1050
+ for k in _kwargs:
1051
+ if k not in kwargs:
1052
+ # Then it is part of **kwargs.
1053
+ kwargs[k] = _kwargs[k]
1054
+ assert len(kw_dyn) == len(_kw_dyn), (
1055
+ f"{self.full_name}: unexpected mismatch between _kw_dyn={set(_kw_dyn)} "
1056
+ f"and kw_dyn={set(kw_dyn)}, "
1057
+ f"forward_ordered_parameter_names={self.forward_ordered_parameter_names}"
1058
+ )
1059
+ assert len(kwargs) == len(_kwargs), (
1060
+ f"{self.full_name}: unexpected mismatch between _kwargs={set(_kwargs)} "
1061
+ f"and kwargs={set(kwargs)}, "
1062
+ f"forward_ordered_parameter_names={self.forward_ordered_parameter_names}"
1063
+ )
1064
+ return args, kwargs, (tuple(), kw_dyn)
1065
+
1066
+ def validate_inputs_for_export(
1067
+ self, dynamic_shapes: Optional[DYNAMIC_SHAPES] = None
1068
+ ) -> List[List[Union[int, str]]]:
1069
+ """
1070
+ Validates the inputs the class contains for the given dynamic shapes.
1071
+ If not specified, the dynamic_shapes are guessed.
1072
+
1073
+ :param dynamic_shapes: dynamic shapes to validate
1074
+ :return: a list of lists, every list contains the path the invalid dimension
1075
+ """
1076
+ if dynamic_shapes is None:
1077
+ if len(self.inputs) == 1:
1078
+ return []
1079
+ dyn_shapes = self.guess_dynamic_shapes()
1080
+ return [
1081
+ CoupleInputsDynamicShapes(*i, dyn_shapes).invalid_dimensions_for_export()
1082
+ for i in self.inputs
1083
+ ]