onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Set, Optional, Tuple, Union
|
|
2
|
+
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
|
|
3
|
+
from .dynamic_shapes import ModelInputs
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
7
|
+
"""
|
|
8
|
+
Returns the dynamic shapes for the given inputs.
|
|
9
|
+
All dimensions are considered as dynamic.
|
|
10
|
+
``dim_prefix`` can be a string (the function uses it as a prefix),
|
|
11
|
+
or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.
|
|
12
|
+
Depending on the version of transformers, serializations function
|
|
13
|
+
of DynamicCache class is automatically serialized or not (>= 4.51, < 4.55).
|
|
14
|
+
|
|
15
|
+
.. runpython::
|
|
16
|
+
:showcode:
|
|
17
|
+
|
|
18
|
+
import pprint
|
|
19
|
+
import torch
|
|
20
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
21
|
+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
|
|
22
|
+
from onnx_diagnostic.torch_export_patches import torch_export_patches
|
|
23
|
+
|
|
24
|
+
bsize, nheads, slen, dim = 2, 1, 30, 96
|
|
25
|
+
inputs = dict(
|
|
26
|
+
input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
|
|
27
|
+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
|
|
28
|
+
position_ids=torch.arange(3, dtype=torch.int64),
|
|
29
|
+
past_key_values=make_dynamic_cache(
|
|
30
|
+
[(torch.randn(bsize, nheads, slen, dim),
|
|
31
|
+
torch.randn(bsize, nheads, slen, dim))]
|
|
32
|
+
),
|
|
33
|
+
)
|
|
34
|
+
with torch_export_patches(patch_transformers=True):
|
|
35
|
+
ds = all_dynamic_shapes_from_inputs(inputs)
|
|
36
|
+
pprint.pprint(ds)
|
|
37
|
+
|
|
38
|
+
For this function to work, patches must be enabled if :epkg:`transformers`
|
|
39
|
+
does not implement the serialization functions.
|
|
40
|
+
|
|
41
|
+
.. runpython::
|
|
42
|
+
:showcode:
|
|
43
|
+
|
|
44
|
+
import pprint
|
|
45
|
+
import torch
|
|
46
|
+
from onnx_diagnostic.helpers.cache_helper import (
|
|
47
|
+
make_dynamic_cache,
|
|
48
|
+
make_encoder_decoder_cache,
|
|
49
|
+
make_mamba_cache,
|
|
50
|
+
make_sliding_window_cache,
|
|
51
|
+
make_static_cache,
|
|
52
|
+
)
|
|
53
|
+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
|
|
54
|
+
from onnx_diagnostic.torch_export_patches import torch_export_patches
|
|
55
|
+
|
|
56
|
+
caches = [
|
|
57
|
+
make_dynamic_cache(
|
|
58
|
+
[
|
|
59
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
60
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
61
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
62
|
+
]
|
|
63
|
+
),
|
|
64
|
+
make_encoder_decoder_cache(
|
|
65
|
+
make_dynamic_cache(
|
|
66
|
+
[
|
|
67
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
68
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
69
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
70
|
+
]
|
|
71
|
+
),
|
|
72
|
+
make_dynamic_cache(
|
|
73
|
+
[
|
|
74
|
+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
|
|
75
|
+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
|
|
76
|
+
(torch.rand((5, 5, 5)), torch.rand((5, 5, 5))),
|
|
77
|
+
]
|
|
78
|
+
),
|
|
79
|
+
),
|
|
80
|
+
make_sliding_window_cache(
|
|
81
|
+
[
|
|
82
|
+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
83
|
+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
84
|
+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
85
|
+
]
|
|
86
|
+
),
|
|
87
|
+
make_static_cache(
|
|
88
|
+
[
|
|
89
|
+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
90
|
+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
91
|
+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
|
|
92
|
+
],
|
|
93
|
+
max_cache_len=15,
|
|
94
|
+
),
|
|
95
|
+
make_mamba_cache(
|
|
96
|
+
[
|
|
97
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
98
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
99
|
+
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
|
|
100
|
+
]
|
|
101
|
+
),
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
with torch_export_patches(patch_transformers=True):
|
|
105
|
+
for cache in caches:
|
|
106
|
+
print(f"-- {cache.__class__.__name__}")
|
|
107
|
+
pprint.pprint(all_dynamic_shapes_from_inputs(cache))
|
|
108
|
+
"""
|
|
109
|
+
if isinstance(dim_prefix, str):
|
|
110
|
+
prefixes: Set[str] = set()
|
|
111
|
+
|
|
112
|
+
def tensor_to_shape(tensor):
|
|
113
|
+
n = len(prefixes)
|
|
114
|
+
p = f"{dim_prefix}_{n}"
|
|
115
|
+
prefixes.add(p)
|
|
116
|
+
return {i: f"{p}_{i}" for i in range(tensor.ndim)}
|
|
117
|
+
|
|
118
|
+
else:
|
|
119
|
+
|
|
120
|
+
def tensor_to_shape(tensor):
|
|
121
|
+
return {i: dim_prefix for i in range(tensor.ndim)} # noqa: C420
|
|
122
|
+
|
|
123
|
+
return flatten_unflatten_for_dynamic_shapes(
|
|
124
|
+
inputs, change_function=tensor_to_shape, use_dict=True
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def guess_dynamic_shapes_from_inputs(
|
|
129
|
+
inputs: List[Any], auto: Union[bool, str] = False
|
|
130
|
+
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
|
|
131
|
+
"""
|
|
132
|
+
Guesses which dimension is dimension from a set of inputs.
|
|
133
|
+
Every dimension having different values over multiple sets
|
|
134
|
+
of inputs. Every dimension not changing remains static.
|
|
135
|
+
|
|
136
|
+
:param inputs: a list of input sets
|
|
137
|
+
:param auto: True for ``torch.export.Dim.AUTO``,
|
|
138
|
+
False for ``torch.export.Dim.DYNAMIC``,
|
|
139
|
+
a string to get a unique string for every dynamic dimension
|
|
140
|
+
:return: args and kwargs
|
|
141
|
+
|
|
142
|
+
.. runpython::
|
|
143
|
+
:showcode:
|
|
144
|
+
|
|
145
|
+
import pprint
|
|
146
|
+
import torch
|
|
147
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
148
|
+
from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
|
|
149
|
+
|
|
150
|
+
bsize, nheads, slen, dim = 2, 1, 30, 96
|
|
151
|
+
inputs1 = dict(
|
|
152
|
+
input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
|
|
153
|
+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
|
|
154
|
+
position_ids=torch.arange(3, dtype=torch.int64),
|
|
155
|
+
past_key_values=make_dynamic_cache(
|
|
156
|
+
[
|
|
157
|
+
(
|
|
158
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
159
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
160
|
+
),
|
|
161
|
+
]
|
|
162
|
+
),
|
|
163
|
+
)
|
|
164
|
+
bsize, nheads, slen, dim = 3, 1, 33, 96
|
|
165
|
+
inputs2 = dict(
|
|
166
|
+
input_ids=torch.randint(15, size=(3, 4), dtype=torch.int64),
|
|
167
|
+
attention_mask=torch.randint(1, size=(3, 34), dtype=torch.int64),
|
|
168
|
+
position_ids=torch.arange(4, dtype=torch.int64),
|
|
169
|
+
past_key_values=make_dynamic_cache(
|
|
170
|
+
[
|
|
171
|
+
(
|
|
172
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
173
|
+
torch.randn(bsize, nheads, slen, dim),
|
|
174
|
+
),
|
|
175
|
+
]
|
|
176
|
+
),
|
|
177
|
+
)
|
|
178
|
+
ds = guess_dynamic_shapes_from_inputs([inputs1, inputs2], auto="d")
|
|
179
|
+
pprint.pprint(ds)
|
|
180
|
+
|
|
181
|
+
This function returns something equivalent to function
|
|
182
|
+
:class:`torch.export.dynamic_shapes.AdditionalInputs` but this
|
|
183
|
+
one needs a model.
|
|
184
|
+
|
|
185
|
+
.. runpython::
|
|
186
|
+
:showcode:
|
|
187
|
+
|
|
188
|
+
import pprint
|
|
189
|
+
import torch
|
|
190
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
191
|
+
from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs
|
|
192
|
+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
|
|
193
|
+
|
|
194
|
+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True)
|
|
195
|
+
ds = torch.export.dynamic_shapes.AdditionalInputs()
|
|
196
|
+
ds.add((), data["inputs"])
|
|
197
|
+
ds.add((), data["inputs2"])
|
|
198
|
+
pprint.pprint(ds.dynamic_shapes(data["model"], (), data["inputs"]))
|
|
199
|
+
"""
|
|
200
|
+
mi = ModelInputs(None, inputs)
|
|
201
|
+
return mi.guess_dynamic_shapes(auto=auto)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def make_fake_with_dynamic_dimensions(
|
|
205
|
+
x: Any, dynamic_shapes: Any, context: Optional["FakeTensorContext"] = None # noqa: F821
|
|
206
|
+
) -> Tuple[Any, "FakeTensorContext"]: # noqa: F821
|
|
207
|
+
"""
|
|
208
|
+
Replaces all tensors by fake tensor respecting the same
|
|
209
|
+
constraints as the following dynamic shapes.
|
|
210
|
+
This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`.
|
|
211
|
+
Parameter ``existing`` is used to reused the same object when the dynamic
|
|
212
|
+
dimension is given the same name as another one.
|
|
213
|
+
|
|
214
|
+
A simple tensor:
|
|
215
|
+
|
|
216
|
+
.. runpython::
|
|
217
|
+
:showcode:
|
|
218
|
+
|
|
219
|
+
import torch
|
|
220
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
221
|
+
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
|
|
222
|
+
|
|
223
|
+
inputs, _ = make_fake_with_dynamic_dimensions(
|
|
224
|
+
torch.rand((2, 3, 4, 5), dtype=torch.float32),
|
|
225
|
+
{0: "batch", 2: "cache_length"},
|
|
226
|
+
)
|
|
227
|
+
print(inputs)
|
|
228
|
+
|
|
229
|
+
Two tensors:
|
|
230
|
+
|
|
231
|
+
.. runpython::
|
|
232
|
+
:showcode:
|
|
233
|
+
|
|
234
|
+
import torch
|
|
235
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
236
|
+
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
|
|
237
|
+
|
|
238
|
+
inputs, _ = make_fake_with_dynamic_dimensions(
|
|
239
|
+
(
|
|
240
|
+
torch.rand((2, 3, 4, 5), dtype=torch.float32),
|
|
241
|
+
torch.rand((2, 3, 4, 5), dtype=torch.float32),
|
|
242
|
+
),
|
|
243
|
+
({0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}),
|
|
244
|
+
)
|
|
245
|
+
print(inputs)
|
|
246
|
+
|
|
247
|
+
With a cache:
|
|
248
|
+
|
|
249
|
+
.. runpython::
|
|
250
|
+
:showcode:
|
|
251
|
+
|
|
252
|
+
import pprint
|
|
253
|
+
import torch
|
|
254
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
255
|
+
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
|
|
256
|
+
|
|
257
|
+
inputs, _ = make_fake_with_dynamic_dimensions(
|
|
258
|
+
dict(
|
|
259
|
+
input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
|
|
260
|
+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
|
|
261
|
+
position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
|
|
262
|
+
past_key_values=make_dynamic_cache(
|
|
263
|
+
[
|
|
264
|
+
(
|
|
265
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
266
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
267
|
+
),
|
|
268
|
+
(
|
|
269
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
270
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
271
|
+
),
|
|
272
|
+
]
|
|
273
|
+
),
|
|
274
|
+
),
|
|
275
|
+
dynamic_shapes={
|
|
276
|
+
"input_ids": {0: "batch", 1: "seq_length"},
|
|
277
|
+
"attention_mask": {0: "batch", 1: "cache+seq"},
|
|
278
|
+
"position_ids": {0: "batch", 1: "seq_length"},
|
|
279
|
+
"past_key_values": [
|
|
280
|
+
{0: "batch", 2: "cache_length"},
|
|
281
|
+
{0: "batch", 2: "cache_length"},
|
|
282
|
+
{0: "batch", 2: "cache_length"},
|
|
283
|
+
{0: "batch", 2: "cache_length"},
|
|
284
|
+
],
|
|
285
|
+
},
|
|
286
|
+
)
|
|
287
|
+
pprint.pprint(inputs)
|
|
288
|
+
"""
|
|
289
|
+
if x is None:
|
|
290
|
+
return None, None
|
|
291
|
+
if context is None:
|
|
292
|
+
from ..helpers.fake_tensor_helper import FakeTensorContext
|
|
293
|
+
|
|
294
|
+
context = FakeTensorContext()
|
|
295
|
+
|
|
296
|
+
return context.make_fake_with_dynamic_dimensions(x, dynamic_shapes), context
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import itertools
|
|
3
|
+
import time
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
5
|
+
import torch
|
|
6
|
+
from ..helpers import string_type, max_diff, string_diff
|
|
7
|
+
from ..helpers.torch_helper import torch_deepcopy
|
|
8
|
+
from .dynamic_shapes import CoupleInputsDynamicShapes
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def compare_modules(
|
|
12
|
+
modep: torch.nn.Module,
|
|
13
|
+
mod: Optional[torch.nn.Module] = None,
|
|
14
|
+
args: Optional[Tuple[Any, ...]] = None,
|
|
15
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
|
16
|
+
copy: bool = False,
|
|
17
|
+
exc: bool = True,
|
|
18
|
+
verbose: int = 0,
|
|
19
|
+
atol: float = 1e-2,
|
|
20
|
+
rtol: float = 1e-1,
|
|
21
|
+
) -> Dict[str, Any]:
|
|
22
|
+
"""
|
|
23
|
+
Compares two torch modules, usually one coming from an exported program,
|
|
24
|
+
the other being the origin model.
|
|
25
|
+
|
|
26
|
+
:param model: first module
|
|
27
|
+
:param mod: second module (it produces the expected values)
|
|
28
|
+
:param args: positional arguments
|
|
29
|
+
:param kwargs: named arguments
|
|
30
|
+
:param copy: copy the inputs before executing the model (they may modify them inplace)
|
|
31
|
+
:param exc: raise exception if discrepancies are too high
|
|
32
|
+
:param verbose: verbosity level
|
|
33
|
+
:param atol: absolute tolerance
|
|
34
|
+
:param rtol: relative tolerance
|
|
35
|
+
:return: dictionary with inputs, outputs and tolerance
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
|
|
39
|
+
.. runpython::
|
|
40
|
+
:showcode:
|
|
41
|
+
|
|
42
|
+
import torch
|
|
43
|
+
from onnx_diagnostic.export import validate_ep, CoupleInputsDynamicShapes
|
|
44
|
+
|
|
45
|
+
class Model(torch.nn.Module):
|
|
46
|
+
def forward(self, x, y):
|
|
47
|
+
return x + y
|
|
48
|
+
|
|
49
|
+
model = Model()
|
|
50
|
+
x = torch.randn((5, 6))
|
|
51
|
+
y = torch.randn((1, 6))
|
|
52
|
+
model(x, y) # to make it is running
|
|
53
|
+
|
|
54
|
+
ds = ({0: "a", 1: "b"}, {1: "b"})
|
|
55
|
+
cpl = CoupleInputsDynamicShapes((x, y), {}, ds)
|
|
56
|
+
ep = torch.export.export(model, (x, y), dynamic_shapes=cpl.replace_string_by())
|
|
57
|
+
validate_ep(
|
|
58
|
+
ep,
|
|
59
|
+
model,
|
|
60
|
+
args=(x, y),
|
|
61
|
+
verbose=2,
|
|
62
|
+
copy=True,
|
|
63
|
+
dynamic_shapes=ds,
|
|
64
|
+
values_to_try={"a": [5, 10], "b": [10, 20]},
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
"""
|
|
68
|
+
args = args or ()
|
|
69
|
+
kwargs = kwargs or {}
|
|
70
|
+
|
|
71
|
+
def _get(a):
|
|
72
|
+
return torch_deepcopy(a) if copy else a
|
|
73
|
+
|
|
74
|
+
if verbose:
|
|
75
|
+
begin = time.perf_counter()
|
|
76
|
+
print(
|
|
77
|
+
f"[compare_modules] check ep with "
|
|
78
|
+
f"args={string_type(args, with_shape=True, with_device=True)}, "
|
|
79
|
+
f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}..."
|
|
80
|
+
)
|
|
81
|
+
got = modep(*_get(args), **_get(kwargs))
|
|
82
|
+
if verbose:
|
|
83
|
+
d = time.perf_counter() - begin
|
|
84
|
+
print(f"[compare_modules] done in {d} with output={string_type(got, with_shape=True)}")
|
|
85
|
+
if mod:
|
|
86
|
+
if verbose:
|
|
87
|
+
begin = time.perf_counter()
|
|
88
|
+
print("[compare_modules] run torch module...")
|
|
89
|
+
expected = mod(*_get(args), **_get(kwargs))
|
|
90
|
+
diff = max_diff(expected, got)
|
|
91
|
+
if verbose:
|
|
92
|
+
d = time.perf_counter() - begin
|
|
93
|
+
print(
|
|
94
|
+
f"[compare_modules] done in {d} with "
|
|
95
|
+
f"output={string_type(expected, with_shape=True)}"
|
|
96
|
+
)
|
|
97
|
+
print(f"[compare_modules] discrepancies={string_diff(diff)}")
|
|
98
|
+
assert not exc or (
|
|
99
|
+
isinstance(diff["abs"], float)
|
|
100
|
+
and isinstance(diff["rel"], float)
|
|
101
|
+
and diff["abs"] <= atol
|
|
102
|
+
and diff["rel"] <= rtol
|
|
103
|
+
), f"Discrepancies={string_diff(diff)} higher than expected."
|
|
104
|
+
return dict(args=args, kwargs=kwargs, expected=expected, got=got, diff=diff)
|
|
105
|
+
return dict(args=args, kwargs=kwargs, got=got)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def validate_ep(
|
|
109
|
+
ep: Union[torch.nn.Module, torch.export.ExportedProgram],
|
|
110
|
+
mod: Optional[torch.nn.Module] = None,
|
|
111
|
+
args: Optional[Tuple[Any, ...]] = None,
|
|
112
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
|
113
|
+
copy: bool = False,
|
|
114
|
+
dynamic_shapes: Optional[Any] = None,
|
|
115
|
+
values_to_try: Optional[Dict[str, List[int]]] = None,
|
|
116
|
+
exc: bool = True,
|
|
117
|
+
verbose: int = 0,
|
|
118
|
+
atol: float = 1e-2,
|
|
119
|
+
rtol: float = 1e-1,
|
|
120
|
+
) -> List[Dict[str, Any]]:
|
|
121
|
+
"""
|
|
122
|
+
Validates an exported program.
|
|
123
|
+
|
|
124
|
+
:param model: first module
|
|
125
|
+
:param mod: second module (it produces the expected values)
|
|
126
|
+
:param args: positional arguments
|
|
127
|
+
:param kwargs: named arguments
|
|
128
|
+
:param copy: copy the inputs before executing the model (they may modify them inplace)
|
|
129
|
+
:param dynamic_shapes: dynamic shapes, string should be used not ``torch.export.Dim``
|
|
130
|
+
:param values_to_try: dictionary with the values to try for every dynamic dimension
|
|
131
|
+
:param exc: raise exception if discrepancies are too high
|
|
132
|
+
:param verbose: verbosity level
|
|
133
|
+
:param atol: absolute tolerance
|
|
134
|
+
:param rtol: relative tolerance
|
|
135
|
+
:return: dictionary with inputs, outputs and tolerance
|
|
136
|
+
"""
|
|
137
|
+
modep = ep.module() if isinstance(ep, torch.export.ExportedProgram) else ep
|
|
138
|
+
|
|
139
|
+
results = [
|
|
140
|
+
compare_modules(
|
|
141
|
+
modep, mod, args, kwargs, copy=copy, verbose=verbose, atol=atol, rtol=rtol
|
|
142
|
+
)
|
|
143
|
+
]
|
|
144
|
+
|
|
145
|
+
assert (dynamic_shapes and values_to_try) or (
|
|
146
|
+
not dynamic_shapes and not values_to_try
|
|
147
|
+
), "Either both dynamic_shapes and values_to_try are specified, either none."
|
|
148
|
+
if not dynamic_shapes or not values_to_try:
|
|
149
|
+
return results
|
|
150
|
+
|
|
151
|
+
items = list(values_to_try.items())
|
|
152
|
+
keys = [_[0] for _ in items]
|
|
153
|
+
values = [_[1] for _ in items]
|
|
154
|
+
all_vals = list(itertools.product(*values))
|
|
155
|
+
cpl = CoupleInputsDynamicShapes(
|
|
156
|
+
args or (),
|
|
157
|
+
kwargs or {},
|
|
158
|
+
dynamic_shapes,
|
|
159
|
+
args_names=(
|
|
160
|
+
list(inspect.signature(modep.forward).parameters) if args and kwargs else None
|
|
161
|
+
),
|
|
162
|
+
)
|
|
163
|
+
for i, vals in enumerate(all_vals):
|
|
164
|
+
change_dims = dict(zip(keys, vals))
|
|
165
|
+
if verbose:
|
|
166
|
+
print(f"[validate_ep] try {i}/{len(all_vals)}: {change_dims}")
|
|
167
|
+
new_params = cpl.change_dynamic_dimensions(change_dims, args_kwargs=True)
|
|
168
|
+
na, nkw = new_params
|
|
169
|
+
c = compare_modules(
|
|
170
|
+
modep, mod, na, nkw, copy=copy, verbose=max(verbose - 1, 0), atol=atol, rtol=rtol
|
|
171
|
+
)
|
|
172
|
+
results.append(c)
|
|
173
|
+
return results
|