onnx-diagnostic 0.7.14__py3-none-any.whl → 0.7.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +156 -47
  3. onnx_diagnostic/export/dynamic_shapes.py +6 -6
  4. onnx_diagnostic/export/shape_helper.py +124 -6
  5. onnx_diagnostic/ext_test_case.py +5 -1
  6. onnx_diagnostic/helpers/cache_helper.py +68 -42
  7. onnx_diagnostic/helpers/config_helper.py +2 -1
  8. onnx_diagnostic/helpers/fake_tensor_helper.py +153 -0
  9. onnx_diagnostic/helpers/helper.py +3 -0
  10. onnx_diagnostic/helpers/rt_helper.py +3 -3
  11. onnx_diagnostic/tasks/image_text_to_text.py +7 -6
  12. onnx_diagnostic/tasks/text_generation.py +7 -4
  13. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +69 -11
  14. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +31 -13
  15. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +109 -18
  16. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +133 -28
  17. onnx_diagnostic/torch_models/code_sample.py +343 -0
  18. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +38 -0
  19. onnx_diagnostic/torch_models/hghub/model_inputs.py +7 -3
  20. onnx_diagnostic/torch_models/validate.py +73 -29
  21. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/METADATA +6 -6
  22. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/RECORD +25 -23
  23. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/WHEEL +0 -0
  24. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/licenses/LICENSE.txt +0 -0
  25. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,343 @@
