onnx-diagnostic 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +22 -5
- onnx_diagnostic/ext_test_case.py +31 -0
- onnx_diagnostic/helpers/cache_helper.py +23 -12
- onnx_diagnostic/helpers/config_helper.py +16 -1
- onnx_diagnostic/helpers/log_helper.py +308 -83
- onnx_diagnostic/helpers/rt_helper.py +11 -1
- onnx_diagnostic/helpers/torch_helper.py +7 -3
- onnx_diagnostic/tasks/__init__.py +2 -0
- onnx_diagnostic/tasks/text_generation.py +17 -8
- onnx_diagnostic/tasks/text_to_image.py +91 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +3 -1
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +148 -351
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +89 -10
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +259 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +15 -4
- onnx_diagnostic/torch_models/hghub/hub_data.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +28 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -5
- onnx_diagnostic/torch_models/validate.py +36 -12
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/METADATA +26 -1
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/RECORD +28 -24
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
from ..helpers.config_helper import update_config, check_hasattr, pick
|
|
4
|
+
|
|
5
|
+
__TASK__ = "text-to-image"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
9
|
+
"""Reduces a model size."""
|
|
10
|
+
check_hasattr(config, "sample_size", "cross_attention_dim")
|
|
11
|
+
kwargs = dict(
|
|
12
|
+
sample_size=min(config["sample_size"], 32),
|
|
13
|
+
cross_attention_dim=min(config["cross_attention_dim"], 64),
|
|
14
|
+
)
|
|
15
|
+
update_config(config, kwargs)
|
|
16
|
+
return kwargs
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_inputs(
|
|
20
|
+
model: torch.nn.Module,
|
|
21
|
+
config: Optional[Any],
|
|
22
|
+
batch_size: int,
|
|
23
|
+
sequence_length: int,
|
|
24
|
+
cache_length: int,
|
|
25
|
+
in_channels: int,
|
|
26
|
+
sample_size: int,
|
|
27
|
+
cross_attention_dim: int,
|
|
28
|
+
add_second_input: bool = False,
|
|
29
|
+
**kwargs, # unused
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Generates inputs for task ``text-to-image``.
|
|
33
|
+
Example:
|
|
34
|
+
|
|
35
|
+
::
|
|
36
|
+
|
|
37
|
+
sample:T10s2x4x96x96[-3.7734375,4.359375:A-0.043463995395642184]
|
|
38
|
+
timestep:T7s=101
|
|
39
|
+
encoder_hidden_states:T10s2x77x1024[-6.58203125,13.0234375:A-0.16780663634440257]
|
|
40
|
+
"""
|
|
41
|
+
assert (
|
|
42
|
+
"cls_cache" not in kwargs
|
|
43
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
44
|
+
batch = "batch"
|
|
45
|
+
shapes = {
|
|
46
|
+
"sample": {0: batch},
|
|
47
|
+
"timestep": {},
|
|
48
|
+
"encoder_hidden_states": {0: batch, 1: "encoder_length"},
|
|
49
|
+
}
|
|
50
|
+
inputs = dict(
|
|
51
|
+
sample=torch.randn((batch_size, sequence_length, sample_size, sample_size)).to(
|
|
52
|
+
torch.float32
|
|
53
|
+
),
|
|
54
|
+
timestep=torch.tensor([101], dtype=torch.int64),
|
|
55
|
+
encoder_hidden_states=torch.randn(
|
|
56
|
+
(batch_size, sequence_length, cross_attention_dim)
|
|
57
|
+
).to(torch.float32),
|
|
58
|
+
)
|
|
59
|
+
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
60
|
+
if add_second_input:
|
|
61
|
+
res["inputs2"] = get_inputs(
|
|
62
|
+
model=model,
|
|
63
|
+
config=config,
|
|
64
|
+
batch_size=batch_size + 1,
|
|
65
|
+
sequence_length=sequence_length,
|
|
66
|
+
cache_length=cache_length + 1,
|
|
67
|
+
in_channels=in_channels,
|
|
68
|
+
sample_size=sample_size,
|
|
69
|
+
cross_attention_dim=cross_attention_dim,
|
|
70
|
+
**kwargs,
|
|
71
|
+
)["inputs"]
|
|
72
|
+
return res
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
76
|
+
"""
|
|
77
|
+
Inputs kwargs.
|
|
78
|
+
|
|
79
|
+
If the configuration is None, the function selects typical dimensions.
|
|
80
|
+
"""
|
|
81
|
+
if config is not None:
|
|
82
|
+
check_hasattr(config, "sample_size", "cross_attention_dim", "in_channels")
|
|
83
|
+
kwargs = dict(
|
|
84
|
+
batch_size=2,
|
|
85
|
+
sequence_length=pick(config, "in_channels", 4),
|
|
86
|
+
cache_length=77,
|
|
87
|
+
in_channels=pick(config, "in_channels", 4),
|
|
88
|
+
sample_size=pick(config, "sample_size", 32),
|
|
89
|
+
cross_attention_dim=pick(config, "cross_attention_dim", 64),
|
|
90
|
+
)
|
|
91
|
+
return kwargs, get_inputs
|
|
@@ -337,7 +337,7 @@ def _make_exporter_onnx(
|
|
|
337
337
|
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
|
|
338
338
|
|
|
339
339
|
opts = {}
|
|
340
|
-
opts["strict"] = "-
|
|
340
|
+
opts["strict"] = "-strict" in exporter
|
|
341
341
|
opts["fallback"] = "-fallback" in exporter
|
|
342
342
|
opts["tracing"] = "-tracing" in exporter
|
|
343
343
|
opts["jit"] = "-jit" in exporter
|
|
@@ -520,6 +520,8 @@ def run_exporter(
|
|
|
520
520
|
return res
|
|
521
521
|
|
|
522
522
|
onx, builder = res
|
|
523
|
+
base["onx"] = onx
|
|
524
|
+
base["builder"] = builder
|
|
523
525
|
if verbose >= 9:
|
|
524
526
|
print("[run_exporter] onnx model")
|
|
525
527
|
print(
|
|
@@ -134,11 +134,17 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
|
|
|
134
134
|
|
|
135
135
|
@contextlib.contextmanager
|
|
136
136
|
def register_additional_serialization_functions(
|
|
137
|
-
patch_transformers: bool = False, verbose: int = 0
|
|
137
|
+
patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0
|
|
138
138
|
) -> Callable:
|
|
139
139
|
"""The necessary modifications to run the fx Graph."""
|
|
140
|
-
fct_callable =
|
|
141
|
-
|
|
140
|
+
fct_callable = (
|
|
141
|
+
replacement_before_exporting
|
|
142
|
+
if patch_transformers or patch_diffusers
|
|
143
|
+
else (lambda x: x)
|
|
144
|
+
)
|
|
145
|
+
done = register_cache_serialization(
|
|
146
|
+
patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose
|
|
147
|
+
)
|
|
142
148
|
try:
|
|
143
149
|
yield fct_callable
|
|
144
150
|
finally:
|
|
@@ -150,6 +156,7 @@ def torch_export_patches(
|
|
|
150
156
|
patch_sympy: bool = True,
|
|
151
157
|
patch_torch: bool = True,
|
|
152
158
|
patch_transformers: bool = False,
|
|
159
|
+
patch_diffusers: bool = False,
|
|
153
160
|
catch_constraints: bool = True,
|
|
154
161
|
stop_if_static: int = 0,
|
|
155
162
|
verbose: int = 0,
|
|
@@ -165,6 +172,7 @@ def torch_export_patches(
|
|
|
165
172
|
:param patch_sympy: fix missing method ``name`` for IntegerConstant
|
|
166
173
|
:param patch_torch: patches :epkg:`torch` with supported implementation
|
|
167
174
|
:param patch_transformers: patches :epkg:`transformers` with supported implementation
|
|
175
|
+
:param patch_diffusers: patches :epkg:`diffusers` with supported implementation
|
|
168
176
|
:param catch_constraints: catch constraints related to dynamic shapes,
|
|
169
177
|
as a result, some dynamic dimension may turn into static ones,
|
|
170
178
|
the environment variable ``SKIP_SOLVE_CONSTRAINTS=0``
|
|
@@ -174,8 +182,8 @@ def torch_export_patches(
|
|
|
174
182
|
and show a stack trace indicating the exact location of the issue,
|
|
175
183
|
``if stop_if_static > 1``, more methods are replace to catch more
|
|
176
184
|
issues
|
|
177
|
-
:param patch: if False, disable all patches
|
|
178
|
-
serialization
|
|
185
|
+
:param patch: if False, disable all patches but keeps the registration of
|
|
186
|
+
serialization functions if other patch functions are enabled
|
|
179
187
|
:param custom_patches: to apply custom patches,
|
|
180
188
|
every patched class must define static attributes
|
|
181
189
|
``_PATCHES_``, ``_PATCHED_CLASS_``
|
|
@@ -249,6 +257,7 @@ def torch_export_patches(
|
|
|
249
257
|
patch_sympy=patch_sympy,
|
|
250
258
|
patch_torch=patch_torch,
|
|
251
259
|
patch_transformers=patch_transformers,
|
|
260
|
+
patch_diffusers=patch_diffusers,
|
|
252
261
|
catch_constraints=catch_constraints,
|
|
253
262
|
stop_if_static=stop_if_static,
|
|
254
263
|
verbose=verbose,
|
|
@@ -261,7 +270,11 @@ def torch_export_patches(
|
|
|
261
270
|
pass
|
|
262
271
|
elif not patch:
|
|
263
272
|
fct_callable = lambda x: x # noqa: E731
|
|
264
|
-
done = register_cache_serialization(
|
|
273
|
+
done = register_cache_serialization(
|
|
274
|
+
patch_transformers=patch_transformers,
|
|
275
|
+
patch_diffusers=patch_diffusers,
|
|
276
|
+
verbose=verbose,
|
|
277
|
+
)
|
|
265
278
|
try:
|
|
266
279
|
yield fct_callable
|
|
267
280
|
finally:
|
|
@@ -281,7 +294,11 @@ def torch_export_patches(
|
|
|
281
294
|
# caches
|
|
282
295
|
########
|
|
283
296
|
|
|
284
|
-
cache_done = register_cache_serialization(
|
|
297
|
+
cache_done = register_cache_serialization(
|
|
298
|
+
patch_transformers=patch_transformers,
|
|
299
|
+
patch_diffusers=patch_diffusers,
|
|
300
|
+
verbose=verbose,
|
|
301
|
+
)
|
|
285
302
|
|
|
286
303
|
#############
|
|
287
304
|
# patch sympy
|