onnx-diagnostic 0.7.0__py3-none-any.whl → 0.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +213 -5
- onnx_diagnostic/export/dynamic_shapes.py +48 -20
- onnx_diagnostic/export/shape_helper.py +126 -0
- onnx_diagnostic/ext_test_case.py +31 -0
- onnx_diagnostic/helpers/cache_helper.py +42 -20
- onnx_diagnostic/helpers/config_helper.py +16 -1
- onnx_diagnostic/helpers/log_helper.py +1561 -177
- onnx_diagnostic/helpers/torch_helper.py +6 -2
- onnx_diagnostic/tasks/__init__.py +2 -0
- onnx_diagnostic/tasks/image_text_to_text.py +69 -18
- onnx_diagnostic/tasks/text_generation.py +17 -8
- onnx_diagnostic/tasks/text_to_image.py +91 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +144 -349
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +87 -7
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +259 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +73 -5
- onnx_diagnostic/torch_models/hghub/hub_data.py +7 -2
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +28 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +74 -14
- onnx_diagnostic/torch_models/validate.py +45 -16
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/RECORD +29 -24
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/top_level.txt +0 -0
|
@@ -735,7 +735,8 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
|
|
|
735
735
|
[t.to(to_value) for t in value.key_cache],
|
|
736
736
|
[t.to(to_value) for t in value.value_cache],
|
|
737
737
|
)
|
|
738
|
-
)
|
|
738
|
+
),
|
|
739
|
+
max_cache_len=value.max_cache_len,
|
|
739
740
|
)
|
|
740
741
|
if value.__class__.__name__ == "EncoderDecoderCache":
|
|
741
742
|
return make_encoder_decoder_cache(
|
|
@@ -784,7 +785,10 @@ def torch_deepcopy(value: Any) -> Any:
|
|
|
784
785
|
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
|
|
785
786
|
)
|
|
786
787
|
if value.__class__.__name__ == "StaticCache":
|
|
787
|
-
return make_static_cache(
|
|
788
|
+
return make_static_cache(
|
|
789
|
+
torch_deepcopy(list(zip(value.key_cache, value.value_cache))),
|
|
790
|
+
max_cache_len=value.max_cache_len,
|
|
791
|
+
)
|
|
788
792
|
if value.__class__.__name__ == "SlidingWindowCache":
|
|
789
793
|
return make_sliding_window_cache(
|
|
790
794
|
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
|
|
@@ -11,6 +11,7 @@ from . import (
|
|
|
11
11
|
summarization,
|
|
12
12
|
text_classification,
|
|
13
13
|
text_generation,
|
|
14
|
+
text_to_image,
|
|
14
15
|
text2text_generation,
|
|
15
16
|
zero_shot_image_classification,
|
|
16
17
|
)
|
|
@@ -27,6 +28,7 @@ __TASKS__ = [
|
|
|
27
28
|
summarization,
|
|
28
29
|
text_classification,
|
|
29
30
|
text_generation,
|
|
31
|
+
text_to_image,
|
|
30
32
|
text2text_generation,
|
|
31
33
|
zero_shot_image_classification,
|
|
32
34
|
]
|
|
@@ -96,10 +96,10 @@ def get_inputs(
|
|
|
96
96
|
for i in range(num_hidden_layers)
|
|
97
97
|
]
|
|
98
98
|
),
|
|
99
|
-
|
|
99
|
+
pixel_values=torch.ones((batch_size, n_images, num_channels, width, height)).to(
|
|
100
100
|
torch.int64
|
|
101
101
|
),
|
|
102
|
-
|
|
102
|
+
image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
|
|
103
103
|
torch.int64
|
|
104
104
|
),
|
|
105
105
|
)
|
|
@@ -132,16 +132,30 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
132
132
|
If the configuration is None, the function selects typical dimensions.
|
|
133
133
|
"""
|
|
134
134
|
if config is not None:
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
135
|
+
if hasattr(config, "text_config"):
|
|
136
|
+
check_hasattr(
|
|
137
|
+
config.text_config,
|
|
138
|
+
"vocab_size",
|
|
139
|
+
"hidden_size",
|
|
140
|
+
"num_attention_heads",
|
|
141
|
+
("num_key_value_heads", "num_attention_heads"),
|
|
142
|
+
"intermediate_size",
|
|
143
|
+
"hidden_size",
|
|
144
|
+
)
|
|
145
|
+
check_hasattr(config, "vision_config")
|
|
146
|
+
text_config = True
|
|
147
|
+
else:
|
|
148
|
+
check_hasattr(
|
|
149
|
+
config,
|
|
150
|
+
"vocab_size",
|
|
151
|
+
"hidden_size",
|
|
152
|
+
"num_attention_heads",
|
|
153
|
+
("num_key_value_heads", "num_attention_heads"),
|
|
154
|
+
"intermediate_size",
|
|
155
|
+
"hidden_size",
|
|
156
|
+
"vision_config",
|
|
157
|
+
)
|
|
158
|
+
text_config = False
|
|
145
159
|
check_hasattr(config.vision_config, "image_size", "num_channels")
|
|
146
160
|
kwargs = dict(
|
|
147
161
|
batch_size=2,
|
|
@@ -150,17 +164,54 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
150
164
|
head_dim=(
|
|
151
165
|
16
|
|
152
166
|
if config is None
|
|
153
|
-
else getattr(
|
|
167
|
+
else getattr(
|
|
168
|
+
config,
|
|
169
|
+
"head_dim",
|
|
170
|
+
(config.text_config.hidden_size if text_config else config.hidden_size)
|
|
171
|
+
// (
|
|
172
|
+
config.text_config.num_attention_heads
|
|
173
|
+
if text_config
|
|
174
|
+
else config.num_attention_heads
|
|
175
|
+
),
|
|
176
|
+
)
|
|
177
|
+
),
|
|
178
|
+
dummy_max_token_id=(
|
|
179
|
+
31999
|
|
180
|
+
if config is None
|
|
181
|
+
else (config.text_config.vocab_size if text_config else config.vocab_size) - 1
|
|
182
|
+
),
|
|
183
|
+
num_hidden_layers=(
|
|
184
|
+
4
|
|
185
|
+
if config is None
|
|
186
|
+
else (
|
|
187
|
+
config.text_config.num_hidden_layers
|
|
188
|
+
if text_config
|
|
189
|
+
else config.num_hidden_layers
|
|
190
|
+
)
|
|
154
191
|
),
|
|
155
|
-
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
|
|
156
|
-
num_hidden_layers=4 if config is None else config.num_hidden_layers,
|
|
157
192
|
num_key_value_heads=(
|
|
158
193
|
8
|
|
159
194
|
if config is None
|
|
160
|
-
else
|
|
195
|
+
else (
|
|
196
|
+
_pick(config.text_config, "num_key_value_heads", "num_attention_heads")
|
|
197
|
+
if text_config
|
|
198
|
+
else _pick(config, "num_key_value_heads", "num_attention_heads")
|
|
199
|
+
)
|
|
200
|
+
),
|
|
201
|
+
intermediate_size=(
|
|
202
|
+
1024
|
|
203
|
+
if config is None
|
|
204
|
+
else (
|
|
205
|
+
config.text_config.intermediate_size
|
|
206
|
+
if text_config
|
|
207
|
+
else config.intermediate_size
|
|
208
|
+
)
|
|
209
|
+
),
|
|
210
|
+
hidden_size=(
|
|
211
|
+
512
|
|
212
|
+
if config is None
|
|
213
|
+
else (config.text_config.hidden_size if text_config else config.hidden_size)
|
|
161
214
|
),
|
|
162
|
-
intermediate_size=1024 if config is None else config.intermediate_size,
|
|
163
|
-
hidden_size=512 if config is None else config.hidden_size,
|
|
164
215
|
width=224 if config is None else config.vision_config.image_size,
|
|
165
216
|
height=224 if config is None else config.vision_config.image_size,
|
|
166
217
|
num_channels=3 if config is None else config.vision_config.num_channels,
|
|
@@ -109,7 +109,7 @@ def get_inputs(
|
|
|
109
109
|
sequence_length2 = seq_length_multiple
|
|
110
110
|
|
|
111
111
|
shapes = {
|
|
112
|
-
"input_ids": {0: batch, 1:
|
|
112
|
+
"input_ids": {0: batch, 1: "sequence_length"},
|
|
113
113
|
"attention_mask": {
|
|
114
114
|
0: batch,
|
|
115
115
|
1: "cache+seq", # cache_length + seq_length
|
|
@@ -176,8 +176,10 @@ def get_inputs(
|
|
|
176
176
|
"attention_mask": {0: batch, 2: "seq"},
|
|
177
177
|
"cache_position": {0: "seq"},
|
|
178
178
|
"past_key_values": [
|
|
179
|
-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
180
|
-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
179
|
+
# [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
180
|
+
# [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
181
|
+
[{0: batch} for _ in range(num_hidden_layers)],
|
|
182
|
+
[{0: batch} for _ in range(num_hidden_layers)],
|
|
181
183
|
],
|
|
182
184
|
}
|
|
183
185
|
inputs = dict(
|
|
@@ -188,18 +190,25 @@ def get_inputs(
|
|
|
188
190
|
(batch_size, num_key_value_heads, sequence_length2, head_dim)
|
|
189
191
|
).to(torch.bool),
|
|
190
192
|
cache_position=torch.arange(sequence_length2).to(torch.int64),
|
|
191
|
-
past_key_values=
|
|
193
|
+
past_key_values=make_static_cache(
|
|
192
194
|
[
|
|
193
195
|
(
|
|
194
196
|
torch.randn(
|
|
195
|
-
batch_size,
|
|
197
|
+
batch_size,
|
|
198
|
+
num_key_value_heads,
|
|
199
|
+
sequence_length + sequence_length2,
|
|
200
|
+
head_dim,
|
|
196
201
|
),
|
|
197
202
|
torch.randn(
|
|
198
|
-
batch_size,
|
|
203
|
+
batch_size,
|
|
204
|
+
num_key_value_heads,
|
|
205
|
+
sequence_length + sequence_length2,
|
|
206
|
+
head_dim,
|
|
199
207
|
),
|
|
200
208
|
)
|
|
201
209
|
for i in range(num_hidden_layers)
|
|
202
|
-
]
|
|
210
|
+
],
|
|
211
|
+
max_cache_len=max(sequence_length + sequence_length2, head_dim),
|
|
203
212
|
),
|
|
204
213
|
)
|
|
205
214
|
else:
|
|
@@ -230,7 +239,7 @@ def get_inputs(
|
|
|
230
239
|
position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
|
|
231
240
|
.to(torch.int64)
|
|
232
241
|
.expand((batch_size, -1)),
|
|
233
|
-
past_key_values=make_cache(
|
|
242
|
+
past_key_values=make_cache( # type: ignore[operator]
|
|
234
243
|
[
|
|
235
244
|
(
|
|
236
245
|
torch.randn(
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
from ..helpers.config_helper import update_config, check_hasattr, pick
|
|
4
|
+
|
|
5
|
+
__TASK__ = "text-to-image"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
9
|
+
"""Reduces a model size."""
|
|
10
|
+
check_hasattr(config, "sample_size", "cross_attention_dim")
|
|
11
|
+
kwargs = dict(
|
|
12
|
+
sample_size=min(config["sample_size"], 32),
|
|
13
|
+
cross_attention_dim=min(config["cross_attention_dim"], 64),
|
|
14
|
+
)
|
|
15
|
+
update_config(config, kwargs)
|
|
16
|
+
return kwargs
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_inputs(
|
|
20
|
+
model: torch.nn.Module,
|
|
21
|
+
config: Optional[Any],
|
|
22
|
+
batch_size: int,
|
|
23
|
+
sequence_length: int,
|
|
24
|
+
cache_length: int,
|
|
25
|
+
in_channels: int,
|
|
26
|
+
sample_size: int,
|
|
27
|
+
cross_attention_dim: int,
|
|
28
|
+
add_second_input: bool = False,
|
|
29
|
+
**kwargs, # unused
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Generates inputs for task ``text-to-image``.
|
|
33
|
+
Example:
|
|
34
|
+
|
|
35
|
+
::
|
|
36
|
+
|
|
37
|
+
sample:T10s2x4x96x96[-3.7734375,4.359375:A-0.043463995395642184]
|
|
38
|
+
timestep:T7s=101
|
|
39
|
+
encoder_hidden_states:T10s2x77x1024[-6.58203125,13.0234375:A-0.16780663634440257]
|
|
40
|
+
"""
|
|
41
|
+
assert (
|
|
42
|
+
"cls_cache" not in kwargs
|
|
43
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
44
|
+
batch = "batch"
|
|
45
|
+
shapes = {
|
|
46
|
+
"sample": {0: batch},
|
|
47
|
+
"timestep": {},
|
|
48
|
+
"encoder_hidden_states": {0: batch, 1: "encoder_length"},
|
|
49
|
+
}
|
|
50
|
+
inputs = dict(
|
|
51
|
+
sample=torch.randn((batch_size, sequence_length, sample_size, sample_size)).to(
|
|
52
|
+
torch.float32
|
|
53
|
+
),
|
|
54
|
+
timestep=torch.tensor([101], dtype=torch.int64),
|
|
55
|
+
encoder_hidden_states=torch.randn(
|
|
56
|
+
(batch_size, sequence_length, cross_attention_dim)
|
|
57
|
+
).to(torch.float32),
|
|
58
|
+
)
|
|
59
|
+
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
60
|
+
if add_second_input:
|
|
61
|
+
res["inputs2"] = get_inputs(
|
|
62
|
+
model=model,
|
|
63
|
+
config=config,
|
|
64
|
+
batch_size=batch_size + 1,
|
|
65
|
+
sequence_length=sequence_length,
|
|
66
|
+
cache_length=cache_length + 1,
|
|
67
|
+
in_channels=in_channels,
|
|
68
|
+
sample_size=sample_size,
|
|
69
|
+
cross_attention_dim=cross_attention_dim,
|
|
70
|
+
**kwargs,
|
|
71
|
+
)["inputs"]
|
|
72
|
+
return res
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
76
|
+
"""
|
|
77
|
+
Inputs kwargs.
|
|
78
|
+
|
|
79
|
+
If the configuration is None, the function selects typical dimensions.
|
|
80
|
+
"""
|
|
81
|
+
if config is not None:
|
|
82
|
+
check_hasattr(config, "sample_size", "cross_attention_dim", "in_channels")
|
|
83
|
+
kwargs = dict(
|
|
84
|
+
batch_size=2,
|
|
85
|
+
sequence_length=pick(config, "in_channels", 4),
|
|
86
|
+
cache_length=77,
|
|
87
|
+
in_channels=pick(config, "in_channels", 4),
|
|
88
|
+
sample_size=pick(config, "sample_size", 32),
|
|
89
|
+
cross_attention_dim=pick(config, "cross_attention_dim", 64),
|
|
90
|
+
)
|
|
91
|
+
return kwargs, get_inputs
|
|
@@ -134,11 +134,17 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
|
|
|
134
134
|
|
|
135
135
|
@contextlib.contextmanager
|
|
136
136
|
def register_additional_serialization_functions(
|
|
137
|
-
patch_transformers: bool = False, verbose: int = 0
|
|
137
|
+
patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0
|
|
138
138
|
) -> Callable:
|
|
139
139
|
"""The necessary modifications to run the fx Graph."""
|
|
140
|
-
fct_callable =
|
|
141
|
-
|
|
140
|
+
fct_callable = (
|
|
141
|
+
replacement_before_exporting
|
|
142
|
+
if patch_transformers or patch_diffusers
|
|
143
|
+
else (lambda x: x)
|
|
144
|
+
)
|
|
145
|
+
done = register_cache_serialization(
|
|
146
|
+
patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose
|
|
147
|
+
)
|
|
142
148
|
try:
|
|
143
149
|
yield fct_callable
|
|
144
150
|
finally:
|
|
@@ -150,6 +156,7 @@ def torch_export_patches(
|
|
|
150
156
|
patch_sympy: bool = True,
|
|
151
157
|
patch_torch: bool = True,
|
|
152
158
|
patch_transformers: bool = False,
|
|
159
|
+
patch_diffusers: bool = False,
|
|
153
160
|
catch_constraints: bool = True,
|
|
154
161
|
stop_if_static: int = 0,
|
|
155
162
|
verbose: int = 0,
|
|
@@ -165,6 +172,7 @@ def torch_export_patches(
|
|
|
165
172
|
:param patch_sympy: fix missing method ``name`` for IntegerConstant
|
|
166
173
|
:param patch_torch: patches :epkg:`torch` with supported implementation
|
|
167
174
|
:param patch_transformers: patches :epkg:`transformers` with supported implementation
|
|
175
|
+
:param patch_diffusers: patches :epkg:`diffusers` with supported implementation
|
|
168
176
|
:param catch_constraints: catch constraints related to dynamic shapes,
|
|
169
177
|
as a result, some dynamic dimension may turn into static ones,
|
|
170
178
|
the environment variable ``SKIP_SOLVE_CONSTRAINTS=0``
|
|
@@ -174,8 +182,8 @@ def torch_export_patches(
|
|
|
174
182
|
and show a stack trace indicating the exact location of the issue,
|
|
175
183
|
``if stop_if_static > 1``, more methods are replace to catch more
|
|
176
184
|
issues
|
|
177
|
-
:param patch: if False, disable all patches
|
|
178
|
-
serialization
|
|
185
|
+
:param patch: if False, disable all patches but keeps the registration of
|
|
186
|
+
serialization functions if other patch functions are enabled
|
|
179
187
|
:param custom_patches: to apply custom patches,
|
|
180
188
|
every patched class must define static attributes
|
|
181
189
|
``_PATCHES_``, ``_PATCHED_CLASS_``
|
|
@@ -249,6 +257,7 @@ def torch_export_patches(
|
|
|
249
257
|
patch_sympy=patch_sympy,
|
|
250
258
|
patch_torch=patch_torch,
|
|
251
259
|
patch_transformers=patch_transformers,
|
|
260
|
+
patch_diffusers=patch_diffusers,
|
|
252
261
|
catch_constraints=catch_constraints,
|
|
253
262
|
stop_if_static=stop_if_static,
|
|
254
263
|
verbose=verbose,
|
|
@@ -261,7 +270,11 @@ def torch_export_patches(
|
|
|
261
270
|
pass
|
|
262
271
|
elif not patch:
|
|
263
272
|
fct_callable = lambda x: x # noqa: E731
|
|
264
|
-
done = register_cache_serialization(
|
|
273
|
+
done = register_cache_serialization(
|
|
274
|
+
patch_transformers=patch_transformers,
|
|
275
|
+
patch_diffusers=patch_diffusers,
|
|
276
|
+
verbose=verbose,
|
|
277
|
+
)
|
|
265
278
|
try:
|
|
266
279
|
yield fct_callable
|
|
267
280
|
finally:
|
|
@@ -281,7 +294,11 @@ def torch_export_patches(
|
|
|
281
294
|
# caches
|
|
282
295
|
########
|
|
283
296
|
|
|
284
|
-
cache_done = register_cache_serialization(
|
|
297
|
+
cache_done = register_cache_serialization(
|
|
298
|
+
patch_transformers=patch_transformers,
|
|
299
|
+
patch_diffusers=patch_diffusers,
|
|
300
|
+
verbose=verbose,
|
|
301
|
+
)
|
|
285
302
|
|
|
286
303
|
#############
|
|
287
304
|
# patch sympy
|