onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,422 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import functools
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import pprint
|
|
6
|
+
import sys
|
|
7
|
+
from typing import Any, Dict, List, Optional, Union
|
|
8
|
+
import transformers
|
|
9
|
+
from huggingface_hub import HfApi, model_info, hf_hub_download, list_repo_files
|
|
10
|
+
from ...helpers.config_helper import update_config
|
|
11
|
+
from . import hub_data_cached_configs
|
|
12
|
+
from .hub_data import __date__, __data_tasks__, load_architecture_task, __data_arch_values__
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@functools.cache
|
|
16
|
+
def get_architecture_default_values(architecture: str):
|
|
17
|
+
"""
|
|
18
|
+
The configuration may miss information to build the dummy inputs.
|
|
19
|
+
This information returns the missing pieces.
|
|
20
|
+
"""
|
|
21
|
+
assert architecture in __data_arch_values__, (
|
|
22
|
+
f"No known default values for {architecture!r}, "
|
|
23
|
+
f"expecting one architecture in {', '.join(sorted(__data_arch_values__))}"
|
|
24
|
+
)
|
|
25
|
+
return __data_arch_values__[architecture]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@functools.cache
|
|
29
|
+
def _retrieve_cached_configurations() -> Dict[str, transformers.PretrainedConfig]:
|
|
30
|
+
res = {}
|
|
31
|
+
for k, v in hub_data_cached_configs.__dict__.items():
|
|
32
|
+
if k.startswith("_ccached_"):
|
|
33
|
+
doc = v.__doc__
|
|
34
|
+
res[doc] = v
|
|
35
|
+
return res
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_cached_configuration(
|
|
39
|
+
name: str, exc: bool = False, **kwargs
|
|
40
|
+
) -> Optional[transformers.PretrainedConfig]:
|
|
41
|
+
"""
|
|
42
|
+
Returns cached configuration to avoid having to many accesses to internet.
|
|
43
|
+
It returns None if not Cache. The list of cached models follows.
|
|
44
|
+
If *exc* is True or if environment variable ``NOHTTP`` is defined,
|
|
45
|
+
the function raises an exception if *name* is not found.
|
|
46
|
+
|
|
47
|
+
.. runpython::
|
|
48
|
+
|
|
49
|
+
import pprint
|
|
50
|
+
from onnx_diagnostic.torch_models.hghub.hub_api import _retrieve_cached_configurations
|
|
51
|
+
|
|
52
|
+
configs = _retrieve_cached_configurations()
|
|
53
|
+
pprint.pprint(sorted(configs))
|
|
54
|
+
"""
|
|
55
|
+
cached = _retrieve_cached_configurations()
|
|
56
|
+
assert cached, "no cached configuration, which is weird"
|
|
57
|
+
if name in cached:
|
|
58
|
+
conf = cached[name]()
|
|
59
|
+
if kwargs:
|
|
60
|
+
conf = copy.deepcopy(conf)
|
|
61
|
+
update_config(conf, kwargs)
|
|
62
|
+
return conf
|
|
63
|
+
assert not exc and not os.environ.get("NOHTTP", ""), (
|
|
64
|
+
f"Unable to find {name!r} (exc={exc}, "
|
|
65
|
+
f"NOHTTP={os.environ.get('NOHTTP', '')!r}) "
|
|
66
|
+
f"in {pprint.pformat(sorted(cached))}"
|
|
67
|
+
)
|
|
68
|
+
return None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_pretrained_config(
|
|
72
|
+
model_id: str,
|
|
73
|
+
trust_remote_code: bool = True,
|
|
74
|
+
use_preinstalled: bool = True,
|
|
75
|
+
subfolder: Optional[str] = None,
|
|
76
|
+
use_only_preinstalled: bool = False,
|
|
77
|
+
**kwargs,
|
|
78
|
+
) -> Any:
|
|
79
|
+
"""
|
|
80
|
+
Returns the config for a model_id.
|
|
81
|
+
|
|
82
|
+
:param model_id: model id
|
|
83
|
+
:param trust_remote_code: trust_remote_code,
|
|
84
|
+
see :meth:`transformers.AutoConfig.from_pretrained`
|
|
85
|
+
:param use_preinstalled: if use_preinstalled, uses this version to avoid
|
|
86
|
+
accessing the network, if available, it is returned by
|
|
87
|
+
:func:`get_cached_configuration`, the cached list is mostly for
|
|
88
|
+
unit tests
|
|
89
|
+
:param subfolder: subfolder for the given model id
|
|
90
|
+
:param use_only_preinstalled: if True, raises an exception if not preinstalled
|
|
91
|
+
:param kwargs: additional kwargs
|
|
92
|
+
:return: a configuration
|
|
93
|
+
"""
|
|
94
|
+
if use_preinstalled:
|
|
95
|
+
conf = get_cached_configuration(
|
|
96
|
+
model_id, exc=use_only_preinstalled, subfolder=subfolder, **kwargs
|
|
97
|
+
)
|
|
98
|
+
if conf is not None:
|
|
99
|
+
return conf
|
|
100
|
+
assert not use_only_preinstalled, (
|
|
101
|
+
f"Inconsistencies: use_only_preinstalled={use_only_preinstalled}, "
|
|
102
|
+
f"use_preinstalled={use_preinstalled!r}"
|
|
103
|
+
)
|
|
104
|
+
if subfolder:
|
|
105
|
+
try:
|
|
106
|
+
return transformers.AutoConfig.from_pretrained(
|
|
107
|
+
model_id, trust_remote_code=trust_remote_code, subfolder=subfolder, **kwargs
|
|
108
|
+
)
|
|
109
|
+
except ValueError:
|
|
110
|
+
# Then we try to download it.
|
|
111
|
+
config = hf_hub_download(
|
|
112
|
+
model_id, filename="config.json", subfolder=subfolder, **kwargs
|
|
113
|
+
)
|
|
114
|
+
try:
|
|
115
|
+
return transformers.AutoConfig.from_pretrained(
|
|
116
|
+
config, trust_remote_code=trust_remote_code, **kwargs
|
|
117
|
+
)
|
|
118
|
+
except ValueError:
|
|
119
|
+
# Diffusers uses a dictionayr.
|
|
120
|
+
with open(config, "r") as f:
|
|
121
|
+
return json.load(f)
|
|
122
|
+
return transformers.AutoConfig.from_pretrained(
|
|
123
|
+
model_id, trust_remote_code=trust_remote_code, **kwargs
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def get_model_info(model_id) -> Any:
|
|
128
|
+
"""Returns the model info for a model_id."""
|
|
129
|
+
return model_info(model_id)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _guess_task_from_config(config: Any) -> Optional[str]:
|
|
133
|
+
"""Tries to infer a task from the configuration."""
|
|
134
|
+
if hasattr(config, "bbox_loss_coefficient") and hasattr(config, "giou_loss_coefficient"):
|
|
135
|
+
return "object-detection"
|
|
136
|
+
if hasattr(config, "architecture") and config.architecture:
|
|
137
|
+
return task_from_arch(config.architecture)
|
|
138
|
+
return None
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@functools.cache
|
|
142
|
+
def task_from_arch(
|
|
143
|
+
arch: str,
|
|
144
|
+
default_value: Optional[str] = None,
|
|
145
|
+
model_id: Optional[str] = None,
|
|
146
|
+
subfolder: Optional[str] = None,
|
|
147
|
+
) -> str:
|
|
148
|
+
"""
|
|
149
|
+
This function relies on stored information. That information needs to be refresh.
|
|
150
|
+
|
|
151
|
+
:param arch: architecture name
|
|
152
|
+
:param default_value: default value in case the task cannot be determined
|
|
153
|
+
:param model_id: unused unless the architecture does not help.
|
|
154
|
+
:param subfolder: subfolder
|
|
155
|
+
:return: task
|
|
156
|
+
|
|
157
|
+
.. runpython::
|
|
158
|
+
|
|
159
|
+
from onnx_diagnostic.torch_models.hghub.hub_data import __date__
|
|
160
|
+
print("last refresh", __date__)
|
|
161
|
+
|
|
162
|
+
List of supported architectures, see
|
|
163
|
+
:func:`load_architecture_task
|
|
164
|
+
<onnx_diagnostic.torch_models.hghub.hub_data.load_architecture_task>`.
|
|
165
|
+
"""
|
|
166
|
+
data = load_architecture_task()
|
|
167
|
+
if arch not in data and model_id:
|
|
168
|
+
# Let's try with the model id.
|
|
169
|
+
return task_from_id(model_id, subfolder=subfolder)
|
|
170
|
+
if default_value is not None:
|
|
171
|
+
return data.get(arch, default_value)
|
|
172
|
+
assert arch in data, (
|
|
173
|
+
f"Architecture {arch!r} is unknown, last refresh in {__date__}. "
|
|
174
|
+
f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__`` "
|
|
175
|
+
f"needs to be updated (model_id={(model_id or '?')!r})."
|
|
176
|
+
)
|
|
177
|
+
return data[arch]
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _trygetattr(config, attname):
|
|
181
|
+
try:
|
|
182
|
+
return getattr(config, attname)
|
|
183
|
+
except AttributeError:
|
|
184
|
+
return None
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def architecture_from_config(config) -> Optional[str]:
|
|
188
|
+
"""Guesses the architecture (class) of the model described by this config."""
|
|
189
|
+
if isinstance(config, dict):
|
|
190
|
+
if "_class_name" in config:
|
|
191
|
+
return config["_class_name"]
|
|
192
|
+
if "architecture" in config:
|
|
193
|
+
return config["architecture"]
|
|
194
|
+
if config.get("architectures", []):
|
|
195
|
+
return config["architectures"][0]
|
|
196
|
+
if hasattr(config, "_class_name"):
|
|
197
|
+
return config._class_name
|
|
198
|
+
if hasattr(config, "architecture"):
|
|
199
|
+
return config.architecture
|
|
200
|
+
if hasattr(config, "architectures") and config.architectures:
|
|
201
|
+
return config.architectures[0]
|
|
202
|
+
if hasattr(config, "__dict__"):
|
|
203
|
+
if "_class_name" in config.__dict__:
|
|
204
|
+
return config.__dict__["_class_name"]
|
|
205
|
+
if "architecture" in config.__dict__:
|
|
206
|
+
return config.__dict__["architecture"]
|
|
207
|
+
if config.__dict__.get("architectures", []):
|
|
208
|
+
return config.__dict__["architectures"][0]
|
|
209
|
+
return None
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def find_package_source(config) -> Optional[str]:
|
|
213
|
+
"""Guesses the package the class models from."""
|
|
214
|
+
if isinstance(config, dict):
|
|
215
|
+
if "_diffusers_version" in config:
|
|
216
|
+
return "diffusers"
|
|
217
|
+
if hasattr(config, "_diffusers_version"):
|
|
218
|
+
return "diffusers"
|
|
219
|
+
if hasattr(config, "__dict__"):
|
|
220
|
+
if "_diffusers_version" in config.__dict__:
|
|
221
|
+
return "diffusers"
|
|
222
|
+
return "transformers"
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def task_from_id(
|
|
226
|
+
model_id: str,
|
|
227
|
+
default_value: Optional[str] = None,
|
|
228
|
+
pretrained: bool = False,
|
|
229
|
+
fall_back_to_pretrained: bool = True,
|
|
230
|
+
subfolder: Optional[str] = None,
|
|
231
|
+
) -> str:
|
|
232
|
+
"""
|
|
233
|
+
Returns the task attached to a model id.
|
|
234
|
+
|
|
235
|
+
:param model_id: model id
|
|
236
|
+
:param default_value: if specified, the function returns this value
|
|
237
|
+
if the task cannot be determined
|
|
238
|
+
:param pretrained: uses the config
|
|
239
|
+
:param fall_back_to_pretrained: falls back to pretrained config
|
|
240
|
+
:param subfolder: subfolder
|
|
241
|
+
:return: task
|
|
242
|
+
"""
|
|
243
|
+
if not pretrained:
|
|
244
|
+
try:
|
|
245
|
+
transformers.pipelines.get_task(model_id)
|
|
246
|
+
except RuntimeError:
|
|
247
|
+
if not fall_back_to_pretrained:
|
|
248
|
+
raise
|
|
249
|
+
config = get_pretrained_config(model_id, subfolder=subfolder)
|
|
250
|
+
tag = _trygetattr(config, "pipeline_tag")
|
|
251
|
+
if tag is not None:
|
|
252
|
+
return tag
|
|
253
|
+
|
|
254
|
+
guess = _guess_task_from_config(config)
|
|
255
|
+
if guess is not None:
|
|
256
|
+
return guess
|
|
257
|
+
data = load_architecture_task()
|
|
258
|
+
if subfolder:
|
|
259
|
+
full_id = f"{model_id}//{subfolder}"
|
|
260
|
+
if full_id in data:
|
|
261
|
+
return data[full_id]
|
|
262
|
+
if model_id in data:
|
|
263
|
+
return data[model_id]
|
|
264
|
+
arch = architecture_from_config(config)
|
|
265
|
+
if arch is None:
|
|
266
|
+
if model_id.startswith("google/bert_"):
|
|
267
|
+
return "fill-mask"
|
|
268
|
+
assert arch is not None, (
|
|
269
|
+
f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
|
|
270
|
+
f"config={config}. The task can be added in "
|
|
271
|
+
f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__``."
|
|
272
|
+
)
|
|
273
|
+
return task_from_arch(arch, default_value=default_value)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def task_from_tags(tags: Union[str, List[str]]) -> str:
|
|
277
|
+
"""
|
|
278
|
+
Guesses the task from the list of tags.
|
|
279
|
+
If given by a string, ``|`` should be the separator.
|
|
280
|
+
"""
|
|
281
|
+
if isinstance(tags, str):
|
|
282
|
+
tags = tags.split("|")
|
|
283
|
+
stags = set(tags)
|
|
284
|
+
for task in __data_tasks__:
|
|
285
|
+
if task in stags:
|
|
286
|
+
return task
|
|
287
|
+
raise ValueError(f"Unable to guess the task from tags={tags!r}")
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def enumerate_model_list(
|
|
291
|
+
n: int = 50,
|
|
292
|
+
pipeline_tag: Optional[str] = None,
|
|
293
|
+
search: Optional[str] = None,
|
|
294
|
+
dump: Optional[str] = None,
|
|
295
|
+
filter: Optional[Union[str, List[str]]] = None,
|
|
296
|
+
verbose: int = 0,
|
|
297
|
+
):
|
|
298
|
+
"""
|
|
299
|
+
Enumerates models coming from :epkg:`huggingface_hub`.
|
|
300
|
+
|
|
301
|
+
:param n: number of models to retrieve (-1 for all)
|
|
302
|
+
:param pipeline_tag: see :meth:`huggingface_hub.HfApi.list_models`
|
|
303
|
+
:param search: see :meth:`huggingface_hub.HfApi.list_models`
|
|
304
|
+
:param filter: see :meth:`huggingface_hub.HfApi.list_models`
|
|
305
|
+
:param dump: dumps the result in this csv file
|
|
306
|
+
:param verbose: show progress
|
|
307
|
+
"""
|
|
308
|
+
api = HfApi()
|
|
309
|
+
models = api.list_models(
|
|
310
|
+
pipeline_tag=pipeline_tag,
|
|
311
|
+
search=search,
|
|
312
|
+
full=True,
|
|
313
|
+
filter=filter,
|
|
314
|
+
limit=n if n > 0 else None,
|
|
315
|
+
)
|
|
316
|
+
seen = 0
|
|
317
|
+
found = 0
|
|
318
|
+
|
|
319
|
+
if dump:
|
|
320
|
+
with open(dump, "w") as f:
|
|
321
|
+
f.write(
|
|
322
|
+
",".join(
|
|
323
|
+
[
|
|
324
|
+
"id",
|
|
325
|
+
"model_name",
|
|
326
|
+
"author",
|
|
327
|
+
"created_at",
|
|
328
|
+
"last_modified",
|
|
329
|
+
"downloads",
|
|
330
|
+
"downloads_all_time",
|
|
331
|
+
"likes",
|
|
332
|
+
"trending_score",
|
|
333
|
+
"private",
|
|
334
|
+
"gated",
|
|
335
|
+
"tags",
|
|
336
|
+
]
|
|
337
|
+
)
|
|
338
|
+
)
|
|
339
|
+
f.write("\n")
|
|
340
|
+
|
|
341
|
+
for m in models:
|
|
342
|
+
seen += 1 # noqa: SIM113
|
|
343
|
+
if verbose and seen % 1000 == 0:
|
|
344
|
+
print(f"[enumerate_model_list] {seen} models, found {found}")
|
|
345
|
+
if verbose > 1:
|
|
346
|
+
print(
|
|
347
|
+
f"[enumerate_model_list] id={m.id!r}, "
|
|
348
|
+
f"library={m.library_name!r}, task={m.task!r}"
|
|
349
|
+
)
|
|
350
|
+
with open(dump, "a") as f: # type: ignore
|
|
351
|
+
f.write(
|
|
352
|
+
",".join(
|
|
353
|
+
map(
|
|
354
|
+
str,
|
|
355
|
+
[
|
|
356
|
+
m.id,
|
|
357
|
+
getattr(m, "model_name", "") or "",
|
|
358
|
+
m.author or "",
|
|
359
|
+
str(m.created_at or "").split(" ")[0],
|
|
360
|
+
str(m.last_modified or "").split(" ")[0],
|
|
361
|
+
m.downloads or "",
|
|
362
|
+
m.downloads_all_time or "",
|
|
363
|
+
m.likes or "",
|
|
364
|
+
m.trending_score or "",
|
|
365
|
+
m.private or "",
|
|
366
|
+
m.gated or "",
|
|
367
|
+
(
|
|
368
|
+
("|".join(m.tags)).replace(",", "_").replace(" ", "_")
|
|
369
|
+
if m.tags
|
|
370
|
+
else ""
|
|
371
|
+
),
|
|
372
|
+
],
|
|
373
|
+
)
|
|
374
|
+
)
|
|
375
|
+
)
|
|
376
|
+
f.write("\n")
|
|
377
|
+
yield m
|
|
378
|
+
found += 1 # noqa: SIM113
|
|
379
|
+
if n >= 0:
|
|
380
|
+
n -= 1
|
|
381
|
+
if n == 0:
|
|
382
|
+
break
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def download_code_modelid(
|
|
386
|
+
model_id: str, verbose: int = 0, add_path_to_sys_path: bool = True
|
|
387
|
+
) -> List[str]:
|
|
388
|
+
"""
|
|
389
|
+
Downloads the code for a given model id.
|
|
390
|
+
|
|
391
|
+
:param model_id: model id
|
|
392
|
+
:param verbose: verbosity
|
|
393
|
+
:param add_path_to_sys_path: add folder where the files are downloaded to sys.path
|
|
394
|
+
:return: list of downloaded files
|
|
395
|
+
"""
|
|
396
|
+
if verbose:
|
|
397
|
+
print(f"[download_code_modelid] retrieve file list for {model_id!r}")
|
|
398
|
+
files = list_repo_files(model_id)
|
|
399
|
+
pyfiles = [name for name in files if os.path.splitext(name)[-1] == ".py"]
|
|
400
|
+
if verbose:
|
|
401
|
+
print(f"[download_code_modelid] python files {pyfiles}")
|
|
402
|
+
absfiles = []
|
|
403
|
+
paths = set()
|
|
404
|
+
for i, name in enumerate(pyfiles):
|
|
405
|
+
if verbose:
|
|
406
|
+
print(f"[download_code_modelid] download file {i+1}/{len(pyfiles)}: {name!r}")
|
|
407
|
+
r = hf_hub_download(repo_id=model_id, filename=name)
|
|
408
|
+
p = os.path.split(r)[0]
|
|
409
|
+
paths.add(p)
|
|
410
|
+
absfiles.append(r)
|
|
411
|
+
if add_path_to_sys_path:
|
|
412
|
+
for p in paths:
|
|
413
|
+
init = os.path.join(p, "__init__.py")
|
|
414
|
+
if not os.path.exists(init):
|
|
415
|
+
with open(init, "w"):
|
|
416
|
+
pass
|
|
417
|
+
if p in sys.path:
|
|
418
|
+
continue
|
|
419
|
+
if verbose:
|
|
420
|
+
print(f"[download_code_modelid] add {p!r} to 'sys.path'")
|
|
421
|
+
sys.path.insert(0, p)
|
|
422
|
+
return absfiles
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import functools
|
|
3
|
+
import textwrap
|
|
4
|
+
from typing import Dict, List
|
|
5
|
+
|
|
6
|
+
__date__ = "2025-06-21"
|
|
7
|
+
|
|
8
|
+
__data_arch_values__ = {"ResNetForImageClassification": dict(image_size=224)}
|
|
9
|
+
|
|
10
|
+
__data_arch__ = textwrap.dedent(
|
|
11
|
+
"""
|
|
12
|
+
architecture,task
|
|
13
|
+
ASTModel,feature-extraction
|
|
14
|
+
AutoencoderKL,image-to-image
|
|
15
|
+
AlbertModel,feature-extraction
|
|
16
|
+
BeitForImageClassification,image-classification
|
|
17
|
+
BartForConditionalGeneration,summarization
|
|
18
|
+
BartModel,feature-extraction
|
|
19
|
+
BertForMaskedLM,fill-mask
|
|
20
|
+
BertForSequenceClassification,text-classification
|
|
21
|
+
BertModel,sentence-similarity
|
|
22
|
+
BigBirdModel,feature-extraction
|
|
23
|
+
BlenderbotModel,feature-extraction
|
|
24
|
+
BloomModel,feature-extraction
|
|
25
|
+
CLIPModel,zero-shot-image-classification
|
|
26
|
+
CLIPTextModel,feature-extraction
|
|
27
|
+
CLIPVisionModel,feature-extraction
|
|
28
|
+
CamembertModel,feature-extraction
|
|
29
|
+
CodeGenModel,feature-extraction
|
|
30
|
+
ConvBertModel,feature-extraction
|
|
31
|
+
ConvNextForImageClassification,image-classification
|
|
32
|
+
ConvNextV2Model,image-feature-extraction
|
|
33
|
+
CosmosTransformer3DModel,image-to-video
|
|
34
|
+
CvtModel,feature-extraction
|
|
35
|
+
DPTModel,image-feature-extraction
|
|
36
|
+
Data2VecAudioModel,feature-extraction
|
|
37
|
+
Data2VecTextModel,feature-extraction
|
|
38
|
+
Data2VecVisionModel,image-feature-extraction
|
|
39
|
+
DebertaModel,feature-extraction
|
|
40
|
+
DebertaV2Model,feature-extraction
|
|
41
|
+
DecisionTransformerModel,reinforcement-learning
|
|
42
|
+
DeepseekV3ForCausalLM,text-generation
|
|
43
|
+
DeiTModel,image-feature-extraction
|
|
44
|
+
DetrModel,image-feature-extraction
|
|
45
|
+
Dinov2Model,image-feature-extraction
|
|
46
|
+
DistilBertForSequenceClassification,text-classification
|
|
47
|
+
DistilBertModel,feature-extraction
|
|
48
|
+
DonutSwinModel,feature-extraction
|
|
49
|
+
ElectraModel,feature-extraction
|
|
50
|
+
EsmModel,feature-extraction
|
|
51
|
+
FalconMambaForCausalLM,text-generation
|
|
52
|
+
GLPNModel,image-feature-extraction
|
|
53
|
+
GPT2LMHeadModel,text-generation
|
|
54
|
+
GPTBigCodeModel,feature-extraction
|
|
55
|
+
GPTJModel,feature-extraction
|
|
56
|
+
GPTNeoModel,feature-extraction
|
|
57
|
+
GPTNeoXForCausalLM,text-generation
|
|
58
|
+
GptOssForCausalLM,text-generation
|
|
59
|
+
GemmaForCausalLM,text-generation
|
|
60
|
+
Gemma2ForCausalLM,text-generation
|
|
61
|
+
Gemma3ForConditionalGeneration,image-text-to-text
|
|
62
|
+
Gemma3ForCausalLM,text-generation
|
|
63
|
+
Glm4vMoeForConditionalGeneration,image-text-to-text
|
|
64
|
+
GraniteForCausalLM,text-generation
|
|
65
|
+
GroupViTModel,feature-extraction
|
|
66
|
+
HieraForImageClassification,image-classification
|
|
67
|
+
HubertModel,feature-extraction
|
|
68
|
+
IBertModel,feature-extraction
|
|
69
|
+
IdeficsForVisionText2Text,image-text-to-text
|
|
70
|
+
ImageGPTModel,image-feature-extraction
|
|
71
|
+
LayoutLMModel,feature-extraction
|
|
72
|
+
LayoutLMv3Model,feature-extraction
|
|
73
|
+
LevitModel,image-feature-extraction
|
|
74
|
+
LiltModel,feature-extraction
|
|
75
|
+
LlamaForCausalLM,text-generation
|
|
76
|
+
LongT5Model,feature-extraction
|
|
77
|
+
LongformerModel,feature-extraction
|
|
78
|
+
MCTCTModel,feature-extraction
|
|
79
|
+
MPNetForMaskedLM,sentence-similarity
|
|
80
|
+
MPNetModel,feature-extraction
|
|
81
|
+
MT5Model,feature-extraction
|
|
82
|
+
MarianMTModel,text2text-generation
|
|
83
|
+
MarkupLMModel,feature-extraction
|
|
84
|
+
MaskFormerForInstanceSegmentation,image-segmentation
|
|
85
|
+
MegatronBertModel,feature-extraction
|
|
86
|
+
MgpstrForSceneTextRecognition,feature-extraction
|
|
87
|
+
MistralForCausalLM,text-generation
|
|
88
|
+
MobileBertModel,feature-extraction
|
|
89
|
+
MobileNetV1Model,image-feature-extraction
|
|
90
|
+
MobileNetV2Model,image-feature-extraction
|
|
91
|
+
mobilenetv3_small_100,image-classification
|
|
92
|
+
MobileViTForImageClassification,image-classification
|
|
93
|
+
ModernBertForMaskedLM,fill-mask
|
|
94
|
+
Phi4MMForCausalLM,MoE
|
|
95
|
+
MoonshineForConditionalGeneration,automatic-speech-recognition
|
|
96
|
+
MptForCausalLM,text-generation
|
|
97
|
+
MusicgenForConditionalGeneration,text-to-audio
|
|
98
|
+
NystromformerModel,feature-extraction
|
|
99
|
+
OPTModel,feature-extraction
|
|
100
|
+
Olmo2ForCausalLM,text-generation
|
|
101
|
+
OlmoForCausalLM,text-generation
|
|
102
|
+
OwlViTModel,feature-extraction
|
|
103
|
+
Owlv2Model,feature-extraction
|
|
104
|
+
PatchTSMixerForPrediction,no-pipeline-tag
|
|
105
|
+
PatchTSTForPrediction,no-pipeline-tag
|
|
106
|
+
PegasusModel,feature-extraction
|
|
107
|
+
Phi3ForCausalLM,text-generation
|
|
108
|
+
PhiForCausalLM,text-generation
|
|
109
|
+
PhiMoEForCausalLM,text-generation
|
|
110
|
+
Pix2StructForConditionalGeneration,image-to-text
|
|
111
|
+
PLBartForConditionalGeneration,text2text-generation
|
|
112
|
+
PoolFormerModel,image-feature-extraction
|
|
113
|
+
PvtForImageClassification,image-classification
|
|
114
|
+
Qwen2ForCausalLM,text-generation
|
|
115
|
+
Qwen2_5_VLForConditionalGeneration,image-text-to-text
|
|
116
|
+
Qwen3MoeForCausalLM,text-generation
|
|
117
|
+
RTDetrForObjectDetection,object-detection
|
|
118
|
+
RegNetModel,image-feature-extraction
|
|
119
|
+
RemBertModel,feature-extraction
|
|
120
|
+
ResNetForImageClassification,image-classification
|
|
121
|
+
RoFormerModel,feature-extraction
|
|
122
|
+
RobertaForMaskedLM,sentence-similarity
|
|
123
|
+
RobertaModel,feature-extraction
|
|
124
|
+
RtDetrV2ForObjectDetection,object-detection
|
|
125
|
+
SEWDModel,feature-extraction
|
|
126
|
+
SEWModel,feature-extraction
|
|
127
|
+
SamModel,mask-generation
|
|
128
|
+
SegformerModel,image-feature-extraction
|
|
129
|
+
SiglipModel,zero-shot-image-classification
|
|
130
|
+
SiglipVisionModel,image-feature-extraction
|
|
131
|
+
Speech2TextModel,feature-extraction
|
|
132
|
+
SpeechT5ForTextToSpeech,text-to-audio
|
|
133
|
+
SplinterModel,feature-extraction
|
|
134
|
+
SqueezeBertModel,feature-extraction
|
|
135
|
+
Swin2SRModel,image-feature-extraction
|
|
136
|
+
SwinModel,image-feature-extraction
|
|
137
|
+
Swinv2Model,image-feature-extraction
|
|
138
|
+
T5ForConditionalGeneration,text2text-generation
|
|
139
|
+
TableTransformerModel,image-feature-extraction
|
|
140
|
+
TableTransformerForObjectDetection,object-detection
|
|
141
|
+
UNet2DConditionModel,text-to-image
|
|
142
|
+
UniSpeechForSequenceClassification,audio-classification
|
|
143
|
+
ViTForImageClassification,image-classification
|
|
144
|
+
ViTMAEModel,image-feature-extraction
|
|
145
|
+
ViTMSNForImageClassification,image-classification
|
|
146
|
+
VisionEncoderDecoderModel,document-question-answering
|
|
147
|
+
VitPoseForPoseEstimation,keypoint-detection
|
|
148
|
+
VitsModel,text-to-audio
|
|
149
|
+
Wav2Vec2ConformerForCTC,automatic-speech-recognition
|
|
150
|
+
Wav2Vec2Model,feature-extraction
|
|
151
|
+
WhisperForConditionalGeneration,automatic-speech-recognition
|
|
152
|
+
XLMModel,feature-extraction
|
|
153
|
+
XLMRobertaForCausalLM,text-generation
|
|
154
|
+
XLMRobertaForMaskedLM,fill-mask
|
|
155
|
+
XLMRobertaModel,sentence-similarity
|
|
156
|
+
Wav2Vec2ForCTC,automatic-speech-recognition
|
|
157
|
+
YolosForObjectDetection,object-detection
|
|
158
|
+
YolosModel,image-feature-extraction
|
|
159
|
+
Alibaba-NLP/gte-large-en-v1.5,sentence-similarity
|
|
160
|
+
emilyalsentzer/Bio_ClinicalBERT,fill-mask
|
|
161
|
+
nvidia/Cosmos-Predict2-2B-Video2World//transformer,image-to-video"""
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
__data_tasks__ = [
|
|
165
|
+
"audio-classification",
|
|
166
|
+
"automatic-speech-recognition",
|
|
167
|
+
"document-question-answering",
|
|
168
|
+
"feature-extraction",
|
|
169
|
+
"fill-mask",
|
|
170
|
+
"image-classification",
|
|
171
|
+
"image-feature-extraction",
|
|
172
|
+
"image-segmentation",
|
|
173
|
+
"image-text-to-text",
|
|
174
|
+
"image-to-text",
|
|
175
|
+
"keypoint-detection",
|
|
176
|
+
"mask-generation",
|
|
177
|
+
"no-pipeline-tag",
|
|
178
|
+
"object-detection",
|
|
179
|
+
"reinforcement-learning",
|
|
180
|
+
"sentence-similarity",
|
|
181
|
+
"summarization",
|
|
182
|
+
"text-classification",
|
|
183
|
+
"text-generation",
|
|
184
|
+
"text-to-image",
|
|
185
|
+
"text-to-audio",
|
|
186
|
+
"text2text-generation",
|
|
187
|
+
"zero-shot-image-classification",
|
|
188
|
+
]
|
|
189
|
+
|
|
190
|
+
__models_testing__ = """
|
|
191
|
+
hf-internal-testing/tiny-random-BeitForImageClassification
|
|
192
|
+
hf-internal-testing/tiny-random-convnext
|
|
193
|
+
fxmarty/tiny-random-GemmaForCausalLM
|
|
194
|
+
hf-internal-testing/tiny-random-GPTNeoXForCausalLM
|
|
195
|
+
hf-internal-testing/tiny-random-GraniteForCausalLM
|
|
196
|
+
hf-internal-testing/tiny-random-HieraForImageClassification
|
|
197
|
+
fxmarty/tiny-llama-fast-tokenizer
|
|
198
|
+
sshleifer/tiny-marian-en-de
|
|
199
|
+
hf-internal-testing/tiny-random-MaskFormerForInstanceSegmentation
|
|
200
|
+
echarlaix/tiny-random-mistral
|
|
201
|
+
hf-internal-testing/tiny-random-mobilevit
|
|
202
|
+
hf-internal-testing/tiny-random-MoonshineForConditionalGeneration
|
|
203
|
+
hf-internal-testing/tiny-random-OlmoForCausalLM
|
|
204
|
+
hf-internal-testing/tiny-random-Olmo2ForCausalLM
|
|
205
|
+
echarlaix/tiny-random-PhiForCausalLM
|
|
206
|
+
Xenova/tiny-random-Phi3ForCausalLM
|
|
207
|
+
fxmarty/pix2struct-tiny-random
|
|
208
|
+
fxmarty/tiny-dummy-qwen2
|
|
209
|
+
hf-internal-testing/tiny-random-ViTMSNForImageClassification
|
|
210
|
+
hf-internal-testing/tiny-random-YolosModel
|
|
211
|
+
hf-internal-testing/tiny-xlm-roberta
|
|
212
|
+
HuggingFaceM4/tiny-random-idefics
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@functools.cache
|
|
217
|
+
def load_models_testing() -> List[str]:
|
|
218
|
+
"""Returns model ids for testing."""
|
|
219
|
+
return [_.strip() for _ in __models_testing__.split("\n") if _.strip()]
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@functools.cache
|
|
223
|
+
def load_architecture_task() -> Dict[str, str]:
|
|
224
|
+
"""
|
|
225
|
+
Returns a dictionary mapping architectures to tasks.
|
|
226
|
+
|
|
227
|
+
import pprint
|
|
228
|
+
from onnx_diagnostic.torch_models.hghub.hub_data import load_architecture_task
|
|
229
|
+
pprint.pprint(load_architecture_task())
|
|
230
|
+
"""
|
|
231
|
+
import pandas
|
|
232
|
+
|
|
233
|
+
df = pandas.read_csv(io.StringIO(__data_arch__))
|
|
234
|
+
return dict(zip(list(df["architecture"]), list(df["task"])))
|