1
+ import os
2
+ import textwrap
3
+ import torch
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+ from ..helpers import flatten_object
6
+ from ..helpers.torch_helper import to_any
7
+ from .hghub.model_inputs import _preprocess_model_id
8
+ from .hghub import get_untrained_model_with_inputs
9
+ from .validate import filter_inputs, make_patch_kwargs
10
+
11
+
12
+ CODE_SAMPLES = {
13
+ "imports": "from typing import Any\nimport torch",
14
+ "get_model_with_inputs": textwrap.dedent(
15
+ """
16
+ def get_model_with_inputs(
17
+ model_id:str,
18
+ subfolder: str | None = None,
19
+ dtype: str | torch.dtype | None = None,
20
+ device: str | torch.device | None = None,
21
+ same_as_pretrained: bool = False,
22
+ use_pretrained: bool = False,
23
+ input_options: dict[str, Any] | None = None,
24
+ model_options: dict[str, Any] | None = None,
25
+ ) -> dict[str, Any]:
26
+ if use_pretrained:
27
+ import transformers
28
+ assert same_as_pretrained, (
29
+ "same_as_pretrained must be True if use_pretrained is True"
30
+ )
31
+ # tokenizer = AutoTokenizer.from_pretrained(model_path)
32
+ model = transformers.AutoModel.from_pretrained(
33
+ model_id,
34
+ trust_remote_code=True,
35
+ subfolder=subfolder,
36
+ dtype=dtype,
37
+ device=device,
38
+ )
39
+ data = {"model": model}
40
+ assert not input_options, "Not implemented yet with input_options{input_options}"
41
+ assert not model_options, "Not implemented yet with input_options{model_options}"
42
+ else:
43
+ from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
44
+ data = get_untrained_model_with_inputs(
45
+ model_id,
46
+ use_pretrained=use_pretrained,
47
+ same_as_pretrained=same_as_pretrained,
48
+ inputs_kwargs=input_options,
49
+ model_kwargs=model_options,
50
+ subfolder=subfolder,
51
+ add_second_input=False,
52
+ )
53
+ if dtype:
54
+ data["model"] = data["model"].to(
55
+ getattr(torch, dtype) if isinstance(dtype, str) else dtype
56
+ )
57
+ if device:
58
+ data["model"] = data["model"].to(device)
59
+ return data["model"]
60
+ """
61
+ ),
62
+ }
63
+
64
+
65
+ def make_code_for_inputs(inputs: Dict[str, torch.Tensor]) -> str:
66
+ """
67
+ Creates a code to generate random inputs.
68
+
69
+ :param inputs: dictionary
70
+ :return: code
71
+ """
72
+ codes = []
73
+ for k, v in inputs.items():
74
+ if isinstance(v, (int, bool, float)):
75
+ code = f"{k}={v}"
76
+ elif isinstance(v, torch.Tensor):
77
+ shape = tuple(map(int, v.shape))
78
+ if v.dtype in (torch.int32, torch.int64):
79
+ code = f"{k}=torch.randint({v.max()}, size={shape}, dtype={v.dtype})"
80
+ elif v.dtype in (torch.float32, torch.float16, torch.bfloat16):
81
+ code = f"{k}=torch.rand({shape}, dtype={v.dtype})"
82
+ else:
83
+ raise ValueError(f"Unexpected dtype = {v.dtype} for k={k!r}")
84
+ elif v.__class__.__name__ == "DynamicCache":
85
+ obj = flatten_object(v)
86
+ cc = [f"torch.rand({tuple(map(int,_.shape))}, dtype={_.dtype})" for _ in obj]
87
+ va = [f"({a},{b})" for a, b in zip(cc[: len(cc) // 2], cc[len(cc) // 2 :])]
88
+ va2 = ", ".join(va)
89
+ code = f"{k}=make_dynamic_cache([{va2}])"
90
+ else:
91
+ raise ValueError(f"Unexpected type {type(v)} for k={k!r}")
92
+ codes.append(code)
93
+ st = ", ".join(codes)
94
+ return f"dict({st})"
95
+
96
+
97
+ def make_export_code(
98
+ exporter: str,
99
+ optimization: Optional[str] = None,
100
+ patch_kwargs: Optional[Dict[str, Any]] = None,
101
+ stop_if_static: int = 0,
102
+ dump_folder: Optional[str] = None,
103
+ opset: Optional[int] = None,
104
+ dynamic_shapes: Optional[Dict[str, Any]] = None,
105
+ output_names: Optional[List[str]] = None,
106
+ verbose: int = 0,
107
+ ) -> Tuple[str, str]:
108
+ args = [f"dynamic_shapes={dynamic_shapes}"]
109
+ if output_names:
110
+ args.append(f"output_names={output_names}")
111
+ code = []
112
+ imports = []
113
+ if dump_folder:
114
+ code.append(f"os.makedirs({dump_folder!r})")
115
+ imports.append("import os")
116
+ filename = os.path.join(dump_folder, "model.onnx")
117
+ if exporter == "custom":
118
+ if opset:
119
+ args.append(f"target_opset={opset}")
120
+ if optimization:
121
+ args.append(f"options=OptimizationOptions(patterns={optimization!r})")
122
+ args.append(f"large_model=True, filename={filename!r}")
123
+ sargs = ", ".join(args)
124
+ imports.extend(
125
+ [
126
+ "from experimental_experiment.torch_interpreter import to_onnx",
127
+ "from experimental_experiment.xbuilder import OptimizationOptions",
128
+ ]
129
+ )
130
+ code.extend([f"onx = to_onnx(model, inputs, {sargs})"])
131
+ elif exporter == "onnx-dynamo":
132
+ if opset:
133
+ args.append(f"opset_version={opset}")
134
+ sargs = ", ".join(args)
135
+ code.extend([f"epo = torch.onnx.export(model, args=(), kwargs=inputs, {sargs})"])
136
+ if optimization:
137
+ imports.append("import onnxscript")
138
+ code.extend(["onnxscript.optimizer.optimize_ir(epo.model)"])
139
+ if "os_ort" in optimization:
140
+ imports.append("import onnxscript.rewriter.ort_fusions as ort_fusions")
141
+ code.extend(["ort_fusions.optimize_for_ort(epo.model)"])
142
+ if dump_folder:
143
+ code.extend([f"epo.save({filename!r})"])
144
+ else:
145
+ raise ValueError(f"Unexpected exporter {exporter!r}")
146
+ if not patch_kwargs:
147
+ return "\n".join(imports), "\n".join(code)
148
+
149
+ imports.append("from onnx_diagnostic.torch_export_patches import torch_export_patches")
150
+ if stop_if_static:
151
+ patch_kwargs["stop_if_static"] = stop_if_static
152
+ sargs = ", ".join(f"{k}={v}" for k, v in patch_kwargs.items())
153
+ code = [f"with torch_export_patches({sargs}):", *[" " + _ for _ in code]]
154
+ return "\n".join(imports), "\n".join(code)
155
+
156
+
157
+ def code_sample(
158
+ model_id: str,
159
+ task: Optional[str] = None,
160
+ do_run: bool = False,
161
+ exporter: Optional[str] = None,
162
+ do_same: bool = False,
163
+ verbose: int = 0,
164
+ dtype: Optional[Union[str, torch.dtype]] = None,
165
+ device: Optional[Union[str, torch.device]] = None,
166
+ same_as_pretrained: bool = False,
167
+ use_pretrained: bool = False,
168
+ optimization: Optional[str] = None,
169
+ quiet: bool = False,
170
+ patch: Union[bool, str, Dict[str, bool]] = False,
171
+ rewrite: bool = False,
172
+ stop_if_static: int = 1,
173
+ dump_folder: Optional[str] = None,
174
+ drop_inputs: Optional[List[str]] = None,
175
+ input_options: Optional[Dict[str, Any]] = None,
176
+ model_options: Optional[Dict[str, Any]] = None,
177
+ subfolder: Optional[str] = None,
178
+ opset: Optional[int] = None,
179
+ runtime: str = "onnxruntime",
180
+ output_names: Optional[List[str]] = None,
181
+ ) -> str:
182
+ """
183
+ This generates a code to export a model with the proper settings.
184
+
185
+ :param model_id: model id to validate
186
+ :param task: task used to generate the necessary inputs,
187
+ can be left empty to use the default task for this model
188
+ if it can be determined
189
+ :param do_run: checks the model works with the defined inputs
190
+ :param exporter: exporter the model using this exporter,
191
+ available list: ``export-strict``, ``export-nostrict``, ...
192
+ see below
193
+ :param do_same: checks the discrepancies of the exported model
194
+ :param verbose: verbosity level
195
+ :param dtype: uses this dtype to check the model
196
+ :param device: do the verification on this device
197
+ :param same_as_pretrained: use a model equivalent to the trained,
198
+ this is not always possible
199
+ :param use_pretrained: use the trained model, not the untrained one
200
+ :param optimization: optimization to apply to the exported model,
201
+ depend on the the exporter
202
+ :param quiet: if quiet, catches exception if any issue
203
+ :param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
204
+ if True before exporting
205
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`,
206
+ a string can be used to specify only one of them
207
+ :param rewrite: applies known rewriting (``patch_transformers=True``) before exporting,
208
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
209
+ :param stop_if_static: stops if a dynamic dimension becomes static,
210
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
211
+ :param dump_folder: dumps everything in a subfolder of this one
212
+ :param drop_inputs: drops this list of inputs (given their names)
213
+ :param input_options: additional options to define the dummy inputs
214
+ used to export
215
+ :param model_options: additional options when creating the model such as
216
+ ``num_hidden_layers`` or ``attn_implementation``
217
+ :param subfolder: version or subfolders to uses when retrieving a model id
218
+ :param opset: onnx opset to use for the conversion
219
+ :param runtime: onnx runtime to use to check about discrepancies,
220
+ possible values ``onnxruntime``, ``torch``, ``orteval``,
221
+ ``orteval10``, ``ref`` only if `do_run` is true
222
+ :param output_names: output names the onnx exporter should use
223
+ :return: a code
224
+
225
+ .. runpython::
226
+ :showcode:
227
+
228
+ from onnx_diagnostic.torch_models.code_sample import code_sample
229
+
230
+ print(
231
+ code_sample(
232
+ "arnir0/Tiny-LLM",
233
+ exporter="onnx-dynamo",
234
+ optimization="ir",
235
+ patch=True,
236
+ )
237
+ )
238
+ """
239
+ model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
240
+ model_id,
241
+ subfolder,
242
+ same_as_pretrained=same_as_pretrained,
243
+ use_pretrained=use_pretrained,
244
+ )
245
+ patch_kwargs = make_patch_kwargs(patch=patch, rewrite=rewrite)
246
+
247
+ iop = input_options or {}
248
+ mop = model_options or {}
249
+ data = get_untrained_model_with_inputs(
250
+ model_id,
251
+ verbose=verbose,
252
+ task=task,
253
+ use_pretrained=use_pretrained,
254
+ same_as_pretrained=same_as_pretrained,
255
+ inputs_kwargs=iop,
256
+ model_kwargs=mop,
257
+ subfolder=subfolder,
258
+ add_second_input=False,
259
+ )
260
+ if drop_inputs:
261
+ update = {}
262
+ for k in data:
263
+ if k.startswith("inputs"):
264
+ update[k], ds = filter_inputs(
265
+ data[k],
266
+ drop_names=drop_inputs,
267
+ model=data["model"],
268
+ dynamic_shapes=data["dynamic_shapes"],
269
+ )
270
+ update["dynamic_shapes"] = ds
271
+ data.update(update)
272
+
273
+ update = {}
274
+ for k in data:
275
+ if k.startswith("inputs"):
276
+ v = data[k]
277
+ if dtype:
278
+ update[k] = v = to_any(
279
+ v, getattr(torch, dtype) if isinstance(dtype, str) else dtype
280
+ )
281
+ if device:
282
+ update[k] = v = to_any(v, device)
283
+ if update:
284
+ data.update(update)
285
+
286
+ args = [f"{model_id!r}"]
287
+ if subfolder:
288
+ args.append(f"subfolder={subfolder!r}")
289
+ if dtype:
290
+ args.append(f"dtype={dtype!r}")
291
+ if device:
292
+ args.append(f"device={device!r}")
293
+ if same_as_pretrained:
294
+ args.append(f"same_as_pretrained={same_as_pretrained!r}")
295
+ if use_pretrained:
296
+ args.append(f"use_pretrained={use_pretrained!r}")
297
+ if input_options:
298
+ args.append(f"input_options={input_options!r}")
299
+ if model_options:
300
+ args.append(f"model_options={model_options!r}")
301
+ model_args = ", ".join(args)
302
+ imports, exporter_code = (
303
+ make_export_code(
304
+ exporter=exporter,
305
+ patch_kwargs=patch_kwargs,
306
+ verbose=verbose,
307
+ optimization=optimization,
308
+ stop_if_static=stop_if_static,
309
+ dump_folder=dump_folder,
310
+ opset=opset,
311
+ dynamic_shapes=data["dynamic_shapes"],
312
+ )
313
+ if exporter is not None
314
+ else ([], [])
315
+ )
316
+ input_code = make_code_for_inputs(data["inputs"])
317
+ cache_import = (
318
+ "from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache"
319
+ if "dynamic_cache" in input_code
320
+ else ""
321
+ )
322
+
323
+ pieces = [
324
+ CODE_SAMPLES["imports"],
325
+ imports,
326
+ cache_import,
327
+ CODE_SAMPLES["get_model_with_inputs"],
328
+ textwrap.dedent(
329
+ f"""
330
+ model = get_model_with_inputs({model_args})
331
+ """
332
+ ),
333
+ f"inputs = {input_code}",
334
+ exporter_code,
335
+ ]
336
+ code = "\n".join(pieces) # type: ignore[arg-type]
337
+ try:
338
+ import black
339
+ except ImportError:
340
+ # No black formatting.
341
+ return code
342
+
343
+ return black.format_str(code, mode=black.Mode())
@@ -4865,3 +4865,41 @@ def _ccached_google_gemma_3_4b_it_like():
4865
4865
  },
4866
4866
  }
