onnx-diagnostic 0.7.15__py3-none-any.whl → 0.7.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +154 -52
- onnx_diagnostic/export/dynamic_shapes.py +6 -6
- onnx_diagnostic/export/shape_helper.py +124 -6
- onnx_diagnostic/ext_test_case.py +5 -1
- onnx_diagnostic/helpers/cache_helper.py +67 -41
- onnx_diagnostic/helpers/fake_tensor_helper.py +153 -0
- onnx_diagnostic/helpers/helper.py +3 -0
- onnx_diagnostic/tasks/image_text_to_text.py +1 -1
- onnx_diagnostic/tasks/text_generation.py +1 -4
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +68 -10
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +84 -5
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +54 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +38 -0
- onnx_diagnostic/torch_models/validate.py +39 -20
- {onnx_diagnostic-0.7.15.dist-info → onnx_diagnostic-0.7.16.dist-info}/METADATA +6 -6
- {onnx_diagnostic-0.7.15.dist-info → onnx_diagnostic-0.7.16.dist-info}/RECORD +21 -19
- {onnx_diagnostic-0.7.15.dist-info → onnx_diagnostic-0.7.16.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.15.dist-info → onnx_diagnostic-0.7.16.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.15.dist-info → onnx_diagnostic-0.7.16.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import textwrap
|
|
3
|
+
import torch
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
5
|
+
from ..helpers import flatten_object
|
|
6
|
+
from ..helpers.torch_helper import to_any
|
|
7
|
+
from .hghub.model_inputs import _preprocess_model_id
|
|
8
|
+
from .hghub import get_untrained_model_with_inputs
|
|
9
|
+
from .validate import filter_inputs, make_patch_kwargs
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
CODE_SAMPLES = {
|
|
13
|
+
"imports": "from typing import Any\nimport torch",
|
|
14
|
+
"get_model_with_inputs": textwrap.dedent(
|
|
15
|
+
"""
|
|
16
|
+
def get_model_with_inputs(
|
|
17
|
+
model_id:str,
|
|
18
|
+
subfolder: str | None = None,
|
|
19
|
+
dtype: str | torch.dtype | None = None,
|
|
20
|
+
device: str | torch.device | None = None,
|
|
21
|
+
same_as_pretrained: bool = False,
|
|
22
|
+
use_pretrained: bool = False,
|
|
23
|
+
input_options: dict[str, Any] | None = None,
|
|
24
|
+
model_options: dict[str, Any] | None = None,
|
|
25
|
+
) -> dict[str, Any]:
|
|
26
|
+
if use_pretrained:
|
|
27
|
+
import transformers
|
|
28
|
+
assert same_as_pretrained, (
|
|
29
|
+
"same_as_pretrained must be True if use_pretrained is True"
|
|
30
|
+
)
|
|
31
|
+
# tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
32
|
+
model = transformers.AutoModel.from_pretrained(
|
|
33
|
+
model_id,
|
|
34
|
+
trust_remote_code=True,
|
|
35
|
+
subfolder=subfolder,
|
|
36
|
+
dtype=dtype,
|
|
37
|
+
device=device,
|
|
38
|
+
)
|
|
39
|
+
data = {"model": model}
|
|
40
|
+
assert not input_options, "Not implemented yet with input_options{input_options}"
|
|
41
|
+
assert not model_options, "Not implemented yet with input_options{model_options}"
|
|
42
|
+
else:
|
|
43
|
+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
|
|
44
|
+
data = get_untrained_model_with_inputs(
|
|
45
|
+
model_id,
|
|
46
|
+
use_pretrained=use_pretrained,
|
|
47
|
+
same_as_pretrained=same_as_pretrained,
|
|
48
|
+
inputs_kwargs=input_options,
|
|
49
|
+
model_kwargs=model_options,
|
|
50
|
+
subfolder=subfolder,
|
|
51
|
+
add_second_input=False,
|
|
52
|
+
)
|
|
53
|
+
if dtype:
|
|
54
|
+
data["model"] = data["model"].to(
|
|
55
|
+
getattr(torch, dtype) if isinstance(dtype, str) else dtype
|
|
56
|
+
)
|
|
57
|
+
if device:
|
|
58
|
+
data["model"] = data["model"].to(device)
|
|
59
|
+
return data["model"]
|
|
60
|
+
"""
|
|
61
|
+
),
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def make_code_for_inputs(inputs: Dict[str, torch.Tensor]) -> str:
|
|
66
|
+
"""
|
|
67
|
+
Creates a code to generate random inputs.
|
|
68
|
+
|
|
69
|
+
:param inputs: dictionary
|
|
70
|
+
:return: code
|
|
71
|
+
"""
|
|
72
|
+
codes = []
|
|
73
|
+
for k, v in inputs.items():
|
|
74
|
+
if isinstance(v, (int, bool, float)):
|
|
75
|
+
code = f"{k}={v}"
|
|
76
|
+
elif isinstance(v, torch.Tensor):
|
|
77
|
+
shape = tuple(map(int, v.shape))
|
|
78
|
+
if v.dtype in (torch.int32, torch.int64):
|
|
79
|
+
code = f"{k}=torch.randint({v.max()}, size={shape}, dtype={v.dtype})"
|
|
80
|
+
elif v.dtype in (torch.float32, torch.float16, torch.bfloat16):
|
|
81
|
+
code = f"{k}=torch.rand({shape}, dtype={v.dtype})"
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError(f"Unexpected dtype = {v.dtype} for k={k!r}")
|
|
84
|
+
elif v.__class__.__name__ == "DynamicCache":
|
|
85
|
+
obj = flatten_object(v)
|
|
86
|
+
cc = [f"torch.rand({tuple(map(int,_.shape))}, dtype={_.dtype})" for _ in obj]
|
|
87
|
+
va = [f"({a},{b})" for a, b in zip(cc[: len(cc) // 2], cc[len(cc) // 2 :])]
|
|
88
|
+
va2 = ", ".join(va)
|
|
89
|
+
code = f"{k}=make_dynamic_cache([{va2}])"
|
|
90
|
+
else:
|
|
91
|
+
raise ValueError(f"Unexpected type {type(v)} for k={k!r}")
|
|
92
|
+
codes.append(code)
|
|
93
|
+
st = ", ".join(codes)
|
|
94
|
+
return f"dict({st})"
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def make_export_code(
|
|
98
|
+
exporter: str,
|
|
99
|
+
optimization: Optional[str] = None,
|
|
100
|
+
patch_kwargs: Optional[Dict[str, Any]] = None,
|
|
101
|
+
stop_if_static: int = 0,
|
|
102
|
+
dump_folder: Optional[str] = None,
|
|
103
|
+
opset: Optional[int] = None,
|
|
104
|
+
dynamic_shapes: Optional[Dict[str, Any]] = None,
|
|
105
|
+
output_names: Optional[List[str]] = None,
|
|
106
|
+
verbose: int = 0,
|
|
107
|
+
) -> Tuple[str, str]:
|
|
108
|
+
args = [f"dynamic_shapes={dynamic_shapes}"]
|
|
109
|
+
if output_names:
|
|
110
|
+
args.append(f"output_names={output_names}")
|
|
111
|
+
code = []
|
|
112
|
+
imports = []
|
|
113
|
+
if dump_folder:
|
|
114
|
+
code.append(f"os.makedirs({dump_folder!r})")
|
|
115
|
+
imports.append("import os")
|
|
116
|
+
filename = os.path.join(dump_folder, "model.onnx")
|
|
117
|
+
if exporter == "custom":
|
|
118
|
+
if opset:
|
|
119
|
+
args.append(f"target_opset={opset}")
|
|
120
|
+
if optimization:
|
|
121
|
+
args.append(f"options=OptimizationOptions(patterns={optimization!r})")
|
|
122
|
+
args.append(f"large_model=True, filename={filename!r}")
|
|
123
|
+
sargs = ", ".join(args)
|
|
124
|
+
imports.extend(
|
|
125
|
+
[
|
|
126
|
+
"from experimental_experiment.torch_interpreter import to_onnx",
|
|
127
|
+
"from experimental_experiment.xbuilder import OptimizationOptions",
|
|
128
|
+
]
|
|
129
|
+
)
|
|
130
|
+
code.extend([f"onx = to_onnx(model, inputs, {sargs})"])
|
|
131
|
+
elif exporter == "onnx-dynamo":
|
|
132
|
+
if opset:
|
|
133
|
+
args.append(f"opset_version={opset}")
|
|
134
|
+
sargs = ", ".join(args)
|
|
135
|
+
code.extend([f"epo = torch.onnx.export(model, args=(), kwargs=inputs, {sargs})"])
|
|
136
|
+
if optimization:
|
|
137
|
+
imports.append("import onnxscript")
|
|
138
|
+
code.extend(["onnxscript.optimizer.optimize_ir(epo.model)"])
|
|
139
|
+
if "os_ort" in optimization:
|
|
140
|
+
imports.append("import onnxscript.rewriter.ort_fusions as ort_fusions")
|
|
141
|
+
code.extend(["ort_fusions.optimize_for_ort(epo.model)"])
|
|
142
|
+
if dump_folder:
|
|
143
|
+
code.extend([f"epo.save({filename!r})"])
|
|
144
|
+
else:
|
|
145
|
+
raise ValueError(f"Unexpected exporter {exporter!r}")
|
|
146
|
+
if not patch_kwargs:
|
|
147
|
+
return "\n".join(imports), "\n".join(code)
|
|
148
|
+
|
|
149
|
+
imports.append("from onnx_diagnostic.torch_export_patches import torch_export_patches")
|
|
150
|
+
if stop_if_static:
|
|
151
|
+
patch_kwargs["stop_if_static"] = stop_if_static
|
|
152
|
+
sargs = ", ".join(f"{k}={v}" for k, v in patch_kwargs.items())
|
|
153
|
+
code = [f"with torch_export_patches({sargs}):", *[" " + _ for _ in code]]
|
|
154
|
+
return "\n".join(imports), "\n".join(code)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def code_sample(
|
|
158
|
+
model_id: str,
|
|
159
|
+
task: Optional[str] = None,
|
|
160
|
+
do_run: bool = False,
|
|
161
|
+
exporter: Optional[str] = None,
|
|
162
|
+
do_same: bool = False,
|
|
163
|
+
verbose: int = 0,
|
|
164
|
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
|
165
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
166
|
+
same_as_pretrained: bool = False,
|
|
167
|
+
use_pretrained: bool = False,
|
|
168
|
+
optimization: Optional[str] = None,
|
|
169
|
+
quiet: bool = False,
|
|
170
|
+
patch: Union[bool, str, Dict[str, bool]] = False,
|
|
171
|
+
rewrite: bool = False,
|
|
172
|
+
stop_if_static: int = 1,
|
|
173
|
+
dump_folder: Optional[str] = None,
|
|
174
|
+
drop_inputs: Optional[List[str]] = None,
|
|
175
|
+
input_options: Optional[Dict[str, Any]] = None,
|
|
176
|
+
model_options: Optional[Dict[str, Any]] = None,
|
|
177
|
+
subfolder: Optional[str] = None,
|
|
178
|
+
opset: Optional[int] = None,
|
|
179
|
+
runtime: str = "onnxruntime",
|
|
180
|
+
output_names: Optional[List[str]] = None,
|
|
181
|
+
) -> str:
|
|
182
|
+
"""
|
|
183
|
+
This generates a code to export a model with the proper settings.
|
|
184
|
+
|
|
185
|
+
:param model_id: model id to validate
|
|
186
|
+
:param task: task used to generate the necessary inputs,
|
|
187
|
+
can be left empty to use the default task for this model
|
|
188
|
+
if it can be determined
|
|
189
|
+
:param do_run: checks the model works with the defined inputs
|
|
190
|
+
:param exporter: exporter the model using this exporter,
|
|
191
|
+
available list: ``export-strict``, ``export-nostrict``, ...
|
|
192
|
+
see below
|
|
193
|
+
:param do_same: checks the discrepancies of the exported model
|
|
194
|
+
:param verbose: verbosity level
|
|
195
|
+
:param dtype: uses this dtype to check the model
|
|
196
|
+
:param device: do the verification on this device
|
|
197
|
+
:param same_as_pretrained: use a model equivalent to the trained,
|
|
198
|
+
this is not always possible
|
|
199
|
+
:param use_pretrained: use the trained model, not the untrained one
|
|
200
|
+
:param optimization: optimization to apply to the exported model,
|
|
201
|
+
depend on the the exporter
|
|
202
|
+
:param quiet: if quiet, catches exception if any issue
|
|
203
|
+
:param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
|
|
204
|
+
if True before exporting
|
|
205
|
+
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`,
|
|
206
|
+
a string can be used to specify only one of them
|
|
207
|
+
:param rewrite: applies known rewriting (``patch_transformers=True``) before exporting,
|
|
208
|
+
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
|
|
209
|
+
:param stop_if_static: stops if a dynamic dimension becomes static,
|
|
210
|
+
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
|
|
211
|
+
:param dump_folder: dumps everything in a subfolder of this one
|
|
212
|
+
:param drop_inputs: drops this list of inputs (given their names)
|
|
213
|
+
:param input_options: additional options to define the dummy inputs
|
|
214
|
+
used to export
|
|
215
|
+
:param model_options: additional options when creating the model such as
|
|
216
|
+
``num_hidden_layers`` or ``attn_implementation``
|
|
217
|
+
:param subfolder: version or subfolders to uses when retrieving a model id
|
|
218
|
+
:param opset: onnx opset to use for the conversion
|
|
219
|
+
:param runtime: onnx runtime to use to check about discrepancies,
|
|
220
|
+
possible values ``onnxruntime``, ``torch``, ``orteval``,
|
|
221
|
+
``orteval10``, ``ref`` only if `do_run` is true
|
|
222
|
+
:param output_names: output names the onnx exporter should use
|
|
223
|
+
:return: a code
|
|
224
|
+
|
|
225
|
+
.. runpython::
|
|
226
|
+
:showcode:
|
|
227
|
+
|
|
228
|
+
from onnx_diagnostic.torch_models.code_sample import code_sample
|
|
229
|
+
|
|
230
|
+
print(
|
|
231
|
+
code_sample(
|
|
232
|
+
"arnir0/Tiny-LLM",
|
|
233
|
+
exporter="onnx-dynamo",
|
|
234
|
+
optimization="ir",
|
|
235
|
+
patch=True,
|
|
236
|
+
)
|
|
237
|
+
)
|
|
238
|
+
"""
|
|
239
|
+
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
|
|
240
|
+
model_id,
|
|
241
|
+
subfolder,
|
|
242
|
+
same_as_pretrained=same_as_pretrained,
|
|
243
|
+
use_pretrained=use_pretrained,
|
|
244
|
+
)
|
|
245
|
+
patch_kwargs = make_patch_kwargs(patch=patch, rewrite=rewrite)
|
|
246
|
+
|
|
247
|
+
iop = input_options or {}
|
|
248
|
+
mop = model_options or {}
|
|
249
|
+
data = get_untrained_model_with_inputs(
|
|
250
|
+
model_id,
|
|
251
|
+
verbose=verbose,
|
|
252
|
+
task=task,
|
|
253
|
+
use_pretrained=use_pretrained,
|
|
254
|
+
same_as_pretrained=same_as_pretrained,
|
|
255
|
+
inputs_kwargs=iop,
|
|
256
|
+
model_kwargs=mop,
|
|
257
|
+
subfolder=subfolder,
|
|
258
|
+
add_second_input=False,
|
|
259
|
+
)
|
|
260
|
+
if drop_inputs:
|
|
261
|
+
update = {}
|
|
262
|
+
for k in data:
|
|
263
|
+
if k.startswith("inputs"):
|
|
264
|
+
update[k], ds = filter_inputs(
|
|
265
|
+
data[k],
|
|
266
|
+
drop_names=drop_inputs,
|
|
267
|
+
model=data["model"],
|
|
268
|
+
dynamic_shapes=data["dynamic_shapes"],
|
|
269
|
+
)
|
|
270
|
+
update["dynamic_shapes"] = ds
|
|
271
|
+
data.update(update)
|
|
272
|
+
|
|
273
|
+
update = {}
|
|
274
|
+
for k in data:
|
|
275
|
+
if k.startswith("inputs"):
|
|
276
|
+
v = data[k]
|
|
277
|
+
if dtype:
|
|
278
|
+
update[k] = v = to_any(
|
|
279
|
+
v, getattr(torch, dtype) if isinstance(dtype, str) else dtype
|
|
280
|
+
)
|
|
281
|
+
if device:
|
|
282
|
+
update[k] = v = to_any(v, device)
|
|
283
|
+
if update:
|
|
284
|
+
data.update(update)
|
|
285
|
+
|
|
286
|
+
args = [f"{model_id!r}"]
|
|
287
|
+
if subfolder:
|
|
288
|
+
args.append(f"subfolder={subfolder!r}")
|
|
289
|
+
if dtype:
|
|
290
|
+
args.append(f"dtype={dtype!r}")
|
|
291
|
+
if device:
|
|
292
|
+
args.append(f"device={device!r}")
|
|
293
|
+
if same_as_pretrained:
|
|
294
|
+
args.append(f"same_as_pretrained={same_as_pretrained!r}")
|
|
295
|
+
if use_pretrained:
|
|
296
|
+
args.append(f"use_pretrained={use_pretrained!r}")
|
|
297
|
+
if input_options:
|
|
298
|
+
args.append(f"input_options={input_options!r}")
|
|
299
|
+
if model_options:
|
|
300
|
+
args.append(f"model_options={model_options!r}")
|
|
301
|
+
model_args = ", ".join(args)
|
|
302
|
+
imports, exporter_code = (
|
|
303
|
+
make_export_code(
|
|
304
|
+
exporter=exporter,
|
|
305
|
+
patch_kwargs=patch_kwargs,
|
|
306
|
+
verbose=verbose,
|
|
307
|
+
optimization=optimization,
|
|
308
|
+
stop_if_static=stop_if_static,
|
|
309
|
+
dump_folder=dump_folder,
|
|
310
|
+
opset=opset,
|
|
311
|
+
dynamic_shapes=data["dynamic_shapes"],
|
|
312
|
+
)
|
|
313
|
+
if exporter is not None
|
|
314
|
+
else ([], [])
|
|
315
|
+
)
|
|
316
|
+
input_code = make_code_for_inputs(data["inputs"])
|
|
317
|
+
cache_import = (
|
|
318
|
+
"from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache"
|
|
319
|
+
if "dynamic_cache" in input_code
|
|
320
|
+
else ""
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
pieces = [
|
|
324
|
+
CODE_SAMPLES["imports"],
|
|
325
|
+
imports,
|
|
326
|
+
cache_import,
|
|
327
|
+
CODE_SAMPLES["get_model_with_inputs"],
|
|
328
|
+
textwrap.dedent(
|
|
329
|
+
f"""
|
|
330
|
+
model = get_model_with_inputs({model_args})
|
|
331
|
+
"""
|
|
332
|
+
),
|
|
333
|
+
f"inputs = {input_code}",
|
|
334
|
+
exporter_code,
|
|
335
|
+
]
|
|
336
|
+
code = "\n".join(pieces) # type: ignore[arg-type]
|
|
337
|
+
try:
|
|
338
|
+
import black
|
|
339
|
+
except ImportError:
|
|
340
|
+
# No black formatting.
|
|
341
|
+
return code
|
|
342
|
+
|
|
343
|
+
return black.format_str(code, mode=black.Mode())
|
|
@@ -4865,3 +4865,41 @@ def _ccached_google_gemma_3_4b_it_like():
|
|
|
4865
4865
|
},
|
|
4866
4866
|
}
|
|
4867
4867
|
)
|
|
4868
|
+
|
|
4869
|
+
|
|
4870
|
+
def _ccached_hf_internal_testing_tiny_random_gemma3_for_causal_lm():
|
|
4871
|
+
"hf-internal-testing/tiny-random-Gemma3ForCausalLM"
|
|
4872
|
+
return transformers.Gemma3TextConfig(
|
|
4873
|
+
**{
|
|
4874
|
+
"architectures": ["Gemma3ForCausalLM"],
|
|
4875
|
+
"attention_bias": false,
|
|
4876
|
+
"attention_dropout": 0.0,
|
|
4877
|
+
"attn_logit_softcapping": null,
|
|
4878
|
+
"bos_token_id": 2,
|
|
4879
|
+
"cache_implementation": "hybrid",
|
|
4880
|
+
"eos_token_id": [1, 106],
|
|
4881
|
+
"final_logit_softcapping": null,
|
|
4882
|
+
"head_dim": 8,
|
|
4883
|
+
"hidden_activation": "gelu_pytorch_tanh",
|
|
4884
|
+
"hidden_size": 16,
|
|
4885
|
+
"initializer_range": 0.02,
|
|
4886
|
+
"intermediate_size": 32,
|
|
4887
|
+
"max_position_embeddings": 32768,
|
|
4888
|
+
"model_type": "gemma3_text",
|
|
4889
|
+
"num_attention_heads": 2,
|
|
4890
|
+
"num_hidden_layers": 2,
|
|
4891
|
+
"num_key_value_heads": 1,
|
|
4892
|
+
"pad_token_id": 0,
|
|
4893
|
+
"query_pre_attn_scalar": 256,
|
|
4894
|
+
"rms_norm_eps": 1e-06,
|
|
4895
|
+
"rope_local_base_freq": 10000,
|
|
4896
|
+
"rope_scaling": null,
|
|
4897
|
+
"rope_theta": 1000000,
|
|
4898
|
+
"sliding_window": 512,
|
|
4899
|
+
"sliding_window_pattern": 6,
|
|
4900
|
+
"torch_dtype": "float32",
|
|
4901
|
+
"transformers_version": "4.52.0.dev0",
|
|
4902
|
+
"use_cache": true,
|
|
4903
|
+
"vocab_size": 262144,
|
|
4904
|
+
}
|
|
4905
|
+
)
|
|
@@ -123,8 +123,8 @@ def _make_folder_name(
|
|
|
123
123
|
els = [model_id.replace("/", "_")]
|
|
124
124
|
if subfolder:
|
|
125
125
|
els.append(subfolder.replace("/", "_"))
|
|
126
|
-
if
|
|
127
|
-
els.append(task)
|
|
126
|
+
if task:
|
|
127
|
+
els.append(task)
|
|
128
128
|
if drop_inputs:
|
|
129
129
|
ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs)
|
|
130
130
|
els.append(f"I-{ii.upper()}")
|
|
@@ -293,6 +293,33 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
293
293
|
return new_cfg
|
|
294
294
|
|
|
295
295
|
|
|
296
|
+
def make_patch_kwargs(
|
|
297
|
+
patch: Union[bool, str, Dict[str, bool]] = False,
|
|
298
|
+
rewrite: bool = False,
|
|
299
|
+
) -> Dict[str, Any]:
|
|
300
|
+
"""Creates patch arguments."""
|
|
301
|
+
default_patch = dict(patch_transformers=True, patch_diffusers=True, patch=True)
|
|
302
|
+
if isinstance(patch, bool):
|
|
303
|
+
patch_kwargs = default_patch if patch else dict(patch=False)
|
|
304
|
+
elif isinstance(patch, str):
|
|
305
|
+
patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
|
|
306
|
+
else:
|
|
307
|
+
assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}"
|
|
308
|
+
patch_kwargs = patch.copy()
|
|
309
|
+
if "patch" not in patch_kwargs:
|
|
310
|
+
if any(patch_kwargs.values()):
|
|
311
|
+
patch_kwargs["patch"] = True
|
|
312
|
+
elif len(patch) == 1 and patch.get("patch", False):
|
|
313
|
+
patch_kwargs.update(default_patch)
|
|
314
|
+
|
|
315
|
+
assert not rewrite or patch_kwargs.get("patch", False), (
|
|
316
|
+
f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
|
|
317
|
+
f"patch must be True to enable rewriting, "
|
|
318
|
+
f"if --patch=0 was specified on the command line, rewrites are disabled."
|
|
319
|
+
)
|
|
320
|
+
return patch_kwargs
|
|
321
|
+
|
|
322
|
+
|
|
296
323
|
def validate_model(
|
|
297
324
|
model_id: str,
|
|
298
325
|
task: Optional[str] = None,
|
|
@@ -420,25 +447,8 @@ def validate_model(
|
|
|
420
447
|
use_pretrained=use_pretrained,
|
|
421
448
|
)
|
|
422
449
|
time_preprocess_model_id = time.perf_counter() - main_validation_begin
|
|
423
|
-
|
|
424
|
-
if isinstance(patch, bool):
|
|
425
|
-
patch_kwargs = default_patch if patch else dict(patch=False)
|
|
426
|
-
elif isinstance(patch, str):
|
|
427
|
-
patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
|
|
428
|
-
else:
|
|
429
|
-
assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}"
|
|
430
|
-
patch_kwargs = patch.copy()
|
|
431
|
-
if "patch" not in patch_kwargs:
|
|
432
|
-
if any(patch_kwargs.values()):
|
|
433
|
-
patch_kwargs["patch"] = True
|
|
434
|
-
elif len(patch) == 1 and patch.get("patch", False):
|
|
435
|
-
patch_kwargs.update(default_patch)
|
|
450
|
+
patch_kwargs = make_patch_kwargs(patch=patch, rewrite=rewrite)
|
|
436
451
|
|
|
437
|
-
assert not rewrite or patch_kwargs.get("patch", False), (
|
|
438
|
-
f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
|
|
439
|
-
f"patch must be True to enable rewriting, "
|
|
440
|
-
f"if --patch=0 was specified on the command line, rewrites are disabled."
|
|
441
|
-
)
|
|
442
452
|
summary = version_summary()
|
|
443
453
|
summary.update(
|
|
444
454
|
dict(
|
|
@@ -1890,6 +1900,7 @@ def call_torch_export_custom(
|
|
|
1890
1900
|
"custom-nostrict-all-noinline",
|
|
1891
1901
|
"custom-dec",
|
|
1892
1902
|
"custom-decall",
|
|
1903
|
+
"custom-fake",
|
|
1893
1904
|
}
|
|
1894
1905
|
assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
|
|
1895
1906
|
assert "model" in data, f"model is missing from data: {sorted(data)}"
|
|
@@ -1898,6 +1909,14 @@ def call_torch_export_custom(
|
|
|
1898
1909
|
strict = "-strict" in exporter
|
|
1899
1910
|
args, kwargs = split_args_kwargs(data["inputs_export"])
|
|
1900
1911
|
ds = data.get("dynamic_shapes", None)
|
|
1912
|
+
if "-fake" in exporter:
|
|
1913
|
+
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
|
|
1914
|
+
|
|
1915
|
+
if verbose:
|
|
1916
|
+
print("[call_torch_export_custom] switching to FakeTensor")
|
|
1917
|
+
assert not args, f"Exporter {exporter!r} not implemented with fake tensors."
|
|
1918
|
+
kwargs = torch_deepcopy(kwargs)
|
|
1919
|
+
kwargs, _ = make_fake_with_dynamic_dimensions(kwargs, dynamic_shapes=ds)
|
|
1901
1920
|
opset = data.get("model_opset", None)
|
|
1902
1921
|
if opset:
|
|
1903
1922
|
summary["export_opset"] = opset
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnx-diagnostic
|
|
3
|
-
Version: 0.7.
|
|
3
|
+
Version: 0.7.16
|
|
4
4
|
Summary: Tools to help converting pytorch models into ONNX.
|
|
5
5
|
Home-page: https://github.com/sdpython/onnx-diagnostic
|
|
6
6
|
Author: Xavier Dupré
|
|
@@ -60,9 +60,9 @@ You need then to remove those which are not dynamic in your model.
|
|
|
60
60
|
|
|
61
61
|
.. code-block:: python
|
|
62
62
|
|
|
63
|
-
from onnx_diagnostic.export.shape_helper import
|
|
63
|
+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
|
|
64
64
|
|
|
65
|
-
dynamic_shapes =
|
|
65
|
+
dynamic_shapes = all_dynamic_shapes_from_inputs(cache)
|
|
66
66
|
|
|
67
67
|
It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...).
|
|
68
68
|
See `documentation of onnx-diagnostic <https://sdpython.github.io/doc/onnx-diagnostic/dev/>`_ and
|
|
@@ -126,13 +126,13 @@ Snapshot of usefuls tools
|
|
|
126
126
|
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
|
|
127
127
|
# ...
|
|
128
128
|
|
|
129
|
-
**
|
|
129
|
+
**all_dynamic_shapes_from_inputs**
|
|
130
130
|
|
|
131
131
|
.. code-block:: python
|
|
132
132
|
|
|
133
|
-
from onnx_diagnostic.export.shape_helper import
|
|
133
|
+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
|
|
134
134
|
|
|
135
|
-
dynamic_shapes =
|
|
135
|
+
dynamic_shapes = all_dynamic_shapes_from_inputs(cache)
|
|
136
136
|
|
|
137
137
|
**torch_export_rewrite**
|
|
138
138
|
|
|
@@ -1,22 +1,23 @@
|
|
|
1
|
-
onnx_diagnostic/__init__.py,sha256=
|
|
1
|
+
onnx_diagnostic/__init__.py,sha256=eanCPlr_TSrwLKPl2uQn0CJzW015D7ORk2T08gXZg2o,174
|
|
2
2
|
onnx_diagnostic/__main__.py,sha256=YmyV_Aq_ianDlHyKLHMa6h8YK3ZmFPpLVHLKjM91aCk,79
|
|
3
|
-
onnx_diagnostic/_command_lines_parser.py,sha256=
|
|
3
|
+
onnx_diagnostic/_command_lines_parser.py,sha256=WjnVLqxMQ1sOo18dU4C-2nEZ40vkCy4MnlGBJAEuUOk,38173
|
|
4
4
|
onnx_diagnostic/api.py,sha256=BhCl_yCd78N7TlVtPOHjeYv1QBEy39TjZ647rcHqLh0,345
|
|
5
5
|
onnx_diagnostic/doc.py,sha256=t3RELgfooYnVMAi0JSpggWkQEgUsREz8NmRvn0TnLI8,2829
|
|
6
|
-
onnx_diagnostic/ext_test_case.py,sha256=
|
|
6
|
+
onnx_diagnostic/ext_test_case.py,sha256=f0bVfN0vqeeQy1mSb4YOThvJA5sCH3e0h9-u5x9hmrs,43620
|
|
7
7
|
onnx_diagnostic/export/__init__.py,sha256=yEIoWiOeTwBsDhyYt2fTKuhtA0Ya1J9u9ZzMTOTWaWs,101
|
|
8
|
-
onnx_diagnostic/export/dynamic_shapes.py,sha256=
|
|
9
|
-
onnx_diagnostic/export/shape_helper.py,sha256=
|
|
8
|
+
onnx_diagnostic/export/dynamic_shapes.py,sha256=dvpIafK-DgalLMhVudmJkgn316UVerAyMMnvwQCNK5g,42004
|
|
9
|
+
onnx_diagnostic/export/shape_helper.py,sha256=Qr5pZeVuRahYza5hRAVDGkhshArysdJ1UhNJ603UI9w,12632
|
|
10
10
|
onnx_diagnostic/export/validate.py,sha256=_PGUql2DJhIgGKo0WjTGUc5AgsZUx8fEs00MePy-w98,6043
|
|
11
11
|
onnx_diagnostic/helpers/__init__.py,sha256=GJ2GT7cgnlIveVUwMZhuvUwidbTJaKv8CsSIOpZDsJg,83
|
|
12
12
|
onnx_diagnostic/helpers/_log_helper.py,sha256=OTwQH0OIxs9B6nrSvR7MoxMimSw_8mU0mj133NvLk5o,16832
|
|
13
13
|
onnx_diagnostic/helpers/args_helper.py,sha256=SRWnqC7EENg09RZlA50B_PcdiIhdbgA4C3ACfzl5nMs,4419
|
|
14
14
|
onnx_diagnostic/helpers/bench_run.py,sha256=CGA6VMJZMH2gDhVueT9ypNm4PMcjGrrGFYp08nhWj9k,16539
|
|
15
|
-
onnx_diagnostic/helpers/cache_helper.py,sha256=
|
|
15
|
+
onnx_diagnostic/helpers/cache_helper.py,sha256=pvtx__UCBi367p6pD_-juIUgIyzEC_zfgmsYZdKzXSk,25899
|
|
16
16
|
onnx_diagnostic/helpers/config_helper.py,sha256=cWRETgFhZ7tayIZPnMqF8BF5AvTU64G2BMqyzgO7lzs,5670
|
|
17
17
|
onnx_diagnostic/helpers/doc_helper.py,sha256=pl5MZd3_FaE8BqQnqoBuSBxoNCFcd2OJd3eITUSku5c,5897
|
|
18
|
+
onnx_diagnostic/helpers/fake_tensor_helper.py,sha256=5KDX_bWbrenyYpBQUk7MNkpu28apn-qnshavmf22Uh8,5353
|
|
18
19
|
onnx_diagnostic/helpers/graph_helper.py,sha256=hevQT5a7_QuriVPQcbT5qe18n99Doyl5h3-qshx1-uk,14093
|
|
19
|
-
onnx_diagnostic/helpers/helper.py,sha256=
|
|
20
|
+
onnx_diagnostic/helpers/helper.py,sha256=ut8upptmarQp2bVivnFmOTokMKslTRC6u0qooTNvSPA,63477
|
|
20
21
|
onnx_diagnostic/helpers/log_helper.py,sha256=xBKz5rj2-jEtN_tFKsOV4RpBGermrv7CWqG3KUm2psI,87335
|
|
21
22
|
onnx_diagnostic/helpers/memory_peak.py,sha256=OT6mz0muBbBZY0pjgW2_eCk_lOtFRo-5w4jFo2Z6Kok,6380
|
|
22
23
|
onnx_diagnostic/helpers/mini_onnx_builder.py,sha256=Cgx1ojmV0S_JpZ_UqwsNxeULMMDvMInXslhkE34fwec,22051
|
|
@@ -77,7 +78,7 @@ onnx_diagnostic/tasks/automatic_speech_recognition.py,sha256=umZmjGW1gDUFkqvBJnQ
|
|
|
77
78
|
onnx_diagnostic/tasks/feature_extraction.py,sha256=Zh9p_Q8FqEO2_aqI0cCiq8OXuM3WUZbwItlLOmLnNl8,5537
|
|
78
79
|
onnx_diagnostic/tasks/fill_mask.py,sha256=5Gt6zlj0p6vuifox7Wmj-TpHXJvPS0CEH8evgdBHDNA,2640
|
|
79
80
|
onnx_diagnostic/tasks/image_classification.py,sha256=nLpBBB1Gkog3Fk6pu2waiHcuQr4ILPptc9FhQ-pn460,4682
|
|
80
|
-
onnx_diagnostic/tasks/image_text_to_text.py,sha256=
|
|
81
|
+
onnx_diagnostic/tasks/image_text_to_text.py,sha256=OdgyoT8FEZi-0HpWhFHWa8ySP9s_Tf3YsW46sbHPd-I,21576
|
|
81
82
|
onnx_diagnostic/tasks/image_to_video.py,sha256=SoF2cVIJr6P30Abp-FCuixFDh5RvTuNEOL36QthGY6U,3860
|
|
82
83
|
onnx_diagnostic/tasks/mask_generation.py,sha256=fjdD3rd-O-mFL0hQy3la3JXKth_0bH2HL7Eelq-3Dbs,5057
|
|
83
84
|
onnx_diagnostic/tasks/mixture_of_expert.py,sha256=al4tk1BrHidtRiHlAaiflWiJaAte0d5M8WcBioANG9k,2808
|
|
@@ -86,13 +87,13 @@ onnx_diagnostic/tasks/sentence_similarity.py,sha256=vPqNZgAnIvY0rKWPUTs0IlU3RFQD
|
|
|
86
87
|
onnx_diagnostic/tasks/summarization.py,sha256=8vB_JiRzDEacIvr8CYTuVQTH73xG_jNkndoS9RHJTSs,8292
|
|
87
88
|
onnx_diagnostic/tasks/text2text_generation.py,sha256=35eF_RlSeMdLTZPooLMAnszs-z0bkKZ34Iej3JgA96A,8602
|
|
88
89
|
onnx_diagnostic/tasks/text_classification.py,sha256=CGc72SpXFzTUyzAHEMPgyy_s187DaYGsRdrosxG80_Q,2711
|
|
89
|
-
onnx_diagnostic/tasks/text_generation.py,sha256=
|
|
90
|
+
onnx_diagnostic/tasks/text_generation.py,sha256=BVt8cBStUa-dwTdVwlIMUuxsz6NQTQU1a3i6CjyWl8A,14116
|
|
90
91
|
onnx_diagnostic/tasks/text_to_image.py,sha256=mOS3Ruosi3hzRMxXLDN7ZkAbi7NnQb7MWwQP_okGVHs,2962
|
|
91
92
|
onnx_diagnostic/tasks/zero_shot_image_classification.py,sha256=jJCMWuOqGv5ahCfjrcqxuYCJFhTgHV5KUf2yyv2yxYA,4624
|
|
92
93
|
onnx_diagnostic/tasks/data/__init__.py,sha256=uJoemrWgEjI6oA-tMX7r3__x-b3siPmkgqaY7bgIles,401
|
|
93
94
|
onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx,sha256=UbtvmWMqcZOKJ-I-HXWI1A6YR6QDaFS5u_yXm5C3ZBw,10299
|
|
94
95
|
onnx_diagnostic/torch_export_patches/__init__.py,sha256=0SaZedwznm1hQUCvXZsGZORV5vby954wEExr5faepGg,720
|
|
95
|
-
onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=
|
|
96
|
+
onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=WgoyNI9ueqDGNmkcIotCXLJFHKjSdoPywEym0m9X2KA,33206
|
|
96
97
|
onnx_diagnostic/torch_export_patches/onnx_export_serialization.py,sha256=K78uX5EHTuu0ah3mkZWNcGow4775GKH-EnDs3ZlIEhE,11778
|
|
97
98
|
onnx_diagnostic/torch_export_patches/patch_expressions.py,sha256=vr4tt61cbDnaaaduzMj4UBZ8OUtr6GfDpIWwOYqjWzs,3213
|
|
98
99
|
onnx_diagnostic/torch_export_patches/patch_inputs.py,sha256=2HQZKQV6TM5430RIvKiMPe4cfGvFdx1UnP1w76CeGE4,8110
|
|
@@ -101,18 +102,19 @@ onnx_diagnostic/torch_export_patches/patch_module_helper.py,sha256=2U0AdyZuU0W54
|
|
|
101
102
|
onnx_diagnostic/torch_export_patches/eval/__init__.py,sha256=YQoOGt9XQLWqnJ15NnT7ri_jDevfvpuQwEJo38E-VRU,25056
|
|
102
103
|
onnx_diagnostic/torch_export_patches/eval/model_cases.py,sha256=joDJV1YfrhYBR_6eXYvNO1jbiJM8Whb47NWZxo8SBwg,27172
|
|
103
104
|
onnx_diagnostic/torch_export_patches/patches/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
104
|
-
onnx_diagnostic/torch_export_patches/patches/patch_torch.py,sha256=
|
|
105
|
-
onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=
|
|
105
|
+
onnx_diagnostic/torch_export_patches/patches/patch_torch.py,sha256=p6JOcDX1zXdYoTThqZP7oeKr67nJ1ZJJcvWobz9RjI4,44435
|
|
106
|
+
onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=D6fE427n68j-YjvPyvU9obumDPgyIK7AWRkOukraEdM,83525
|
|
106
107
|
onnx_diagnostic/torch_export_patches/serialization/__init__.py,sha256=BHLdRPtNAtNPAS-bPKEj3-foGSPvwAbZXrHzGGPDLEw,1876
|
|
107
108
|
onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py,sha256=drq3EH_yjcSuIWYsVeUWm8Cx6YCZFU6bP_1PLtPfY5I,945
|
|
108
109
|
onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py,sha256=mcmZGekzQlLgE_o3SdKlRgCx4ewwyyAuNWZ9CaN_zrI,9317
|
|
109
110
|
onnx_diagnostic/torch_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
111
|
+
onnx_diagnostic/torch_models/code_sample.py,sha256=PWf7piGx7Eiv7BOTpL2bLUtWwaVcw7SMBvkSpEzZDPs,12883
|
|
110
112
|
onnx_diagnostic/torch_models/llms.py,sha256=soyg4yC87ptGoeulJhKqw5opGmuLvH1pn_ZDXZ4Jr8E,90
|
|
111
|
-
onnx_diagnostic/torch_models/validate.py,sha256=
|
|
113
|
+
onnx_diagnostic/torch_models/validate.py,sha256=ErpdXa8Gh9NAUsUsaf5JcvQVscBa3ZUkQ82PEZSawRc,81098
|
|
112
114
|
onnx_diagnostic/torch_models/hghub/__init__.py,sha256=vi1Q7YHdddj1soiBN42MSvJdFqe2_KUoWafHISjwOu8,58
|
|
113
115
|
onnx_diagnostic/torch_models/hghub/hub_api.py,sha256=rFbiPNLET-KdBpnv-p0nKgwHX6d7C_Z0s9zZ86_92kQ,14307
|
|
114
116
|
onnx_diagnostic/torch_models/hghub/hub_data.py,sha256=8V_pAgACPLPsLRYUododg7MSL6str-T3tBEGY4OaeYQ,8724
|
|
115
|
-
onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py,sha256=
|
|
117
|
+
onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py,sha256=VdLMPQ_ZpPg6cGMy00R2bGjSLWhvjkx8awY7kFD03I8,290141
|
|
116
118
|
onnx_diagnostic/torch_models/hghub/model_inputs.py,sha256=xIY_CWOp3m5-cJUvDLTZiH9GwiXi6xTYwONgFY4o45g,15593
|
|
117
119
|
onnx_diagnostic/torch_models/hghub/model_specific.py,sha256=j50Nu7wddJMoqmD4QzMbNdFDUUgUmSBKRzPDH55TlUQ,2498
|
|
118
120
|
onnx_diagnostic/torch_models/untrained/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -121,8 +123,8 @@ onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py,sha256=QXw_Bs2SzfeiQMf-tm
|
|
|
121
123
|
onnx_diagnostic/torch_onnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
122
124
|
onnx_diagnostic/torch_onnx/runtime_info.py,sha256=1g9F_Jf9AAgYQU4stbsrFXwQl-30mWlQrFbQ7val8Ps,9268
|
|
123
125
|
onnx_diagnostic/torch_onnx/sbs.py,sha256=IoKLA5UwS6kY8g4OOf_bdQwCziIsQfBczZ3w8wo4wZM,16905
|
|
124
|
-
onnx_diagnostic-0.7.
|
|
125
|
-
onnx_diagnostic-0.7.
|
|
126
|
-
onnx_diagnostic-0.7.
|
|
127
|
-
onnx_diagnostic-0.7.
|
|
128
|
-
onnx_diagnostic-0.7.
|
|
126
|
+
onnx_diagnostic-0.7.16.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
|
|
127
|
+
onnx_diagnostic-0.7.16.dist-info/METADATA,sha256=k9xWoY46lqN9fhGpujHDnGFZ4WB3kpSc8XHtWZ63sNg,6735
|
|
128
|
+
onnx_diagnostic-0.7.16.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
129
|
+
onnx_diagnostic-0.7.16.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
|
|
130
|
+
onnx_diagnostic-0.7.16.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|