onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import textwrap
|
|
3
|
+
import torch
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
5
|
+
from ..helpers import flatten_object
|
|
6
|
+
from ..helpers.torch_helper import to_any
|
|
7
|
+
from .hghub.model_inputs import _preprocess_model_id
|
|
8
|
+
from .hghub import get_untrained_model_with_inputs
|
|
9
|
+
from .validate import filter_inputs, make_patch_kwargs
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
CODE_SAMPLES = {
|
|
13
|
+
"imports": "from typing import Any\nimport torch",
|
|
14
|
+
"get_model_with_inputs": textwrap.dedent(
|
|
15
|
+
"""
|
|
16
|
+
def get_model_with_inputs(
|
|
17
|
+
model_id:str,
|
|
18
|
+
subfolder: str | None = None,
|
|
19
|
+
dtype: str | torch.dtype | None = None,
|
|
20
|
+
device: str | torch.device | None = None,
|
|
21
|
+
same_as_pretrained: bool = False,
|
|
22
|
+
use_pretrained: bool = False,
|
|
23
|
+
input_options: dict[str, Any] | None = None,
|
|
24
|
+
model_options: dict[str, Any] | None = None,
|
|
25
|
+
) -> dict[str, Any]:
|
|
26
|
+
if use_pretrained:
|
|
27
|
+
import transformers
|
|
28
|
+
assert same_as_pretrained, (
|
|
29
|
+
"same_as_pretrained must be True if use_pretrained is True"
|
|
30
|
+
)
|
|
31
|
+
# tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
32
|
+
model = transformers.AutoModel.from_pretrained(
|
|
33
|
+
model_id,
|
|
34
|
+
trust_remote_code=True,
|
|
35
|
+
subfolder=subfolder,
|
|
36
|
+
dtype=dtype,
|
|
37
|
+
device=device,
|
|
38
|
+
)
|
|
39
|
+
data = {"model": model}
|
|
40
|
+
assert not input_options, "Not implemented yet with input_options{input_options}"
|
|
41
|
+
assert not model_options, "Not implemented yet with input_options{model_options}"
|
|
42
|
+
else:
|
|
43
|
+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
|
|
44
|
+
data = get_untrained_model_with_inputs(
|
|
45
|
+
model_id,
|
|
46
|
+
use_pretrained=use_pretrained,
|
|
47
|
+
same_as_pretrained=same_as_pretrained,
|
|
48
|
+
inputs_kwargs=input_options,
|
|
49
|
+
model_kwargs=model_options,
|
|
50
|
+
subfolder=subfolder,
|
|
51
|
+
add_second_input=False,
|
|
52
|
+
)
|
|
53
|
+
if dtype:
|
|
54
|
+
data["model"] = data["model"].to(
|
|
55
|
+
getattr(torch, dtype) if isinstance(dtype, str) else dtype
|
|
56
|
+
)
|
|
57
|
+
if device:
|
|
58
|
+
data["model"] = data["model"].to(device)
|
|
59
|
+
return data["model"]
|
|
60
|
+
"""
|
|
61
|
+
),
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def make_code_for_inputs(inputs: Dict[str, torch.Tensor]) -> str:
|
|
66
|
+
"""
|
|
67
|
+
Creates a code to generate random inputs.
|
|
68
|
+
|
|
69
|
+
:param inputs: dictionary
|
|
70
|
+
:return: code
|
|
71
|
+
"""
|
|
72
|
+
codes = []
|
|
73
|
+
for k, v in inputs.items():
|
|
74
|
+
if isinstance(v, (int, bool, float)):
|
|
75
|
+
code = f"{k}={v}"
|
|
76
|
+
elif isinstance(v, torch.Tensor):
|
|
77
|
+
shape = tuple(map(int, v.shape))
|
|
78
|
+
if v.dtype in (torch.int32, torch.int64):
|
|
79
|
+
code = f"{k}=torch.randint({v.max()}, size={shape}, dtype={v.dtype})"
|
|
80
|
+
elif v.dtype in (torch.float32, torch.float16, torch.bfloat16):
|
|
81
|
+
code = f"{k}=torch.rand({shape}, dtype={v.dtype})"
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError(f"Unexpected dtype = {v.dtype} for k={k!r}")
|
|
84
|
+
elif v.__class__.__name__ == "DynamicCache":
|
|
85
|
+
obj = flatten_object(v)
|
|
86
|
+
cc = [f"torch.rand({tuple(map(int,_.shape))}, dtype={_.dtype})" for _ in obj]
|
|
87
|
+
va = [f"({a},{b})" for a, b in zip(cc[: len(cc) // 2], cc[len(cc) // 2 :])]
|
|
88
|
+
va2 = ", ".join(va)
|
|
89
|
+
code = f"{k}=make_dynamic_cache([{va2}])"
|
|
90
|
+
else:
|
|
91
|
+
raise ValueError(f"Unexpected type {type(v)} for k={k!r}")
|
|
92
|
+
codes.append(code)
|
|
93
|
+
st = ", ".join(codes)
|
|
94
|
+
return f"dict({st})"
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def make_export_code(
|
|
98
|
+
exporter: str,
|
|
99
|
+
optimization: Optional[str] = None,
|
|
100
|
+
patch_kwargs: Optional[Dict[str, Any]] = None,
|
|
101
|
+
stop_if_static: int = 0,
|
|
102
|
+
dump_folder: Optional[str] = None,
|
|
103
|
+
opset: Optional[int] = None,
|
|
104
|
+
dynamic_shapes: Optional[Dict[str, Any]] = None,
|
|
105
|
+
output_names: Optional[List[str]] = None,
|
|
106
|
+
verbose: int = 0,
|
|
107
|
+
) -> Tuple[str, str]:
|
|
108
|
+
args = [f"dynamic_shapes={dynamic_shapes}"]
|
|
109
|
+
if output_names:
|
|
110
|
+
args.append(f"output_names={output_names}")
|
|
111
|
+
code = []
|
|
112
|
+
imports = []
|
|
113
|
+
if dump_folder:
|
|
114
|
+
code.append(f"os.makedirs({dump_folder!r})")
|
|
115
|
+
imports.append("import os")
|
|
116
|
+
filename = os.path.join(dump_folder, "model.onnx")
|
|
117
|
+
if exporter == "custom":
|
|
118
|
+
if opset:
|
|
119
|
+
args.append(f"target_opset={opset}")
|
|
120
|
+
if optimization:
|
|
121
|
+
args.append(f"options=OptimizationOptions(patterns={optimization!r})")
|
|
122
|
+
args.append(f"large_model=True, filename={filename!r}")
|
|
123
|
+
sargs = ", ".join(args)
|
|
124
|
+
imports.extend(
|
|
125
|
+
[
|
|
126
|
+
"from experimental_experiment.torch_interpreter import to_onnx",
|
|
127
|
+
"from experimental_experiment.xbuilder import OptimizationOptions",
|
|
128
|
+
]
|
|
129
|
+
)
|
|
130
|
+
code.extend([f"onx = to_onnx(model, inputs, {sargs})"])
|
|
131
|
+
elif exporter == "onnx-dynamo":
|
|
132
|
+
if opset:
|
|
133
|
+
args.append(f"opset_version={opset}")
|
|
134
|
+
sargs = ", ".join(args)
|
|
135
|
+
code.extend([f"epo = torch.onnx.export(model, args=(), kwargs=inputs, {sargs})"])
|
|
136
|
+
if optimization:
|
|
137
|
+
imports.append("import onnxscript")
|
|
138
|
+
code.extend(["onnxscript.optimizer.optimize_ir(epo.model)"])
|
|
139
|
+
if "os_ort" in optimization:
|
|
140
|
+
imports.append("import onnxscript.rewriter.ort_fusions as ort_fusions")
|
|
141
|
+
code.extend(["ort_fusions.optimize_for_ort(epo.model)"])
|
|
142
|
+
if dump_folder:
|
|
143
|
+
code.extend([f"epo.save({filename!r})"])
|
|
144
|
+
else:
|
|
145
|
+
raise ValueError(f"Unexpected exporter {exporter!r}")
|
|
146
|
+
if not patch_kwargs:
|
|
147
|
+
return "\n".join(imports), "\n".join(code)
|
|
148
|
+
|
|
149
|
+
imports.append("from onnx_diagnostic.torch_export_patches import torch_export_patches")
|
|
150
|
+
if stop_if_static:
|
|
151
|
+
patch_kwargs["stop_if_static"] = stop_if_static
|
|
152
|
+
sargs = ", ".join(f"{k}={v}" for k, v in patch_kwargs.items())
|
|
153
|
+
code = [f"with torch_export_patches({sargs}):", *[" " + _ for _ in code]]
|
|
154
|
+
return "\n".join(imports), "\n".join(code)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def code_sample(
|
|
158
|
+
model_id: str,
|
|
159
|
+
task: Optional[str] = None,
|
|
160
|
+
do_run: bool = False,
|
|
161
|
+
exporter: Optional[str] = None,
|
|
162
|
+
do_same: bool = False,
|
|
163
|
+
verbose: int = 0,
|
|
164
|
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
|
165
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
166
|
+
same_as_pretrained: bool = False,
|
|
167
|
+
use_pretrained: bool = False,
|
|
168
|
+
optimization: Optional[str] = None,
|
|
169
|
+
quiet: bool = False,
|
|
170
|
+
patch: Union[bool, str, Dict[str, bool]] = False,
|
|
171
|
+
rewrite: bool = False,
|
|
172
|
+
stop_if_static: int = 1,
|
|
173
|
+
dump_folder: Optional[str] = None,
|
|
174
|
+
drop_inputs: Optional[List[str]] = None,
|
|
175
|
+
input_options: Optional[Dict[str, Any]] = None,
|
|
176
|
+
model_options: Optional[Dict[str, Any]] = None,
|
|
177
|
+
subfolder: Optional[str] = None,
|
|
178
|
+
opset: Optional[int] = None,
|
|
179
|
+
runtime: str = "onnxruntime",
|
|
180
|
+
output_names: Optional[List[str]] = None,
|
|
181
|
+
) -> str:
|
|
182
|
+
"""
|
|
183
|
+
This generates a code to export a model with the proper settings.
|
|
184
|
+
|
|
185
|
+
:param model_id: model id to validate
|
|
186
|
+
:param task: task used to generate the necessary inputs,
|
|
187
|
+
can be left empty to use the default task for this model
|
|
188
|
+
if it can be determined
|
|
189
|
+
:param do_run: checks the model works with the defined inputs
|
|
190
|
+
:param exporter: exporter the model using this exporter,
|
|
191
|
+
available list: ``export-strict``, ``export-nostrict``, ...
|
|
192
|
+
see below
|
|
193
|
+
:param do_same: checks the discrepancies of the exported model
|
|
194
|
+
:param verbose: verbosity level
|
|
195
|
+
:param dtype: uses this dtype to check the model
|
|
196
|
+
:param device: do the verification on this device
|
|
197
|
+
:param same_as_pretrained: use a model equivalent to the trained,
|
|
198
|
+
this is not always possible
|
|
199
|
+
:param use_pretrained: use the trained model, not the untrained one
|
|
200
|
+
:param optimization: optimization to apply to the exported model,
|
|
201
|
+
depend on the the exporter
|
|
202
|
+
:param quiet: if quiet, catches exception if any issue
|
|
203
|
+
:param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
|
|
204
|
+
if True before exporting
|
|
205
|
+
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`,
|
|
206
|
+
a string can be used to specify only one of them
|
|
207
|
+
:param rewrite: applies known rewriting (``patch_transformers=True``) before exporting,
|
|
208
|
+
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
|
|
209
|
+
:param stop_if_static: stops if a dynamic dimension becomes static,
|
|
210
|
+
see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
|
|
211
|
+
:param dump_folder: dumps everything in a subfolder of this one
|
|
212
|
+
:param drop_inputs: drops this list of inputs (given their names)
|
|
213
|
+
:param input_options: additional options to define the dummy inputs
|
|
214
|
+
used to export
|
|
215
|
+
:param model_options: additional options when creating the model such as
|
|
216
|
+
``num_hidden_layers`` or ``attn_implementation``
|
|
217
|
+
:param subfolder: version or subfolders to uses when retrieving a model id
|
|
218
|
+
:param opset: onnx opset to use for the conversion
|
|
219
|
+
:param runtime: onnx runtime to use to check about discrepancies,
|
|
220
|
+
possible values ``onnxruntime``, ``torch``, ``orteval``,
|
|
221
|
+
``orteval10``, ``ref`` only if `do_run` is true
|
|
222
|
+
:param output_names: output names the onnx exporter should use
|
|
223
|
+
:return: a code
|
|
224
|
+
|
|
225
|
+
.. runpython::
|
|
226
|
+
:showcode:
|
|
227
|
+
|
|
228
|
+
from onnx_diagnostic.torch_models.code_sample import code_sample
|
|
229
|
+
|
|
230
|
+
print(
|
|
231
|
+
code_sample(
|
|
232
|
+
"arnir0/Tiny-LLM",
|
|
233
|
+
exporter="onnx-dynamo",
|
|
234
|
+
optimization="ir",
|
|
235
|
+
patch=True,
|
|
236
|
+
)
|
|
237
|
+
)
|
|
238
|
+
"""
|
|
239
|
+
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
|
|
240
|
+
model_id,
|
|
241
|
+
subfolder,
|
|
242
|
+
same_as_pretrained=same_as_pretrained,
|
|
243
|
+
use_pretrained=use_pretrained,
|
|
244
|
+
)
|
|
245
|
+
patch_kwargs = make_patch_kwargs(patch=patch, rewrite=rewrite)
|
|
246
|
+
|
|
247
|
+
iop = input_options or {}
|
|
248
|
+
mop = model_options or {}
|
|
249
|
+
data = get_untrained_model_with_inputs(
|
|
250
|
+
model_id,
|
|
251
|
+
verbose=verbose,
|
|
252
|
+
task=task,
|
|
253
|
+
use_pretrained=use_pretrained,
|
|
254
|
+
same_as_pretrained=same_as_pretrained,
|
|
255
|
+
inputs_kwargs=iop,
|
|
256
|
+
model_kwargs=mop,
|
|
257
|
+
subfolder=subfolder,
|
|
258
|
+
add_second_input=False,
|
|
259
|
+
)
|
|
260
|
+
if drop_inputs:
|
|
261
|
+
update = {}
|
|
262
|
+
for k in data:
|
|
263
|
+
if k.startswith("inputs"):
|
|
264
|
+
update[k], ds = filter_inputs(
|
|
265
|
+
data[k],
|
|
266
|
+
drop_names=drop_inputs,
|
|
267
|
+
model=data["model"],
|
|
268
|
+
dynamic_shapes=data["dynamic_shapes"],
|
|
269
|
+
)
|
|
270
|
+
update["dynamic_shapes"] = ds
|
|
271
|
+
data.update(update)
|
|
272
|
+
|
|
273
|
+
update = {}
|
|
274
|
+
for k in data:
|
|
275
|
+
if k.startswith("inputs"):
|
|
276
|
+
v = data[k]
|
|
277
|
+
if dtype:
|
|
278
|
+
update[k] = v = to_any(
|
|
279
|
+
v, getattr(torch, dtype) if isinstance(dtype, str) else dtype
|
|
280
|
+
)
|
|
281
|
+
if device:
|
|
282
|
+
update[k] = v = to_any(v, device)
|
|
283
|
+
if update:
|
|
284
|
+
data.update(update)
|
|
285
|
+
|
|
286
|
+
args = [f"{model_id!r}"]
|
|
287
|
+
if subfolder:
|
|
288
|
+
args.append(f"subfolder={subfolder!r}")
|
|
289
|
+
if dtype:
|
|
290
|
+
args.append(f"dtype={dtype!r}")
|
|
291
|
+
if device:
|
|
292
|
+
args.append(f"device={device!r}")
|
|
293
|
+
if same_as_pretrained:
|
|
294
|
+
args.append(f"same_as_pretrained={same_as_pretrained!r}")
|
|
295
|
+
if use_pretrained:
|
|
296
|
+
args.append(f"use_pretrained={use_pretrained!r}")
|
|
297
|
+
if input_options:
|
|
298
|
+
args.append(f"input_options={input_options!r}")
|
|
299
|
+
if model_options:
|
|
300
|
+
args.append(f"model_options={model_options!r}")
|
|
301
|
+
model_args = ", ".join(args)
|
|
302
|
+
imports, exporter_code = (
|
|
303
|
+
make_export_code(
|
|
304
|
+
exporter=exporter,
|
|
305
|
+
patch_kwargs=patch_kwargs,
|
|
306
|
+
verbose=verbose,
|
|
307
|
+
optimization=optimization,
|
|
308
|
+
stop_if_static=stop_if_static,
|
|
309
|
+
dump_folder=dump_folder,
|
|
310
|
+
opset=opset,
|
|
311
|
+
dynamic_shapes=data["dynamic_shapes"],
|
|
312
|
+
)
|
|
313
|
+
if exporter is not None
|
|
314
|
+
else ([], [])
|
|
315
|
+
)
|
|
316
|
+
input_code = make_code_for_inputs(data["inputs"])
|
|
317
|
+
cache_import = (
|
|
318
|
+
"from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache"
|
|
319
|
+
if "dynamic_cache" in input_code
|
|
320
|
+
else ""
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
pieces = [
|
|
324
|
+
CODE_SAMPLES["imports"],
|
|
325
|
+
imports,
|
|
326
|
+
cache_import,
|
|
327
|
+
CODE_SAMPLES["get_model_with_inputs"],
|
|
328
|
+
textwrap.dedent(
|
|
329
|
+
f"""
|
|
330
|
+
model = get_model_with_inputs({model_args})
|
|
331
|
+
"""
|
|
332
|
+
),
|
|
333
|
+
f"inputs = {input_code}",
|
|
334
|
+
exporter_code,
|
|
335
|
+
]
|
|
336
|
+
code = "\n".join(pieces) # type: ignore[arg-type]
|
|
337
|
+
try:
|
|
338
|
+
import black
|
|
339
|
+
except ImportError:
|
|
340
|
+
# No black formatting.
|
|
341
|
+
return code
|
|
342
|
+
|
|
343
|
+
return black.format_str(code, mode=black.Mode())
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .model_inputs import get_untrained_model_with_inputs
|