4867
4867
  )
4868
+
4869
+
4870
+ def _ccached_hf_internal_testing_tiny_random_gemma3_for_causal_lm():
4871
+ "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
4872
+ return transformers.Gemma3TextConfig(
4873
+ **{
4874
+ "architectures": ["Gemma3ForCausalLM"],
4875
+ "attention_bias": false,
4876
+ "attention_dropout": 0.0,
4877
+ "attn_logit_softcapping": null,
4878
+ "bos_token_id": 2,
4879
+ "cache_implementation": "hybrid",
4880
+ "eos_token_id": [1, 106],
4881
+ "final_logit_softcapping": null,
4882
+ "head_dim": 8,
4883
+ "hidden_activation": "gelu_pytorch_tanh",
4884
+ "hidden_size": 16,
4885
+ "initializer_range": 0.02,
4886
+ "intermediate_size": 32,
4887
+ "max_position_embeddings": 32768,
4888
+ "model_type": "gemma3_text",
4889
+ "num_attention_heads": 2,
4890
+ "num_hidden_layers": 2,
4891
+ "num_key_value_heads": 1,
4892
+ "pad_token_id": 0,
4893
+ "query_pre_attn_scalar": 256,
4894
+ "rms_norm_eps": 1e-06,
4895
+ "rope_local_base_freq": 10000,
4896
+ "rope_scaling": null,
4897
+ "rope_theta": 1000000,
4898
+ "sliding_window": 512,
4899
+ "sliding_window_pattern": 6,
4900
+ "torch_dtype": "float32",
4901
+ "transformers_version": "4.52.0.dev0",
4902
+ "use_cache": true,
4903
+ "vocab_size": 262144,
4904
+ }
4905
+ )
@@ -95,6 +95,8 @@ def get_untrained_model_with_inputs(
95
95
  print("-- dynamic shapes:", pprint.pformat(data['dynamic_shapes']))
96
96
  print("-- configuration:", pprint.pformat(data['configuration']))
97
97
  """
98
+ if task == "":
99
+ task = None
98
100
  assert not use_preinstalled or not use_only_preinstalled, (
99
101
  f"model_id={model_id!r}, preinstalled model is only available "
100
102
  f"if use_only_preinstalled is False."
@@ -120,14 +122,16 @@ def get_untrained_model_with_inputs(
120
122
  **(model_kwargs or {}),
121
123
  )
122
124
 
123
- model, task, mkwargs, diff_config = None, None, {}, None
125
+ model, task_, mkwargs, diff_config = None, None, {}, None
124
126
  if use_pretrained and same_as_pretrained:
125
127
  if model_id in HANDLED_MODELS:
126
- model, task, config = load_specific_model(model_id, verbose=verbose)
128
+ model, task_, config = load_specific_model(model_id, verbose=verbose)
127
129
 
130
+ if task is None:
131
+ task = task_
128
132
  if model is None:
129
133
  arch = architecture_from_config(config)
130
- if arch is None:
134
+ if task is None and arch is None:
131
135
  task = task_from_id(model_id, subfolder=subfolder)
132
136
  assert task is not None or arch is not None, (
133
137
  f"Unable to determine the architecture for model {model_id!r}, "
@@ -4,7 +4,7 @@ import inspect
4
4
  import os
5
5
  import pprint
6
6
  import sys
7
- from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
7
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
8
8
  import time
9
9
  import numpy as np
10
10
  import onnx
@@ -117,11 +117,21 @@ def _make_folder_name(
117
117
  drop_inputs: Optional[List[str]] = None,
118
118
  same_as_pretrained: bool = False,
119
119
  use_pretrained: bool = False,
120
+ task: Optional[str] = None,
120
121
  ) -> str:
121
122
  "Creates a filename unique based on the given options."
122
123
  els = [model_id.replace("/", "_")]
123
124
  if subfolder:
124
125
  els.append(subfolder.replace("/", "_"))
126
+ if task:
127
+ els.append(task)
128
+ if drop_inputs:
129
+ ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs)
130
+ els.append(f"I-{ii.upper()}")
131
+ if use_pretrained:
132
+ els.append("TRAINED")
133
+ elif same_as_pretrained:
134
+ els.append("SAMESIZE")
125
135
  if exporter:
126
136
  els.append(exporter)
127
137
  if optimization:
@@ -142,14 +152,7 @@ def _make_folder_name(
142
152
  els.append(sdev)
143
153
  if opset is not None:
144
154
  els.append(f"op{opset}")
145
- if drop_inputs:
146
- ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs)
147
- els.append(f"I-{ii.upper()}")
148
- if use_pretrained:
149
- els.append("TRAINED")
150
- elif same_as_pretrained:
151
- els.append("SAMESIZE")
152
- return "-".join(els)
155
+ return "/".join([e for e in els if e])
153
156
 
154
157
 
155
158
  def version_summary() -> Dict[str, Union[int, float, str]]:
@@ -290,6 +293,33 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
290
293
  return new_cfg
291
294
 
292
295
 
296
+ def make_patch_kwargs(
297
+ patch: Union[bool, str, Dict[str, bool]] = False,
298
+ rewrite: bool = False,
299
+ ) -> Dict[str, Any]:
300
+ """Creates patch arguments."""
301
+ default_patch = dict(patch_transformers=True, patch_diffusers=True, patch=True)
302
+ if isinstance(patch, bool):
303
+ patch_kwargs = default_patch if patch else dict(patch=False)
304
+ elif isinstance(patch, str):
305
+ patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
306
+ else:
307
+ assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}"
308
+ patch_kwargs = patch.copy()
309
+ if "patch" not in patch_kwargs:
310
+ if any(patch_kwargs.values()):
311
+ patch_kwargs["patch"] = True
312
+ elif len(patch) == 1 and patch.get("patch", False):
313
+ patch_kwargs.update(default_patch)
314
+
315
+ assert not rewrite or patch_kwargs.get("patch", False), (
316
+ f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
317
+ f"patch must be True to enable rewriting, "
318
+ f"if --patch=0 was specified on the command line, rewrites are disabled."
319
+ )
320
+ return patch_kwargs
321
+
322
+
293
323
  def validate_model(
294
324
  model_id: str,
295
325
  task: Optional[str] = None,
@@ -319,6 +349,7 @@ def validate_model(
319
349
  inputs2: int = 1,
320
350
  output_names: Optional[List[str]] = None,
321
351
  ort_logs: bool = False,
352
+ quiet_input_sets: Optional[Set[str]] = None,
322
353
  ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
323
354
  """
324
355
  Validates a model.
@@ -373,6 +404,8 @@ def validate_model(
373
404
  or an empty cache for example
374
405
  :param output_names: output names the onnx exporter should use
375
406
  :param ort_logs: increases onnxruntime verbosity when creating the session
407
+ :param quiet_input_sets: avoid raising an exception if the inputs belongs to that set
408
+ even if quiet is False
376
409
  :return: two dictionaries, one with some metrics,
377
410
  another one with whatever the function produces
378
411
 
@@ -414,25 +447,8 @@ def validate_model(
414
447
  use_pretrained=use_pretrained,
415
448
  )
416
449
  time_preprocess_model_id = time.perf_counter() - main_validation_begin
417
- default_patch = dict(patch_transformers=True, patch_diffusers=True, patch=True)
418
- if isinstance(patch, bool):
419
- patch_kwargs = default_patch if patch else dict(patch=False)
420
- elif isinstance(patch, str):
421
- patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
422
- else:
423
- assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}"
424
- patch_kwargs = patch.copy()
425
- if "patch" not in patch_kwargs:
426
- if any(patch_kwargs.values()):
427
- patch_kwargs["patch"] = True
428
- elif len(patch) == 1 and patch.get("patch", False):
429
- patch_kwargs.update(default_patch)
450
+ patch_kwargs = make_patch_kwargs(patch=patch, rewrite=rewrite)
430
451
 
431
- assert not rewrite or patch_kwargs.get("patch", False), (
432
- f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
433
- f"patch must be True to enable rewriting, "
434
- f"if --patch=0 was specified on the command line, rewrites are disabled."
435
- )
436
452
  summary = version_summary()
437
453
  summary.update(
438
454
  dict(
@@ -473,6 +489,7 @@ def validate_model(
473
489
  drop_inputs=drop_inputs,
474
490
  use_pretrained=use_pretrained,
475
491
  same_as_pretrained=same_as_pretrained,
492
+ task=task,
476
493
  )
477
494
  dump_folder = os.path.join(dump_folder, folder_name)
478
495
  if not os.path.exists(dump_folder):
@@ -487,6 +504,8 @@ def validate_model(
487
504
  print(f"[validate_model] validate model id {model_id!r}, subfolder={subfolder!r}")
488
505
  else:
489
506
  print(f"[validate_model] validate model id {model_id!r}")
507
+ if task:
508
+ print(f"[validate_model] with task {task!r}")
490
509
  print(f"[validate_model] patch={patch!r}")
491
510
  if model_options:
492
511
  print(f"[validate_model] model_options={model_options!r}")
@@ -762,6 +781,10 @@ def validate_model(
762
781
  ep = data["exported_program"]
763
782
  if verbose:
764
783
  print(f"[validate_model] -- dumps exported program in {dump_folder!r}...")
784
+ assert isinstance(
785
+ folder_name, str
786
+ ), f"folder_name={folder_name!r} should be a string"
787
+ folder_name = folder_name.replace("/", "-")
765
788
  with open(os.path.join(dump_folder, f"{folder_name}.ep"), "w") as f:
766
789
  f.write(str(ep))
767
790
  torch.export.save(ep, os.path.join(dump_folder, f"{folder_name}.pt2"))
@@ -770,6 +793,10 @@ def validate_model(
770
793
  if verbose:
771
794
  print("[validate_model] done (dump ep)")
772
795
  if "onnx_program" in data:
796
+ assert isinstance(
797
+ folder_name, str
798
+ ), f"folder_name={folder_name!r} should be a string"
799
+ folder_name = folder_name.replace("/", "-")
773
800
  epo = data["onnx_program"]
774
801
  if verbose:
775
802
  print(f"[validate_model] dumps onnx program in {dump_folder!r}...")
@@ -842,6 +869,7 @@ def validate_model(
842
869
  warmup=warmup,
843
870
  second_input_keys=second_input_keys,
844
871
  ort_logs=ort_logs,
872
+ quiet_input_sets=quiet_input_sets,
845
873
  )
846
874
  summary.update(summary_valid)
847
875
  summary["time_total_validation_onnx"] = time.perf_counter() - validation_begin
@@ -904,6 +932,7 @@ def validate_model(
904
932
  repeat=repeat,
905
933
  warmup=warmup,
906
934
  second_input_keys=second_input_keys,
935
+ quiet_input_sets=quiet_input_sets,
907
936
  )
908
937
  summary.update(summary_valid)
909
938
 
@@ -1289,6 +1318,7 @@ def validate_onnx_model(
1289
1318
  warmup: int = 0,
1290
1319
  second_input_keys: Optional[List[str]] = None,
1291
1320
  ort_logs: bool = False,
1321
+ quiet_input_sets: Optional[Set[str]] = None,
1292
1322
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1293
1323
  """
1294
1324
  Verifies that an onnx model produces the same
@@ -1308,6 +1338,7 @@ def validate_onnx_model(
1308
1338
  to make sure the exported model supports dynamism, the value is
1309
1339
  used as an increment added to the first set of inputs (added to dimensions)
1310
1340
  :param ort_logs: triggers the logs for onnxruntime
1341
+ :param quiet_input_sets: avoid raising an exception for these sets of inputs
1311
1342
  :return: two dictionaries, one with some metrics,
1312
1343
  another one with whatever the function produces
1313
1344
  """
@@ -1431,6 +1462,8 @@ def validate_onnx_model(
1431
1462
  keys = [("inputs", "run_expected", "")]
1432
1463
  if second_input_keys:
1433
1464
  keys.extend([(k, f"run_expected2{k[6:]}", f"2{k[6:]}") for k in second_input_keys])
1465
+ if verbose:
1466
+ print(f"[validate_onnx_model] -- keys={keys}")
1434
1467
  for k_input, k_expected, suffix in keys:
1435
1468
  # make_feeds
1436
1469
  assert k_input in data, f"Unable to find {k_input!r} in {sorted(data)}"
@@ -1455,10 +1488,12 @@ def validate_onnx_model(
1455
1488
 
1456
1489
  # run ort
1457
1490
  if verbose:
1458
- print("[validate_onnx_model] run session...")
1491
+ print(f"[validate_onnx_model] run session on inputs 'inputs{suffix}'...")
1492
+ if quiet_input_sets and f"inputs{suffix}" in quiet_input_sets:
1493
+ print(f"[validate_onnx_model] quiet_input_sets={quiet_input_sets}")
1459
1494
 
1460
1495
  got = _quiet_or_not_quiet(
1461
- quiet,
1496
+ quiet or (quiet_input_sets is not None and f"inputs{suffix}" in quiet_input_sets),
1462
1497
  _mk(f"run_onnx_ort{suffix}"),
1463
1498
  summary,
1464
1499
  data,
@@ -1865,6 +1900,7 @@ def call_torch_export_custom(
1865
1900
  "custom-nostrict-all-noinline",
1866
1901
  "custom-dec",
1867
1902
  "custom-decall",
1903
+ "custom-fake",
1868
1904
  }
1869
1905
  assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
1870
1906
  assert "model" in data, f"model is missing from data: {sorted(data)}"
@@ -1873,6 +1909,14 @@ def call_torch_export_custom(
1873
1909
  strict = "-strict" in exporter
1874
1910
  args, kwargs = split_args_kwargs(data["inputs_export"])
1875
1911
  ds = data.get("dynamic_shapes", None)
1912
+ if "-fake" in exporter:
1913
+ from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
1914
+
1915
+ if verbose:
1916
+ print("[call_torch_export_custom] switching to FakeTensor")
1917
+ assert not args, f"Exporter {exporter!r} not implemented with fake tensors."
1918
+ kwargs = torch_deepcopy(kwargs)
1919
+ kwargs, _ = make_fake_with_dynamic_dimensions(kwargs, dynamic_shapes=ds)
1876
1920
  opset = data.get("model_opset", None)
1877
1921
  if opset:
1878
1922
  summary["export_opset"] = opset
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-diagnostic
3
- Version: 0.7.14
3
+ Version: 0.7.16
4
4
  Summary: Tools to help converting pytorch models into ONNX.
5
5
  Home-page: https://github.com/sdpython/onnx-diagnostic
6
6
  Author: Xavier Dupré
@@ -60,9 +60,9 @@ You need then to remove those which are not dynamic in your model.
60
60
 
61
61
  .. code-block:: python
62
62
 
63
- from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
63
+ from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
64
64
 
65
- dynamic_shapes = all_dynamic_shape_from_inputs(cache)
65
+ dynamic_shapes = all_dynamic_shapes_from_inputs(cache)
66
66
 
67
67
  It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...).
68
68
  See `documentation of onnx-diagnostic <https://sdpython.github.io/doc/onnx-diagnostic/dev/>`_ and
@@ -126,13 +126,13 @@ Snapshot of usefuls tools
126
126
  ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
127
127
  # ...
128
128
 
129
- **all_dynamic_shape_from_inputs**
129
+ **all_dynamic_shapes_from_inputs**
130
130
 
131
131
  .. code-block:: python
132
132
 
133
- from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
133
+ from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
134
134
 
135
- dynamic_shapes = all_dynamic_shape_from_inputs(cache)
135
+ dynamic_shapes = all_dynamic_shapes_from_inputs(cache)
136
136
 
137
137
  **torch_export_rewrite**
138
138