onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,296 @@
1
+ from typing import Any, Dict, List, Set, Optional, Tuple, Union
2
+ from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
3
+ from .dynamic_shapes import ModelInputs
4
+
5
+
6
+ def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
7
+ """
8
+ Returns the dynamic shapes for the given inputs.
9
+ All dimensions are considered as dynamic.
10
+ ``dim_prefix`` can be a string (the function uses it as a prefix),
11
+ or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.
12
+ Depending on the version of transformers, serializations function
13
+ of DynamicCache class is automatically serialized or not (>= 4.51, < 4.55).
14
+
15
+ .. runpython::
16
+ :showcode:
17
+
18
+ import pprint
19
+ import torch
20
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
21
+ from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
22
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
23
+
24
+ bsize, nheads, slen, dim = 2, 1, 30, 96
25
+ inputs = dict(
26
+ input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
27
+ attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
28
+ position_ids=torch.arange(3, dtype=torch.int64),
29
+ past_key_values=make_dynamic_cache(
30
+ [(torch.randn(bsize, nheads, slen, dim),
31
+ torch.randn(bsize, nheads, slen, dim))]
32
+ ),
33
+ )
34
+ with torch_export_patches(patch_transformers=True):
35
+ ds = all_dynamic_shapes_from_inputs(inputs)
36
+ pprint.pprint(ds)
37
+
38
+ For this function to work, patches must be enabled if :epkg:`transformers`
39
+ does not implement the serialization functions.
40
+
41
+ .. runpython::
42
+ :showcode:
43
+
44
+ import pprint
45
+ import torch
46
+ from onnx_diagnostic.helpers.cache_helper import (
47
+ make_dynamic_cache,
48
+ make_encoder_decoder_cache,
49
+ make_mamba_cache,
50
+ make_sliding_window_cache,
51
+ make_static_cache,
52
+ )
53
+ from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
54
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
55
+
56
+ caches = [
57
+ make_dynamic_cache(
58
+ [
59
+ (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
60
+ (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
61
+ (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
62
+ ]
63
+ ),
64
+ make_encoder_decoder_cache(
65
+ make_dynamic_cache(
66
+ [
67
+ (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
68
+ (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
69
+ (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
70
+ ]
71
+ ),
72
+ make_dynamic_cache(
73
+ [
74
+ (torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
75
+ (torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
76
+ (torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
77
+ ]
78
+ ),
79
+ ),
80
+ make_sliding_window_cache(
81
+ [
82
+ (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
83
+ (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
84
+ (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
85
+ ]
86
+ ),
87
+ make_static_cache(
88
+ [
89
+ (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
90
+ (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
91
+ (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
92
+ ],
93
+ max_cache_len=15,
94
+ ),
95
+ make_mamba_cache(
96
+ [
97
+ (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
98
+ (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
99
+ (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
100
+ ]
101
+ ),
102
+ ]
103
+
104
+ with torch_export_patches(patch_transformers=True):
105
+ for cache in caches:
106
+ print(f"-- {cache.__class__.__name__}")
107
+ pprint.pprint(all_dynamic_shapes_from_inputs(cache))
108
+ """
109
+ if isinstance(dim_prefix, str):
110
+ prefixes: Set[str] = set()
111
+
112
+ def tensor_to_shape(tensor):
113
+ n = len(prefixes)
114
+ p = f"{dim_prefix}_{n}"
115
+ prefixes.add(p)
116
+ return {i: f"{p}_{i}" for i in range(tensor.ndim)}
117
+
118
+ else:
119
+
120
+ def tensor_to_shape(tensor):
121
+ return {i: dim_prefix for i in range(tensor.ndim)} # noqa: C420
122
+
123
+ return flatten_unflatten_for_dynamic_shapes(
124
+ inputs, change_function=tensor_to_shape, use_dict=True
125
+ )
126
+
127
+
128
+ def guess_dynamic_shapes_from_inputs(
129
+ inputs: List[Any], auto: Union[bool, str] = False
130
+ ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
131
+ """
132
+ Guesses which dimension is dimension from a set of inputs.
133
+ Every dimension having different values over multiple sets
134
+ of inputs. Every dimension not changing remains static.
135
+
136
+ :param inputs: a list of input sets
137
+ :param auto: True for ``torch.export.Dim.AUTO``,
138
+ False for ``torch.export.Dim.DYNAMIC``,
139
+ a string to get a unique string for every dynamic dimension
140
+ :return: args and kwargs
141
+
142
+ .. runpython::
143
+ :showcode:
144
+
145
+ import pprint
146
+ import torch
147
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
148
+ from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
149
+
150
+ bsize, nheads, slen, dim = 2, 1, 30, 96
151
+ inputs1 = dict(
152
+ input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
153
+ attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
154
+ position_ids=torch.arange(3, dtype=torch.int64),
155
+ past_key_values=make_dynamic_cache(
156
+ [
157
+ (
158
+ torch.randn(bsize, nheads, slen, dim),
159
+ torch.randn(bsize, nheads, slen, dim),
160
+ ),
161
+ ]
162
+ ),
163
+ )
164
+ bsize, nheads, slen, dim = 3, 1, 33, 96
165
+ inputs2 = dict(
166
+ input_ids=torch.randint(15, size=(3, 4), dtype=torch.int64),
167
+ attention_mask=torch.randint(1, size=(3, 34), dtype=torch.int64),
168
+ position_ids=torch.arange(4, dtype=torch.int64),
169
+ past_key_values=make_dynamic_cache(
170
+ [
171
+ (
172
+ torch.randn(bsize, nheads, slen, dim),
173
+ torch.randn(bsize, nheads, slen, dim),
174
+ ),
175
+ ]
176
+ ),
177
+ )
178
+ ds = guess_dynamic_shapes_from_inputs([inputs1, inputs2], auto="d")
179
+ pprint.pprint(ds)
180
+
181
+ This function returns something equivalent to function
182
+ :class:`torch.export.dynamic_shapes.AdditionalInputs` but this
183
+ one needs a model.
184
+
185
+ .. runpython::
186
+ :showcode:
187
+
188
+ import pprint
189
+ import torch
190
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
191
+ from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
192
+ from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
193
+
194
+ data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True)
195
+ ds = torch.export.dynamic_shapes.AdditionalInputs()
196
+ ds.add((), data["inputs"])
197
+ ds.add((), data["inputs2"])
198
+ pprint.pprint(ds.dynamic_shapes(data["model"], (), data["inputs"]))
199
+ """
200
+ mi = ModelInputs(None, inputs)
201
+ return mi.guess_dynamic_shapes(auto=auto)
202
+
203
+
204
+ def make_fake_with_dynamic_dimensions(
205
+ x: Any, dynamic_shapes: Any, context: Optional["FakeTensorContext"] = None # noqa: F821
206
+ ) -> Tuple[Any, "FakeTensorContext"]: # noqa: F821
207
+ """
208
+ Replaces all tensors by fake tensor respecting the same
209
+ constraints as the following dynamic shapes.
210
+ This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`.
211
+ Parameter ``existing`` is used to reused the same object when the dynamic
212
+ dimension is given the same name as another one.
213
+
214
+ A simple tensor:
215
+
216
+ .. runpython::
217
+ :showcode:
218
+
219
+ import torch
220
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
221
+ from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
222
+
223
+ inputs, _ = make_fake_with_dynamic_dimensions(
224
+ torch.rand((2, 3, 4, 5), dtype=torch.float32),
225
+ {0: "batch", 2: "cache_length"},
226
+ )
227
+ print(inputs)
228
+
229
+ Two tensors:
230
+
231
+ .. runpython::
232
+ :showcode:
233
+
234
+ import torch
235
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
236
+ from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
237
+
238
+ inputs, _ = make_fake_with_dynamic_dimensions(
239
+ (
240
+ torch.rand((2, 3, 4, 5), dtype=torch.float32),
241
+ torch.rand((2, 3, 4, 5), dtype=torch.float32),
242
+ ),
243
+ ({0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}),
244
+ )
245
+ print(inputs)
246
+
247
+ With a cache:
248
+
249
+ .. runpython::
250
+ :showcode:
251
+
252
+ import pprint
253
+ import torch
254
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
255
+ from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
256
+
257
+ inputs, _ = make_fake_with_dynamic_dimensions(
258
+ dict(
259
+ input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
260
+ attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
261
+ position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
262
+ past_key_values=make_dynamic_cache(
263
+ [
264
+ (
265
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
266
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
267
+ ),
268
+ (
269
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
270
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
271
+ ),
272
+ ]
273
+ ),
274
+ ),
275
+ dynamic_shapes={
276
+ "input_ids": {0: "batch", 1: "seq_length"},
277
+ "attention_mask": {0: "batch", 1: "cache+seq"},
278
+ "position_ids": {0: "batch", 1: "seq_length"},
279
+ "past_key_values": [
280
+ {0: "batch", 2: "cache_length"},
281
+ {0: "batch", 2: "cache_length"},
282
+ {0: "batch", 2: "cache_length"},
283
+ {0: "batch", 2: "cache_length"},
284
+ ],
285
+ },
286
+ )
287
+ pprint.pprint(inputs)
288
+ """
289
+ if x is None:
290
+ return None, None
291
+ if context is None:
292
+ from ..helpers.fake_tensor_helper import FakeTensorContext
293
+
294
+ context = FakeTensorContext()
295
+
296
+ return context.make_fake_with_dynamic_dimensions(x, dynamic_shapes), context
@@ -0,0 +1,173 @@
1
+ import inspect
2
+ import itertools
3
+ import time
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+ import torch
6
+ from ..helpers import string_type, max_diff, string_diff
7
+ from ..helpers.torch_helper import torch_deepcopy
8
+ from .dynamic_shapes import CoupleInputsDynamicShapes
9
+
10
+
11
+ def compare_modules(
12
+ modep: torch.nn.Module,
13
+ mod: Optional[torch.nn.Module] = None,
14
+ args: Optional[Tuple[Any, ...]] = None,
15
+ kwargs: Optional[Dict[str, Any]] = None,
16
+ copy: bool = False,
17
+ exc: bool = True,
18
+ verbose: int = 0,
19
+ atol: float = 1e-2,
20
+ rtol: float = 1e-1,
21
+ ) -> Dict[str, Any]:
22
+ """
23
+ Compares two torch modules, usually one coming from an exported program,
24
+ the other being the origin model.
25
+
26
+ :param model: first module
27
+ :param mod: second module (it produces the expected values)
28
+ :param args: positional arguments
29
+ :param kwargs: named arguments
30
+ :param copy: copy the inputs before executing the model (they may modify them inplace)
31
+ :param exc: raise exception if discrepancies are too high
32
+ :param verbose: verbosity level
33
+ :param atol: absolute tolerance
34
+ :param rtol: relative tolerance
35
+ :return: dictionary with inputs, outputs and tolerance
36
+
37
+ Example:
38
+
39
+ .. runpython::
40
+ :showcode:
41
+
42
+ import torch
43
+ from onnx_diagnostic.export import validate_ep, CoupleInputsDynamicShapes
44
+
45
+ class Model(torch.nn.Module):
46
+ def forward(self, x, y):
47
+ return x + y
48
+
49
+ model = Model()
50
+ x = torch.randn((5, 6))
51
+ y = torch.randn((1, 6))
52
+ model(x, y) # to make it is running
53
+
54
+ ds = ({0: "a", 1: "b"}, {1: "b"})
55
+ cpl = CoupleInputsDynamicShapes((x, y), {}, ds)
56
+ ep = torch.export.export(model, (x, y), dynamic_shapes=cpl.replace_string_by())
57
+ validate_ep(
58
+ ep,
59
+ model,
60
+ args=(x, y),
61
+ verbose=2,
62
+ copy=True,
63
+ dynamic_shapes=ds,
64
+ values_to_try={"a": [5, 10], "b": [10, 20]},
65
+ )
66
+
67
+ """
68
+ args = args or ()
69
+ kwargs = kwargs or {}
70
+
71
+ def _get(a):
72
+ return torch_deepcopy(a) if copy else a
73
+
74
+ if verbose:
75
+ begin = time.perf_counter()
76
+ print(
77
+ f"[compare_modules] check ep with "
78
+ f"args={string_type(args, with_shape=True, with_device=True)}, "
79
+ f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}..."
80
+ )
81
+ got = modep(*_get(args), **_get(kwargs))
82
+ if verbose:
83
+ d = time.perf_counter() - begin
84
+ print(f"[compare_modules] done in {d} with output={string_type(got, with_shape=True)}")
85
+ if mod:
86
+ if verbose:
87
+ begin = time.perf_counter()
88
+ print("[compare_modules] run torch module...")
89
+ expected = mod(*_get(args), **_get(kwargs))
90
+ diff = max_diff(expected, got)
91
+ if verbose:
92
+ d = time.perf_counter() - begin
93
+ print(
94
+ f"[compare_modules] done in {d} with "
95
+ f"output={string_type(expected, with_shape=True)}"
96
+ )
97
+ print(f"[compare_modules] discrepancies={string_diff(diff)}")
98
+ assert not exc or (
99
+ isinstance(diff["abs"], float)
100
+ and isinstance(diff["rel"], float)
101
+ and diff["abs"] <= atol
102
+ and diff["rel"] <= rtol
103
+ ), f"Discrepancies={string_diff(diff)} higher than expected."
104
+ return dict(args=args, kwargs=kwargs, expected=expected, got=got, diff=diff)
105
+ return dict(args=args, kwargs=kwargs, got=got)
106
+
107
+
108
+ def validate_ep(
109
+ ep: Union[torch.nn.Module, torch.export.ExportedProgram],
110
+ mod: Optional[torch.nn.Module] = None,
111
+ args: Optional[Tuple[Any, ...]] = None,
112
+ kwargs: Optional[Dict[str, Any]] = None,
113
+ copy: bool = False,
114
+ dynamic_shapes: Optional[Any] = None,
115
+ values_to_try: Optional[Dict[str, List[int]]] = None,
116
+ exc: bool = True,
117
+ verbose: int = 0,
118
+ atol: float = 1e-2,
119
+ rtol: float = 1e-1,
120
+ ) -> List[Dict[str, Any]]:
121
+ """
122
+ Validates an exported program.
123
+
124
+ :param model: first module
125
+ :param mod: second module (it produces the expected values)
126
+ :param args: positional arguments
127
+ :param kwargs: named arguments
128
+ :param copy: copy the inputs before executing the model (they may modify them inplace)
129
+ :param dynamic_shapes: dynamic shapes, string should be used not ``torch.export.Dim``
130
+ :param values_to_try: dictionary with the values to try for every dynamic dimension
131
+ :param exc: raise exception if discrepancies are too high
132
+ :param verbose: verbosity level
133
+ :param atol: absolute tolerance
134
+ :param rtol: relative tolerance
135
+ :return: dictionary with inputs, outputs and tolerance
136
+ """
137
+ modep = ep.module() if isinstance(ep, torch.export.ExportedProgram) else ep
138
+
139
+ results = [
140
+ compare_modules(
141
+ modep, mod, args, kwargs, copy=copy, verbose=verbose, atol=atol, rtol=rtol
142
+ )
143
+ ]
144
+
145
+ assert (dynamic_shapes and values_to_try) or (
146
+ not dynamic_shapes and not values_to_try
147
+ ), "Either both dynamic_shapes and values_to_try are specified, either none."
148
+ if not dynamic_shapes or not values_to_try:
149
+ return results
150
+
151
+ items = list(values_to_try.items())
152
+ keys = [_[0] for _ in items]
153
+ values = [_[1] for _ in items]
154
+ all_vals = list(itertools.product(*values))
155
+ cpl = CoupleInputsDynamicShapes(
156
+ args or (),
157
+ kwargs or {},
158
+ dynamic_shapes,
159
+ args_names=(
160
+ list(inspect.signature(modep.forward).parameters) if args and kwargs else None
161
+ ),
162
+ )
163
+ for i, vals in enumerate(all_vals):
164
+ change_dims = dict(zip(keys, vals))
165
+ if verbose:
166
+ print(f"[validate_ep] try {i}/{len(all_vals)}: {change_dims}")
167
+ new_params = cpl.change_dynamic_dimensions(change_dims, args_kwargs=True)
168
+ na, nkw = new_params
169
+ c = compare_modules(
170
+ modep, mod, na, nkw, copy=copy, verbose=max(verbose - 1, 0), atol=atol, rtol=rtol
171
+ )
172
+ results.append(c)
173
+ return results