onnx-diagnostic 0.7.0__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +196 -5
- onnx_diagnostic/export/dynamic_shapes.py +48 -20
- onnx_diagnostic/export/shape_helper.py +126 -0
- onnx_diagnostic/helpers/cache_helper.py +19 -8
- onnx_diagnostic/helpers/log_helper.py +1335 -176
- onnx_diagnostic/tasks/image_text_to_text.py +69 -18
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +3 -3
- onnx_diagnostic/torch_models/hghub/hub_api.py +61 -4
- onnx_diagnostic/torch_models/hghub/hub_data.py +6 -2
- onnx_diagnostic/torch_models/hghub/model_inputs.py +55 -14
- onnx_diagnostic/torch_models/validate.py +9 -4
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.1.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.1.dist-info}/RECORD +17 -16
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.1.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.1.dist-info}/top_level.txt +0 -0
|
@@ -96,10 +96,10 @@ def get_inputs(
|
|
|
96
96
|
for i in range(num_hidden_layers)
|
|
97
97
|
]
|
|
98
98
|
),
|
|
99
|
-
|
|
99
|
+
pixel_values=torch.ones((batch_size, n_images, num_channels, width, height)).to(
|
|
100
100
|
torch.int64
|
|
101
101
|
),
|
|
102
|
-
|
|
102
|
+
image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
|
|
103
103
|
torch.int64
|
|
104
104
|
),
|
|
105
105
|
)
|
|
@@ -132,16 +132,30 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
132
132
|
If the configuration is None, the function selects typical dimensions.
|
|
133
133
|
"""
|
|
134
134
|
if config is not None:
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
135
|
+
if hasattr(config, "text_config"):
|
|
136
|
+
check_hasattr(
|
|
137
|
+
config.text_config,
|
|
138
|
+
"vocab_size",
|
|
139
|
+
"hidden_size",
|
|
140
|
+
"num_attention_heads",
|
|
141
|
+
("num_key_value_heads", "num_attention_heads"),
|
|
142
|
+
"intermediate_size",
|
|
143
|
+
"hidden_size",
|
|
144
|
+
)
|
|
145
|
+
check_hasattr(config, "vision_config")
|
|
146
|
+
text_config = True
|
|
147
|
+
else:
|
|
148
|
+
check_hasattr(
|
|
149
|
+
config,
|
|
150
|
+
"vocab_size",
|
|
151
|
+
"hidden_size",
|
|
152
|
+
"num_attention_heads",
|
|
153
|
+
("num_key_value_heads", "num_attention_heads"),
|
|
154
|
+
"intermediate_size",
|
|
155
|
+
"hidden_size",
|
|
156
|
+
"vision_config",
|
|
157
|
+
)
|
|
158
|
+
text_config = False
|
|
145
159
|
check_hasattr(config.vision_config, "image_size", "num_channels")
|
|
146
160
|
kwargs = dict(
|
|
147
161
|
batch_size=2,
|
|
@@ -150,17 +164,54 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
150
164
|
head_dim=(
|
|
151
165
|
16
|
|
152
166
|
if config is None
|
|
153
|
-
else getattr(
|
|
167
|
+
else getattr(
|
|
168
|
+
config,
|
|
169
|
+
"head_dim",
|
|
170
|
+
(config.text_config.hidden_size if text_config else config.hidden_size)
|
|
171
|
+
// (
|
|
172
|
+
config.text_config.num_attention_heads
|
|
173
|
+
if text_config
|
|
174
|
+
else config.num_attention_heads
|
|
175
|
+
),
|
|
176
|
+
)
|
|
177
|
+
),
|
|
178
|
+
dummy_max_token_id=(
|
|
179
|
+
31999
|
|
180
|
+
if config is None
|
|
181
|
+
else (config.text_config.vocab_size if text_config else config.vocab_size) - 1
|
|
182
|
+
),
|
|
183
|
+
num_hidden_layers=(
|
|
184
|
+
4
|
|
185
|
+
if config is None
|
|
186
|
+
else (
|
|
187
|
+
config.text_config.num_hidden_layers
|
|
188
|
+
if text_config
|
|
189
|
+
else config.num_hidden_layers
|
|
190
|
+
)
|
|
154
191
|
),
|
|
155
|
-
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
|
|
156
|
-
num_hidden_layers=4 if config is None else config.num_hidden_layers,
|
|
157
192
|
num_key_value_heads=(
|
|
158
193
|
8
|
|
159
194
|
if config is None
|
|
160
|
-
else
|
|
195
|
+
else (
|
|
196
|
+
_pick(config.text_config, "num_key_value_heads", "num_attention_heads")
|
|
197
|
+
if text_config
|
|
198
|
+
else _pick(config, "num_key_value_heads", "num_attention_heads")
|
|
199
|
+
)
|
|
200
|
+
),
|
|
201
|
+
intermediate_size=(
|
|
202
|
+
1024
|
|
203
|
+
if config is None
|
|
204
|
+
else (
|
|
205
|
+
config.text_config.intermediate_size
|
|
206
|
+
if text_config
|
|
207
|
+
else config.intermediate_size
|
|
208
|
+
)
|
|
209
|
+
),
|
|
210
|
+
hidden_size=(
|
|
211
|
+
512
|
|
212
|
+
if config is None
|
|
213
|
+
else (config.text_config.hidden_size if text_config else config.hidden_size)
|
|
161
214
|
),
|
|
162
|
-
intermediate_size=1024 if config is None else config.intermediate_size,
|
|
163
|
-
hidden_size=512 if config is None else config.hidden_size,
|
|
164
215
|
width=224 if config is None else config.vision_config.image_size,
|
|
165
216
|
height=224 if config is None else config.vision_config.image_size,
|
|
166
217
|
num_channels=3 if config is None else config.vision_config.num_channels,
|
|
@@ -3,9 +3,10 @@ import functools
|
|
|
3
3
|
import json
|
|
4
4
|
import os
|
|
5
5
|
import pprint
|
|
6
|
+
import sys
|
|
6
7
|
from typing import Any, Dict, List, Optional, Union
|
|
7
8
|
import transformers
|
|
8
|
-
from huggingface_hub import HfApi, model_info, hf_hub_download
|
|
9
|
+
from huggingface_hub import HfApi, model_info, hf_hub_download, list_repo_files
|
|
9
10
|
from ...helpers.config_helper import update_config
|
|
10
11
|
from . import hub_data_cached_configs
|
|
11
12
|
from .hub_data import __date__, __data_tasks__, load_architecture_task, __data_arch_values__
|
|
@@ -138,12 +139,15 @@ def _guess_task_from_config(config: Any) -> Optional[str]:
|
|
|
138
139
|
|
|
139
140
|
|
|
140
141
|
@functools.cache
|
|
141
|
-
def task_from_arch(
|
|
142
|
+
def task_from_arch(
|
|
143
|
+
arch: str, default_value: Optional[str] = None, model_id: Optional[str] = None
|
|
144
|
+
) -> str:
|
|
142
145
|
"""
|
|
143
146
|
This function relies on stored information. That information needs to be refresh.
|
|
144
147
|
|
|
145
148
|
:param arch: architecture name
|
|
146
149
|
:param default_value: default value in case the task cannot be determined
|
|
150
|
+
:param model_id: unused unless the architecture does not help.
|
|
147
151
|
:return: task
|
|
148
152
|
|
|
149
153
|
.. runpython::
|
|
@@ -156,9 +160,16 @@ def task_from_arch(arch: str, default_value: Optional[str] = None) -> str:
|
|
|
156
160
|
<onnx_diagnostic.torch_models.hghub.hub_data.load_architecture_task>`.
|
|
157
161
|
"""
|
|
158
162
|
data = load_architecture_task()
|
|
163
|
+
if arch not in data and model_id:
|
|
164
|
+
# Let's try with the model id.
|
|
165
|
+
return task_from_id(model_id)
|
|
159
166
|
if default_value is not None:
|
|
160
167
|
return data.get(arch, default_value)
|
|
161
|
-
assert arch in data,
|
|
168
|
+
assert arch in data, (
|
|
169
|
+
f"Architecture {arch!r} is unknown, last refresh in {__date__}. "
|
|
170
|
+
f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__`` "
|
|
171
|
+
f"needs to be updated (model_id={(model_id or '?')!r})."
|
|
172
|
+
)
|
|
162
173
|
return data[arch]
|
|
163
174
|
|
|
164
175
|
|
|
@@ -176,6 +187,7 @@ def task_from_id(
|
|
|
176
187
|
if the task cannot be determined
|
|
177
188
|
:param pretrained: uses the config
|
|
178
189
|
:param fall_back_to_pretrained: falls back to pretrained config
|
|
190
|
+
:param exc: raises an exception if True
|
|
179
191
|
:return: task
|
|
180
192
|
"""
|
|
181
193
|
if not pretrained:
|
|
@@ -191,9 +203,14 @@ def task_from_id(
|
|
|
191
203
|
guess = _guess_task_from_config(config)
|
|
192
204
|
if guess is not None:
|
|
193
205
|
return guess
|
|
206
|
+
data = load_architecture_task()
|
|
207
|
+
if model_id in data:
|
|
208
|
+
return data[model_id]
|
|
194
209
|
assert config.architectures is not None and len(config.architectures) == 1, (
|
|
195
210
|
f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
|
|
196
|
-
f"architectures={config.architectures} in config={config}"
|
|
211
|
+
f"architectures={config.architectures} in config={config}. "
|
|
212
|
+
f"The task can be added in "
|
|
213
|
+
f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__``."
|
|
197
214
|
)
|
|
198
215
|
return task_from_arch(config.architectures[0], default_value=default_value)
|
|
199
216
|
|
|
@@ -311,3 +328,43 @@ def enumerate_model_list(
|
|
|
311
328
|
n -= 1
|
|
312
329
|
if n == 0:
|
|
313
330
|
break
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def download_code_modelid(
|
|
334
|
+
model_id: str, verbose: int = 0, add_path_to_sys_path: bool = True
|
|
335
|
+
) -> List[str]:
|
|
336
|
+
"""
|
|
337
|
+
Downloads the code for a given model id.
|
|
338
|
+
|
|
339
|
+
:param model_id: model id
|
|
340
|
+
:param verbose: verbosity
|
|
341
|
+
:param add_path_to_sys_path: add folder where the files are downloaded to sys.path
|
|
342
|
+
:return: list of downloaded files
|
|
343
|
+
"""
|
|
344
|
+
if verbose:
|
|
345
|
+
print(f"[download_code_modelid] retrieve file list for {model_id!r}")
|
|
346
|
+
files = list_repo_files(model_id)
|
|
347
|
+
pyfiles = [name for name in files if os.path.splitext(name)[-1] == ".py"]
|
|
348
|
+
if verbose:
|
|
349
|
+
print(f"[download_code_modelid] python files {pyfiles}")
|
|
350
|
+
absfiles = []
|
|
351
|
+
paths = set()
|
|
352
|
+
for i, name in enumerate(pyfiles):
|
|
353
|
+
if verbose:
|
|
354
|
+
print(f"[download_code_modelid] download file {i+1}/{len(pyfiles)}: {name!r}")
|
|
355
|
+
r = hf_hub_download(repo_id=model_id, filename=name)
|
|
356
|
+
p = os.path.split(r)[0]
|
|
357
|
+
paths.add(p)
|
|
358
|
+
absfiles.append(r)
|
|
359
|
+
if add_path_to_sys_path:
|
|
360
|
+
for p in paths:
|
|
361
|
+
init = os.path.join(p, "__init__.py")
|
|
362
|
+
if not os.path.exists(init):
|
|
363
|
+
with open(init, "w"):
|
|
364
|
+
pass
|
|
365
|
+
if p in sys.path:
|
|
366
|
+
continue
|
|
367
|
+
if verbose:
|
|
368
|
+
print(f"[download_code_modelid] add {p!r} to 'sys.path'")
|
|
369
|
+
sys.path.insert(0, p)
|
|
370
|
+
return absfiles
|
|
@@ -3,7 +3,7 @@ import functools
|
|
|
3
3
|
import textwrap
|
|
4
4
|
from typing import Dict, List
|
|
5
5
|
|
|
6
|
-
__date__ = "2025-
|
|
6
|
+
__date__ = "2025-06-21"
|
|
7
7
|
|
|
8
8
|
__data_arch_values__ = {"ResNetForImageClassification": dict(image_size=224)}
|
|
9
9
|
|
|
@@ -52,6 +52,8 @@ __data_arch__ = textwrap.dedent(
|
|
|
52
52
|
GPTNeoModel,feature-extraction
|
|
53
53
|
GPTNeoXForCausalLM,text-generation
|
|
54
54
|
GemmaForCausalLM,text-generation
|
|
55
|
+
Gemma2ForCausalLM,text-generation
|
|
56
|
+
Gemma3ForConditionalGeneration,image-text-to-text
|
|
55
57
|
GraniteForCausalLM,text-generation
|
|
56
58
|
GroupViTModel,feature-extraction
|
|
57
59
|
HieraForImageClassification,image-classification
|
|
@@ -97,6 +99,7 @@ __data_arch__ = textwrap.dedent(
|
|
|
97
99
|
PegasusModel,feature-extraction
|
|
98
100
|
Phi3ForCausalLM,text-generation
|
|
99
101
|
PhiForCausalLM,text-generation
|
|
102
|
+
PhiMoEForCausalLM,text-generation
|
|
100
103
|
Pix2StructForConditionalGeneration,image-to-text
|
|
101
104
|
PLBartForConditionalGeneration,text2text-generation
|
|
102
105
|
PoolFormerModel,image-feature-extraction
|
|
@@ -144,7 +147,8 @@ __data_arch__ = textwrap.dedent(
|
|
|
144
147
|
XLMRobertaModel,sentence-similarity
|
|
145
148
|
Wav2Vec2ForCTC,automatic-speech-recognition
|
|
146
149
|
YolosForObjectDetection,object-detection
|
|
147
|
-
YolosModel,image-feature-extraction
|
|
150
|
+
YolosModel,image-feature-extraction
|
|
151
|
+
emilyalsentzer/Bio_ClinicalBERT,fill-mask"""
|
|
148
152
|
)
|
|
149
153
|
|
|
150
154
|
__data_tasks__ = [
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import os
|
|
3
|
+
import pprint
|
|
3
4
|
from typing import Any, Dict, Optional, Tuple
|
|
4
5
|
import torch
|
|
5
6
|
import transformers
|
|
6
7
|
from ...helpers.config_helper import update_config
|
|
7
8
|
from ...tasks import reduce_model_config, random_input_kwargs
|
|
8
|
-
from .hub_api import task_from_arch, task_from_id, get_pretrained_config
|
|
9
|
+
from .hub_api import task_from_arch, task_from_id, get_pretrained_config, download_code_modelid
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def _code_needing_rewriting(model: Any) -> Any:
|
|
@@ -22,6 +23,7 @@ def get_untrained_model_with_inputs(
|
|
|
22
23
|
model_kwargs: Optional[Dict[str, Any]] = None,
|
|
23
24
|
verbose: int = 0,
|
|
24
25
|
dynamic_rope: Optional[bool] = None,
|
|
26
|
+
use_pretrained: bool = False,
|
|
25
27
|
same_as_pretrained: bool = False,
|
|
26
28
|
use_preinstalled: bool = True,
|
|
27
29
|
add_second_input: bool = False,
|
|
@@ -43,6 +45,7 @@ def get_untrained_model_with_inputs(
|
|
|
43
45
|
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
|
|
44
46
|
:param same_as_pretrained: if True, do not change the default values
|
|
45
47
|
to get a smaller model
|
|
48
|
+
:param use_pretrained: download the pretrained weights as well
|
|
46
49
|
:param use_preinstalled: use preinstalled configurations
|
|
47
50
|
:param add_second_input: provides a second inputs to check a model
|
|
48
51
|
supports different shapes
|
|
@@ -68,6 +71,10 @@ def get_untrained_model_with_inputs(
|
|
|
68
71
|
print("-- dynamic shapes:", pprint.pformat(data['dynamic_shapes']))
|
|
69
72
|
print("-- configuration:", pprint.pformat(data['configuration']))
|
|
70
73
|
"""
|
|
74
|
+
assert not use_preinstalled or not use_only_preinstalled, (
|
|
75
|
+
f"model_id={model_id!r}, pretinstalled model is only available "
|
|
76
|
+
f"if use_only_preinstalled is False."
|
|
77
|
+
)
|
|
71
78
|
if verbose:
|
|
72
79
|
print(f"[get_untrained_model_with_inputs] model_id={model_id!r}")
|
|
73
80
|
if use_preinstalled:
|
|
@@ -99,7 +106,7 @@ def get_untrained_model_with_inputs(
|
|
|
99
106
|
print(f"[get_untrained_model_with_inputs] architectures={archs!r}")
|
|
100
107
|
print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
|
|
101
108
|
if task is None:
|
|
102
|
-
task = task_from_arch(archs[0])
|
|
109
|
+
task = task_from_arch(archs[0], model_id=model_id)
|
|
103
110
|
if verbose:
|
|
104
111
|
print(f"[get_untrained_model_with_inputs] task={task!r}")
|
|
105
112
|
|
|
@@ -114,7 +121,6 @@ def get_untrained_model_with_inputs(
|
|
|
114
121
|
)
|
|
115
122
|
|
|
116
123
|
# updating the configuration
|
|
117
|
-
|
|
118
124
|
mkwargs = reduce_model_config(config, task) if not same_as_pretrained else {}
|
|
119
125
|
if model_kwargs:
|
|
120
126
|
for k, v in model_kwargs.items():
|
|
@@ -139,27 +145,62 @@ def get_untrained_model_with_inputs(
|
|
|
139
145
|
f"{config._attn_implementation!r}" # type: ignore[union-attr]
|
|
140
146
|
)
|
|
141
147
|
|
|
148
|
+
if use_pretrained:
|
|
149
|
+
model = transformers.AutoModel.from_pretrained(model_id, **mkwargs)
|
|
150
|
+
else:
|
|
151
|
+
if archs is not None:
|
|
152
|
+
try:
|
|
153
|
+
model = getattr(transformers, archs[0])(config)
|
|
154
|
+
except AttributeError as e:
|
|
155
|
+
# The code of the models is not in transformers but in the
|
|
156
|
+
# repository of the model. We need to download it.
|
|
157
|
+
pyfiles = download_code_modelid(model_id, verbose=verbose)
|
|
158
|
+
if pyfiles:
|
|
159
|
+
if "." in archs[0]:
|
|
160
|
+
cls_name = archs[0]
|
|
161
|
+
else:
|
|
162
|
+
modeling = [_ for _ in pyfiles if "/modeling_" in _]
|
|
163
|
+
assert len(modeling) == 1, (
|
|
164
|
+
f"Unable to guess the main file implemented class {archs[0]!r} "
|
|
165
|
+
f"from {pyfiles}, found={modeling}."
|
|
166
|
+
)
|
|
167
|
+
last_name = os.path.splitext(os.path.split(modeling[0])[-1])[0]
|
|
168
|
+
cls_name = f"{last_name}.{archs[0]}"
|
|
169
|
+
if verbose:
|
|
170
|
+
print(
|
|
171
|
+
f"[get_untrained_model_with_inputs] custom code for {cls_name!r}"
|
|
172
|
+
)
|
|
173
|
+
print(
|
|
174
|
+
f"[get_untrained_model_with_inputs] from folder "
|
|
175
|
+
f"{os.path.split(pyfiles[0])[0]!r}"
|
|
176
|
+
)
|
|
177
|
+
cls = transformers.dynamic_module_utils.get_class_from_dynamic_module(
|
|
178
|
+
cls_name, pretrained_model_name_or_path=os.path.split(pyfiles[0])[0]
|
|
179
|
+
)
|
|
180
|
+
model = cls(config)
|
|
181
|
+
else:
|
|
182
|
+
raise AttributeError(
|
|
183
|
+
f"Unable to find class 'tranformers.{archs[0]}'. "
|
|
184
|
+
f"The code needs to be downloaded, config="
|
|
185
|
+
f"\n{pprint.pformat(config)}."
|
|
186
|
+
) from e
|
|
187
|
+
else:
|
|
188
|
+
assert same_as_pretrained and use_pretrained, (
|
|
189
|
+
f"Model {model_id!r} cannot be built, the model cannot be built. "
|
|
190
|
+
f"It must be downloaded. Use same_as_pretrained=True "
|
|
191
|
+
f"and use_pretrained=True."
|
|
192
|
+
)
|
|
193
|
+
|
|
142
194
|
# input kwargs
|
|
143
195
|
kwargs, fct = random_input_kwargs(config, task)
|
|
144
196
|
if verbose:
|
|
145
197
|
print(f"[get_untrained_model_with_inputs] use fct={fct}")
|
|
146
198
|
if os.environ.get("PRINT_CONFIG") in (1, "1"):
|
|
147
|
-
import pprint
|
|
148
|
-
|
|
149
199
|
print(f"-- input kwargs for task {task!r}")
|
|
150
200
|
pprint.pprint(kwargs)
|
|
151
201
|
if inputs_kwargs:
|
|
152
202
|
kwargs.update(inputs_kwargs)
|
|
153
203
|
|
|
154
|
-
if archs is not None:
|
|
155
|
-
model = getattr(transformers, archs[0])(config)
|
|
156
|
-
else:
|
|
157
|
-
assert same_as_pretrained, (
|
|
158
|
-
f"Model {model_id!r} cannot be built, the model cannot be built. "
|
|
159
|
-
f"It must be downloaded. Use same_as_pretrained=True."
|
|
160
|
-
)
|
|
161
|
-
model = None
|
|
162
|
-
|
|
163
204
|
# This line is important. Some models may produce different
|
|
164
205
|
# outputs even with the same inputs in training mode.
|
|
165
206
|
model.eval()
|
|
@@ -259,7 +259,8 @@ def validate_model(
|
|
|
259
259
|
verbose: int = 0,
|
|
260
260
|
dtype: Optional[Union[str, torch.dtype]] = None,
|
|
261
261
|
device: Optional[Union[str, torch.device]] = None,
|
|
262
|
-
|
|
262
|
+
same_as_pretrained: bool = False,
|
|
263
|
+
use_pretrained: bool = False,
|
|
263
264
|
optimization: Optional[str] = None,
|
|
264
265
|
quiet: bool = False,
|
|
265
266
|
patch: bool = False,
|
|
@@ -294,7 +295,9 @@ def validate_model(
|
|
|
294
295
|
:param verbose: verbosity level
|
|
295
296
|
:param dtype: uses this dtype to check the model
|
|
296
297
|
:param device: do the verification on this device
|
|
297
|
-
:param
|
|
298
|
+
:param same_as_pretrained: use a model equivalent to the trained,
|
|
299
|
+
this is not always possible
|
|
300
|
+
:param use_pretrained: use the trained model, not the untrained one
|
|
298
301
|
:param optimization: optimization to apply to the exported model,
|
|
299
302
|
depend on the the exporter
|
|
300
303
|
:param quiet: if quiet, catches exception if any issue
|
|
@@ -353,7 +356,8 @@ def validate_model(
|
|
|
353
356
|
version_do_run=str(do_run),
|
|
354
357
|
version_dtype=str(dtype or ""),
|
|
355
358
|
version_device=str(device or ""),
|
|
356
|
-
|
|
359
|
+
version_same_as_pretrained=str(same_as_pretrained),
|
|
360
|
+
version_use_pretrained=str(use_pretrained),
|
|
357
361
|
version_optimization=optimization or "",
|
|
358
362
|
version_quiet=str(quiet),
|
|
359
363
|
version_patch=str(patch),
|
|
@@ -408,11 +412,12 @@ def validate_model(
|
|
|
408
412
|
summary,
|
|
409
413
|
None,
|
|
410
414
|
(
|
|
411
|
-
lambda mid=model_id, v=verbose, task=task, tr=
|
|
415
|
+
lambda mid=model_id, v=verbose, task=task, uptr=use_pretrained, tr=same_as_pretrained, iop=iop, sub=subfolder, i2=inputs2: ( # noqa: E501
|
|
412
416
|
get_untrained_model_with_inputs(
|
|
413
417
|
mid,
|
|
414
418
|
verbose=v,
|
|
415
419
|
task=task,
|
|
420
|
+
use_pretrained=uptr,
|
|
416
421
|
same_as_pretrained=tr,
|
|
417
422
|
inputs_kwargs=iop,
|
|
418
423
|
model_kwargs=mop,
|
|
@@ -1,21 +1,22 @@
|
|
|
1
|
-
onnx_diagnostic/__init__.py,sha256=
|
|
1
|
+
onnx_diagnostic/__init__.py,sha256=YwjIZRhfTzRgTOBvmUSNNYX0SBBdmLsWfkMVwHkJloQ,173
|
|
2
2
|
onnx_diagnostic/__main__.py,sha256=YmyV_Aq_ianDlHyKLHMa6h8YK3ZmFPpLVHLKjM91aCk,79
|
|
3
|
-
onnx_diagnostic/_command_lines_parser.py,sha256=
|
|
3
|
+
onnx_diagnostic/_command_lines_parser.py,sha256=WpCri2dqc1a1KYthQcb4-eN0htfiWeLrAkncNK3cZaY,27466
|
|
4
4
|
onnx_diagnostic/api.py,sha256=BhCl_yCd78N7TlVtPOHjeYv1QBEy39TjZ647rcHqLh0,345
|
|
5
5
|
onnx_diagnostic/doc.py,sha256=t3RELgfooYnVMAi0JSpggWkQEgUsREz8NmRvn0TnLI8,2829
|
|
6
6
|
onnx_diagnostic/ext_test_case.py,sha256=IX-DNabvsPw8UkUeXC1amw3nnzdmJ3DeERn4E1Y_omo,42416
|
|
7
7
|
onnx_diagnostic/export/__init__.py,sha256=yEIoWiOeTwBsDhyYt2fTKuhtA0Ya1J9u9ZzMTOTWaWs,101
|
|
8
|
-
onnx_diagnostic/export/dynamic_shapes.py,sha256=
|
|
8
|
+
onnx_diagnostic/export/dynamic_shapes.py,sha256=HYf2OEi7PmRSj8uxMD-wbdVxxejkWXTPBAkxoFeM27A,40811
|
|
9
|
+
onnx_diagnostic/export/shape_helper.py,sha256=C9cEq_x8I40RKuD89qWIholN1XZoWhaKPfbZQhiPD3g,4725
|
|
9
10
|
onnx_diagnostic/export/validate.py,sha256=_PGUql2DJhIgGKo0WjTGUc5AgsZUx8fEs00MePy-w98,6043
|
|
10
11
|
onnx_diagnostic/helpers/__init__.py,sha256=GJ2GT7cgnlIveVUwMZhuvUwidbTJaKv8CsSIOpZDsJg,83
|
|
11
12
|
onnx_diagnostic/helpers/args_helper.py,sha256=SRWnqC7EENg09RZlA50B_PcdiIhdbgA4C3ACfzl5nMs,4419
|
|
12
13
|
onnx_diagnostic/helpers/bench_run.py,sha256=CGA6VMJZMH2gDhVueT9ypNm4PMcjGrrGFYp08nhWj9k,16539
|
|
13
|
-
onnx_diagnostic/helpers/cache_helper.py,sha256=
|
|
14
|
+
onnx_diagnostic/helpers/cache_helper.py,sha256=E_D0z5ks4zUJb9H6K19MKhUZ-nZTC_dgeDO5zXNFE9g,10824
|
|
14
15
|
onnx_diagnostic/helpers/config_helper.py,sha256=CdMeUhmDe0LfKcdPv9-Lzt73RRs29NmUHg9uVrdFwTQ,3479
|
|
15
16
|
onnx_diagnostic/helpers/doc_helper.py,sha256=pl5MZd3_FaE8BqQnqoBuSBxoNCFcd2OJd3eITUSku5c,5897
|
|
16
17
|
onnx_diagnostic/helpers/graph_helper.py,sha256=hevQT5a7_QuriVPQcbT5qe18n99Doyl5h3-qshx1-uk,14093
|
|
17
18
|
onnx_diagnostic/helpers/helper.py,sha256=_6K0IvfK7ymBE8uWFAOA1ksU_fMvl2BRtlxj5SA9R2I,58203
|
|
18
|
-
onnx_diagnostic/helpers/log_helper.py,sha256=
|
|
19
|
+
onnx_diagnostic/helpers/log_helper.py,sha256=qZdvHHQqkYdZOf8UsIrByswMYSF_axca27JXRyQk52Y,69163
|
|
19
20
|
onnx_diagnostic/helpers/memory_peak.py,sha256=OT6mz0muBbBZY0pjgW2_eCk_lOtFRo-5w4jFo2Z6Kok,6380
|
|
20
21
|
onnx_diagnostic/helpers/mini_onnx_builder.py,sha256=p0Xh2Br38xAqUjB2214GiNOIbCgiVZKeyVEnjdyqyFI,21091
|
|
21
22
|
onnx_diagnostic/helpers/model_builder_helper.py,sha256=RvDyPFqRboEU3HsQV_xi9oy-o3_4KuGFVzs5MhksduY,12552
|
|
@@ -75,7 +76,7 @@ onnx_diagnostic/tasks/automatic_speech_recognition.py,sha256=7OspFypNHLSL6huvP9m
|
|
|
75
76
|
onnx_diagnostic/tasks/feature_extraction.py,sha256=CbxbGsv3JvEQ2J9tO2DOpMHcJj5ZlCwY81ZB3hPB4D4,2339
|
|
76
77
|
onnx_diagnostic/tasks/fill_mask.py,sha256=ZWz8swzEeRbkmbY9oZ4CM1LYCWWUxnS5CqrKmUVw-u0,2457
|
|
77
78
|
onnx_diagnostic/tasks/image_classification.py,sha256=UjUAFYnwXIdPMXJdHR5MDzpsfMeIvyuKR4RqJVpGV_Q,4449
|
|
78
|
-
onnx_diagnostic/tasks/image_text_to_text.py,sha256=
|
|
79
|
+
onnx_diagnostic/tasks/image_text_to_text.py,sha256=LmpMdH6oF_EN3WIACzSip4fPZOjZWFOoXg4k8qAio6Q,7639
|
|
79
80
|
onnx_diagnostic/tasks/mixture_of_expert.py,sha256=C0ugEc8OWmVyEZpsh8MJq_te1zgOHhpITtnSmGC16Ls,2801
|
|
80
81
|
onnx_diagnostic/tasks/object_detection.py,sha256=1lF5e2f2Coz1urSptEKgvUGCOSFBf0Anuq_QYOC00dA,4046
|
|
81
82
|
onnx_diagnostic/tasks/sentence_similarity.py,sha256=3MvNxjC1iEMtQL_jH1c8bmrVc5IG1lfUygrCZ0SORJk,2472
|
|
@@ -86,7 +87,7 @@ onnx_diagnostic/tasks/text_generation.py,sha256=PDh870BB-llzlu8h_aZX4Z-9QLzcGmDw
|
|
|
86
87
|
onnx_diagnostic/tasks/zero_shot_image_classification.py,sha256=GKaXm8g7cK23h3wJEUc6Q-6mpmLAzQ4YkJbd-eGP7Y4,4496
|
|
87
88
|
onnx_diagnostic/torch_export_patches/__init__.py,sha256=0SaZedwznm1hQUCvXZsGZORV5vby954wEExr5faepGg,720
|
|
88
89
|
onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=op8jgnTa_1T_bGN172A6YFTtkQv_ALMNu1oukrsFt9U,20634
|
|
89
|
-
onnx_diagnostic/torch_export_patches/onnx_export_serialization.py,sha256=
|
|
90
|
+
onnx_diagnostic/torch_export_patches/onnx_export_serialization.py,sha256=Nsf7HUJqu3ZRd0o9vUVCF6ifmS5UaQM6hB_Gmn19dNI,17095
|
|
90
91
|
onnx_diagnostic/torch_export_patches/patch_expressions.py,sha256=vr4tt61cbDnaaaduzMj4UBZ8OUtr6GfDpIWwOYqjWzs,3213
|
|
91
92
|
onnx_diagnostic/torch_export_patches/patch_inputs.py,sha256=9b4pmyT00BwLqi7WG-gliep1RUy3gXEgW6BDnlSSA-M,7689
|
|
92
93
|
onnx_diagnostic/torch_export_patches/patch_module.py,sha256=R2d9IHM-RwsBKDsxuBIJnEqMoxbS9gd4YWFGG2wwV5A,39881
|
|
@@ -98,20 +99,20 @@ onnx_diagnostic/torch_export_patches/patches/patch_torch.py,sha256=KaZ8TjDa9ATgT
|
|
|
98
99
|
onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=GwcPUaSm-Zys2pWHac8Wcvpmy2h4oiFQDmx_D3GZNBA,41007
|
|
99
100
|
onnx_diagnostic/torch_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
100
101
|
onnx_diagnostic/torch_models/llms.py,sha256=soyg4yC87ptGoeulJhKqw5opGmuLvH1pn_ZDXZ4Jr8E,90
|
|
101
|
-
onnx_diagnostic/torch_models/validate.py,sha256=
|
|
102
|
+
onnx_diagnostic/torch_models/validate.py,sha256=NXFmJKmoO4reaeiu2ibuVgMRLS-l0WSdLhjn40_YsbU,62177
|
|
102
103
|
onnx_diagnostic/torch_models/hghub/__init__.py,sha256=vi1Q7YHdddj1soiBN42MSvJdFqe2_KUoWafHISjwOu8,58
|
|
103
|
-
onnx_diagnostic/torch_models/hghub/hub_api.py,sha256=
|
|
104
|
-
onnx_diagnostic/torch_models/hghub/hub_data.py,sha256=
|
|
104
|
+
onnx_diagnostic/torch_models/hghub/hub_api.py,sha256=q4jUgSJ8AD28mpX7yDAUp0z7EQgb8euuD-P9Hayehds,12672
|
|
105
|
+
onnx_diagnostic/torch_models/hghub/hub_data.py,sha256=lM9IDb5-3X8NSHcSPJFLS3tAvu_FqvcetyoHn-P2FIM,8272
|
|
105
106
|
onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py,sha256=dE8tHksGOsTk77jpa7mldLYzaQ5joKxDxDB0ZnwQBV4,267246
|
|
106
|
-
onnx_diagnostic/torch_models/hghub/model_inputs.py,sha256=
|
|
107
|
+
onnx_diagnostic/torch_models/hghub/model_inputs.py,sha256=XCllU8_fJBjkCGSV8cdqlpF1QH6AN_OAErK0aAXNQts,10261
|
|
107
108
|
onnx_diagnostic/torch_models/untrained/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
108
109
|
onnx_diagnostic/torch_models/untrained/llm_phi2.py,sha256=ynBTDHJHCk44NjLT_t6OiFDBdPP0rFGPteiONDxvztw,3708
|
|
109
110
|
onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py,sha256=QXw_Bs2SzfeiQMf-tmtVl83SmVOL4-Um7Qy-f0E48QI,2507
|
|
110
111
|
onnx_diagnostic/torch_onnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
111
112
|
onnx_diagnostic/torch_onnx/runtime_info.py,sha256=1g9F_Jf9AAgYQU4stbsrFXwQl-30mWlQrFbQ7val8Ps,9268
|
|
112
113
|
onnx_diagnostic/torch_onnx/sbs.py,sha256=1EL25DeYFzlBSiFG_XjePBLvsiItRXbdDrr5-QZW2mA,16878
|
|
113
|
-
onnx_diagnostic-0.7.
|
|
114
|
-
onnx_diagnostic-0.7.
|
|
115
|
-
onnx_diagnostic-0.7.
|
|
116
|
-
onnx_diagnostic-0.7.
|
|
117
|
-
onnx_diagnostic-0.7.
|
|
114
|
+
onnx_diagnostic-0.7.1.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
|
|
115
|
+
onnx_diagnostic-0.7.1.dist-info/METADATA,sha256=7YUf_2f3gFaGkWUbOrZNqREtnQSpRXx5q65HJEDcDI8,6631
|
|
116
|
+
onnx_diagnostic-0.7.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
117
|
+
onnx_diagnostic-0.7.1.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
|
|
118
|
+
onnx_diagnostic-0.7.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|