onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,422 @@
1
+ import copy
2
+ import functools
3
+ import json
4
+ import os
5
+ import pprint
6
+ import sys
7
+ from typing import Any, Dict, List, Optional, Union
8
+ import transformers
9
+ from huggingface_hub import HfApi, model_info, hf_hub_download, list_repo_files
10
+ from ...helpers.config_helper import update_config
11
+ from . import hub_data_cached_configs
12
+ from .hub_data import __date__, __data_tasks__, load_architecture_task, __data_arch_values__
13
+
14
+
15
+ @functools.cache
16
+ def get_architecture_default_values(architecture: str):
17
+ """
18
+ The configuration may miss information to build the dummy inputs.
19
+ This information returns the missing pieces.
20
+ """
21
+ assert architecture in __data_arch_values__, (
22
+ f"No known default values for {architecture!r}, "
23
+ f"expecting one architecture in {', '.join(sorted(__data_arch_values__))}"
24
+ )
25
+ return __data_arch_values__[architecture]
26
+
27
+
28
+ @functools.cache
29
+ def _retrieve_cached_configurations() -> Dict[str, transformers.PretrainedConfig]:
30
+ res = {}
31
+ for k, v in hub_data_cached_configs.__dict__.items():
32
+ if k.startswith("_ccached_"):
33
+ doc = v.__doc__
34
+ res[doc] = v
35
+ return res
36
+
37
+
38
+ def get_cached_configuration(
39
+ name: str, exc: bool = False, **kwargs
40
+ ) -> Optional[transformers.PretrainedConfig]:
41
+ """
42
+ Returns cached configuration to avoid having to many accesses to internet.
43
+ It returns None if not Cache. The list of cached models follows.
44
+ If *exc* is True or if environment variable ``NOHTTP`` is defined,
45
+ the function raises an exception if *name* is not found.
46
+
47
+ .. runpython::
48
+
49
+ import pprint
50
+ from onnx_diagnostic.torch_models.hghub.hub_api import _retrieve_cached_configurations
51
+
52
+ configs = _retrieve_cached_configurations()
53
+ pprint.pprint(sorted(configs))
54
+ """
55
+ cached = _retrieve_cached_configurations()
56
+ assert cached, "no cached configuration, which is weird"
57
+ if name in cached:
58
+ conf = cached[name]()
59
+ if kwargs:
60
+ conf = copy.deepcopy(conf)
61
+ update_config(conf, kwargs)
62
+ return conf
63
+ assert not exc and not os.environ.get("NOHTTP", ""), (
64
+ f"Unable to find {name!r} (exc={exc}, "
65
+ f"NOHTTP={os.environ.get('NOHTTP', '')!r}) "
66
+ f"in {pprint.pformat(sorted(cached))}"
67
+ )
68
+ return None
69
+
70
+
71
+ def get_pretrained_config(
72
+ model_id: str,
73
+ trust_remote_code: bool = True,
74
+ use_preinstalled: bool = True,
75
+ subfolder: Optional[str] = None,
76
+ use_only_preinstalled: bool = False,
77
+ **kwargs,
78
+ ) -> Any:
79
+ """
80
+ Returns the config for a model_id.
81
+
82
+ :param model_id: model id
83
+ :param trust_remote_code: trust_remote_code,
84
+ see :meth:`transformers.AutoConfig.from_pretrained`
85
+ :param use_preinstalled: if use_preinstalled, uses this version to avoid
86
+ accessing the network, if available, it is returned by
87
+ :func:`get_cached_configuration`, the cached list is mostly for
88
+ unit tests
89
+ :param subfolder: subfolder for the given model id
90
+ :param use_only_preinstalled: if True, raises an exception if not preinstalled
91
+ :param kwargs: additional kwargs
92
+ :return: a configuration
93
+ """
94
+ if use_preinstalled:
95
+ conf = get_cached_configuration(
96
+ model_id, exc=use_only_preinstalled, subfolder=subfolder, **kwargs
97
+ )
98
+ if conf is not None:
99
+ return conf
100
+ assert not use_only_preinstalled, (
101
+ f"Inconsistencies: use_only_preinstalled={use_only_preinstalled}, "
102
+ f"use_preinstalled={use_preinstalled!r}"
103
+ )
104
+ if subfolder:
105
+ try:
106
+ return transformers.AutoConfig.from_pretrained(
107
+ model_id, trust_remote_code=trust_remote_code, subfolder=subfolder, **kwargs
108
+ )
109
+ except ValueError:
110
+ # Then we try to download it.
111
+ config = hf_hub_download(
112
+ model_id, filename="config.json", subfolder=subfolder, **kwargs
113
+ )
114
+ try:
115
+ return transformers.AutoConfig.from_pretrained(
116
+ config, trust_remote_code=trust_remote_code, **kwargs
117
+ )
118
+ except ValueError:
119
+ # Diffusers uses a dictionayr.
120
+ with open(config, "r") as f:
121
+ return json.load(f)
122
+ return transformers.AutoConfig.from_pretrained(
123
+ model_id, trust_remote_code=trust_remote_code, **kwargs
124
+ )
125
+
126
+
127
+ def get_model_info(model_id) -> Any:
128
+ """Returns the model info for a model_id."""
129
+ return model_info(model_id)
130
+
131
+
132
+ def _guess_task_from_config(config: Any) -> Optional[str]:
133
+ """Tries to infer a task from the configuration."""
134
+ if hasattr(config, "bbox_loss_coefficient") and hasattr(config, "giou_loss_coefficient"):
135
+ return "object-detection"
136
+ if hasattr(config, "architecture") and config.architecture:
137
+ return task_from_arch(config.architecture)
138
+ return None
139
+
140
+
141
+ @functools.cache
142
+ def task_from_arch(
143
+ arch: str,
144
+ default_value: Optional[str] = None,
145
+ model_id: Optional[str] = None,
146
+ subfolder: Optional[str] = None,
147
+ ) -> str:
148
+ """
149
+ This function relies on stored information. That information needs to be refresh.
150
+
151
+ :param arch: architecture name
152
+ :param default_value: default value in case the task cannot be determined
153
+ :param model_id: unused unless the architecture does not help.
154
+ :param subfolder: subfolder
155
+ :return: task
156
+
157
+ .. runpython::
158
+
159
+ from onnx_diagnostic.torch_models.hghub.hub_data import __date__
160
+ print("last refresh", __date__)
161
+
162
+ List of supported architectures, see
163
+ :func:`load_architecture_task
164
+ <onnx_diagnostic.torch_models.hghub.hub_data.load_architecture_task>`.
165
+ """
166
+ data = load_architecture_task()
167
+ if arch not in data and model_id:
168
+ # Let's try with the model id.
169
+ return task_from_id(model_id, subfolder=subfolder)
170
+ if default_value is not None:
171
+ return data.get(arch, default_value)
172
+ assert arch in data, (
173
+ f"Architecture {arch!r} is unknown, last refresh in {__date__}. "
174
+ f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__`` "
175
+ f"needs to be updated (model_id={(model_id or '?')!r})."
176
+ )
177
+ return data[arch]
178
+
179
+
180
+ def _trygetattr(config, attname):
181
+ try:
182
+ return getattr(config, attname)
183
+ except AttributeError:
184
+ return None
185
+
186
+
187
+ def architecture_from_config(config) -> Optional[str]:
188
+ """Guesses the architecture (class) of the model described by this config."""
189
+ if isinstance(config, dict):
190
+ if "_class_name" in config:
191
+ return config["_class_name"]
192
+ if "architecture" in config:
193
+ return config["architecture"]
194
+ if config.get("architectures", []):
195
+ return config["architectures"][0]
196
+ if hasattr(config, "_class_name"):
197
+ return config._class_name
198
+ if hasattr(config, "architecture"):
199
+ return config.architecture
200
+ if hasattr(config, "architectures") and config.architectures:
201
+ return config.architectures[0]
202
+ if hasattr(config, "__dict__"):
203
+ if "_class_name" in config.__dict__:
204
+ return config.__dict__["_class_name"]
205
+ if "architecture" in config.__dict__:
206
+ return config.__dict__["architecture"]
207
+ if config.__dict__.get("architectures", []):
208
+ return config.__dict__["architectures"][0]
209
+ return None
210
+
211
+
212
+ def find_package_source(config) -> Optional[str]:
213
+ """Guesses the package the class models from."""
214
+ if isinstance(config, dict):
215
+ if "_diffusers_version" in config:
216
+ return "diffusers"
217
+ if hasattr(config, "_diffusers_version"):
218
+ return "diffusers"
219
+ if hasattr(config, "__dict__"):
220
+ if "_diffusers_version" in config.__dict__:
221
+ return "diffusers"
222
+ return "transformers"
223
+
224
+
225
+ def task_from_id(
226
+ model_id: str,
227
+ default_value: Optional[str] = None,
228
+ pretrained: bool = False,
229
+ fall_back_to_pretrained: bool = True,
230
+ subfolder: Optional[str] = None,
231
+ ) -> str:
232
+ """
233
+ Returns the task attached to a model id.
234
+
235
+ :param model_id: model id
236
+ :param default_value: if specified, the function returns this value
237
+ if the task cannot be determined
238
+ :param pretrained: uses the config
239
+ :param fall_back_to_pretrained: falls back to pretrained config
240
+ :param subfolder: subfolder
241
+ :return: task
242
+ """
243
+ if not pretrained:
244
+ try:
245
+ transformers.pipelines.get_task(model_id)
246
+ except RuntimeError:
247
+ if not fall_back_to_pretrained:
248
+ raise
249
+ config = get_pretrained_config(model_id, subfolder=subfolder)
250
+ tag = _trygetattr(config, "pipeline_tag")
251
+ if tag is not None:
252
+ return tag
253
+
254
+ guess = _guess_task_from_config(config)
255
+ if guess is not None:
256
+ return guess
257
+ data = load_architecture_task()
258
+ if subfolder:
259
+ full_id = f"{model_id}//{subfolder}"
260
+ if full_id in data:
261
+ return data[full_id]
262
+ if model_id in data:
263
+ return data[model_id]
264
+ arch = architecture_from_config(config)
265
+ if arch is None:
266
+ if model_id.startswith("google/bert_"):
267
+ return "fill-mask"
268
+ assert arch is not None, (
269
+ f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
270
+ f"config={config}. The task can be added in "
271
+ f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__``."
272
+ )
273
+ return task_from_arch(arch, default_value=default_value)
274
+
275
+
276
+ def task_from_tags(tags: Union[str, List[str]]) -> str:
277
+ """
278
+ Guesses the task from the list of tags.
279
+ If given by a string, ``|`` should be the separator.
280
+ """
281
+ if isinstance(tags, str):
282
+ tags = tags.split("|")
283
+ stags = set(tags)
284
+ for task in __data_tasks__:
285
+ if task in stags:
286
+ return task
287
+ raise ValueError(f"Unable to guess the task from tags={tags!r}")
288
+
289
+
290
+ def enumerate_model_list(
291
+ n: int = 50,
292
+ pipeline_tag: Optional[str] = None,
293
+ search: Optional[str] = None,
294
+ dump: Optional[str] = None,
295
+ filter: Optional[Union[str, List[str]]] = None,
296
+ verbose: int = 0,
297
+ ):
298
+ """
299
+ Enumerates models coming from :epkg:`huggingface_hub`.
300
+
301
+ :param n: number of models to retrieve (-1 for all)
302
+ :param pipeline_tag: see :meth:`huggingface_hub.HfApi.list_models`
303
+ :param search: see :meth:`huggingface_hub.HfApi.list_models`
304
+ :param filter: see :meth:`huggingface_hub.HfApi.list_models`
305
+ :param dump: dumps the result in this csv file
306
+ :param verbose: show progress
307
+ """
308
+ api = HfApi()
309
+ models = api.list_models(
310
+ pipeline_tag=pipeline_tag,
311
+ search=search,
312
+ full=True,
313
+ filter=filter,
314
+ limit=n if n > 0 else None,
315
+ )
316
+ seen = 0
317
+ found = 0
318
+
319
+ if dump:
320
+ with open(dump, "w") as f:
321
+ f.write(
322
+ ",".join(
323
+ [
324
+ "id",
325
+ "model_name",
326
+ "author",
327
+ "created_at",
328
+ "last_modified",
329
+ "downloads",
330
+ "downloads_all_time",
331
+ "likes",
332
+ "trending_score",
333
+ "private",
334
+ "gated",
335
+ "tags",
336
+ ]
337
+ )
338
+ )
339
+ f.write("\n")
340
+
341
+ for m in models:
342
+ seen += 1 # noqa: SIM113
343
+ if verbose and seen % 1000 == 0:
344
+ print(f"[enumerate_model_list] {seen} models, found {found}")
345
+ if verbose > 1:
346
+ print(
347
+ f"[enumerate_model_list] id={m.id!r}, "
348
+ f"library={m.library_name!r}, task={m.task!r}"
349
+ )
350
+ with open(dump, "a") as f: # type: ignore
351
+ f.write(
352
+ ",".join(
353
+ map(
354
+ str,
355
+ [
356
+ m.id,
357
+ getattr(m, "model_name", "") or "",
358
+ m.author or "",
359
+ str(m.created_at or "").split(" ")[0],
360
+ str(m.last_modified or "").split(" ")[0],
361
+ m.downloads or "",
362
+ m.downloads_all_time or "",
363
+ m.likes or "",
364
+ m.trending_score or "",
365
+ m.private or "",
366
+ m.gated or "",
367
+ (
368
+ ("|".join(m.tags)).replace(",", "_").replace(" ", "_")
369
+ if m.tags
370
+ else ""
371
+ ),
372
+ ],
373
+ )
374
+ )
375
+ )
376
+ f.write("\n")
377
+ yield m
378
+ found += 1 # noqa: SIM113
379
+ if n >= 0:
380
+ n -= 1
381
+ if n == 0:
382
+ break
383
+
384
+
385
+ def download_code_modelid(
386
+ model_id: str, verbose: int = 0, add_path_to_sys_path: bool = True
387
+ ) -> List[str]:
388
+ """
389
+ Downloads the code for a given model id.
390
+
391
+ :param model_id: model id
392
+ :param verbose: verbosity
393
+ :param add_path_to_sys_path: add folder where the files are downloaded to sys.path
394
+ :return: list of downloaded files
395
+ """
396
+ if verbose:
397
+ print(f"[download_code_modelid] retrieve file list for {model_id!r}")
398
+ files = list_repo_files(model_id)
399
+ pyfiles = [name for name in files if os.path.splitext(name)[-1] == ".py"]
400
+ if verbose:
401
+ print(f"[download_code_modelid] python files {pyfiles}")
402
+ absfiles = []
403
+ paths = set()
404
+ for i, name in enumerate(pyfiles):
405
+ if verbose:
406
+ print(f"[download_code_modelid] download file {i+1}/{len(pyfiles)}: {name!r}")
407
+ r = hf_hub_download(repo_id=model_id, filename=name)
408
+ p = os.path.split(r)[0]
409
+ paths.add(p)
410
+ absfiles.append(r)
411
+ if add_path_to_sys_path:
412
+ for p in paths:
413
+ init = os.path.join(p, "__init__.py")
414
+ if not os.path.exists(init):
415
+ with open(init, "w"):
416
+ pass
417
+ if p in sys.path:
418
+ continue
419
+ if verbose:
420
+ print(f"[download_code_modelid] add {p!r} to 'sys.path'")
421
+ sys.path.insert(0, p)
422
+ return absfiles
@@ -0,0 +1,234 @@
1
+ import io
2
+ import functools
3
+ import textwrap
4
+ from typing import Dict, List
5
+
6
+ __date__ = "2025-06-21"
7
+
8
+ __data_arch_values__ = {"ResNetForImageClassification": dict(image_size=224)}
9
+
10
+ __data_arch__ = textwrap.dedent(
11
+ """
12
+ architecture,task
13
+ ASTModel,feature-extraction
14
+ AutoencoderKL,image-to-image
15
+ AlbertModel,feature-extraction
16
+ BeitForImageClassification,image-classification
17
+ BartForConditionalGeneration,summarization
18
+ BartModel,feature-extraction
19
+ BertForMaskedLM,fill-mask
20
+ BertForSequenceClassification,text-classification
21
+ BertModel,sentence-similarity
22
+ BigBirdModel,feature-extraction
23
+ BlenderbotModel,feature-extraction
24
+ BloomModel,feature-extraction
25
+ CLIPModel,zero-shot-image-classification
26
+ CLIPTextModel,feature-extraction
27
+ CLIPVisionModel,feature-extraction
28
+ CamembertModel,feature-extraction
29
+ CodeGenModel,feature-extraction
30
+ ConvBertModel,feature-extraction
31
+ ConvNextForImageClassification,image-classification
32
+ ConvNextV2Model,image-feature-extraction
33
+ CosmosTransformer3DModel,image-to-video
34
+ CvtModel,feature-extraction
35
+ DPTModel,image-feature-extraction
36
+ Data2VecAudioModel,feature-extraction
37
+ Data2VecTextModel,feature-extraction
38
+ Data2VecVisionModel,image-feature-extraction
39
+ DebertaModel,feature-extraction
40
+ DebertaV2Model,feature-extraction
41
+ DecisionTransformerModel,reinforcement-learning
42
+ DeepseekV3ForCausalLM,text-generation
43
+ DeiTModel,image-feature-extraction
44
+ DetrModel,image-feature-extraction
45
+ Dinov2Model,image-feature-extraction
46
+ DistilBertForSequenceClassification,text-classification
47
+ DistilBertModel,feature-extraction
48
+ DonutSwinModel,feature-extraction
49
+ ElectraModel,feature-extraction
50
+ EsmModel,feature-extraction
51
+ FalconMambaForCausalLM,text-generation
52
+ GLPNModel,image-feature-extraction
53
+ GPT2LMHeadModel,text-generation
54
+ GPTBigCodeModel,feature-extraction
55
+ GPTJModel,feature-extraction
56
+ GPTNeoModel,feature-extraction
57
+ GPTNeoXForCausalLM,text-generation
58
+ GptOssForCausalLM,text-generation
59
+ GemmaForCausalLM,text-generation
60
+ Gemma2ForCausalLM,text-generation
61
+ Gemma3ForConditionalGeneration,image-text-to-text
62
+ Gemma3ForCausalLM,text-generation
63
+ Glm4vMoeForConditionalGeneration,image-text-to-text
64
+ GraniteForCausalLM,text-generation
65
+ GroupViTModel,feature-extraction
66
+ HieraForImageClassification,image-classification
67
+ HubertModel,feature-extraction
68
+ IBertModel,feature-extraction
69
+ IdeficsForVisionText2Text,image-text-to-text
70
+ ImageGPTModel,image-feature-extraction
71
+ LayoutLMModel,feature-extraction
72
+ LayoutLMv3Model,feature-extraction
73
+ LevitModel,image-feature-extraction
74
+ LiltModel,feature-extraction
75
+ LlamaForCausalLM,text-generation
76
+ LongT5Model,feature-extraction
77
+ LongformerModel,feature-extraction
78
+ MCTCTModel,feature-extraction
79
+ MPNetForMaskedLM,sentence-similarity
80
+ MPNetModel,feature-extraction
81
+ MT5Model,feature-extraction
82
+ MarianMTModel,text2text-generation
83
+ MarkupLMModel,feature-extraction
84
+ MaskFormerForInstanceSegmentation,image-segmentation
85
+ MegatronBertModel,feature-extraction
86
+ MgpstrForSceneTextRecognition,feature-extraction
87
+ MistralForCausalLM,text-generation
88
+ MobileBertModel,feature-extraction
89
+ MobileNetV1Model,image-feature-extraction
90
+ MobileNetV2Model,image-feature-extraction
91
+ mobilenetv3_small_100,image-classification
92
+ MobileViTForImageClassification,image-classification
93
+ ModernBertForMaskedLM,fill-mask
94
+ Phi4MMForCausalLM,MoE
95
+ MoonshineForConditionalGeneration,automatic-speech-recognition
96
+ MptForCausalLM,text-generation
97
+ MusicgenForConditionalGeneration,text-to-audio
98
+ NystromformerModel,feature-extraction
99
+ OPTModel,feature-extraction
100
+ Olmo2ForCausalLM,text-generation
101
+ OlmoForCausalLM,text-generation
102
+ OwlViTModel,feature-extraction
103
+ Owlv2Model,feature-extraction
104
+ PatchTSMixerForPrediction,no-pipeline-tag
105
+ PatchTSTForPrediction,no-pipeline-tag
106
+ PegasusModel,feature-extraction
107
+ Phi3ForCausalLM,text-generation
108
+ PhiForCausalLM,text-generation
109
+ PhiMoEForCausalLM,text-generation
110
+ Pix2StructForConditionalGeneration,image-to-text
111
+ PLBartForConditionalGeneration,text2text-generation
112
+ PoolFormerModel,image-feature-extraction
113
+ PvtForImageClassification,image-classification
114
+ Qwen2ForCausalLM,text-generation
115
+ Qwen2_5_VLForConditionalGeneration,image-text-to-text
116
+ Qwen3MoeForCausalLM,text-generation
117
+ RTDetrForObjectDetection,object-detection
118
+ RegNetModel,image-feature-extraction
119
+ RemBertModel,feature-extraction
120
+ ResNetForImageClassification,image-classification
121
+ RoFormerModel,feature-extraction
122
+ RobertaForMaskedLM,sentence-similarity
123
+ RobertaModel,feature-extraction
124
+ RtDetrV2ForObjectDetection,object-detection
125
+ SEWDModel,feature-extraction
126
+ SEWModel,feature-extraction
127
+ SamModel,mask-generation
128
+ SegformerModel,image-feature-extraction
129
+ SiglipModel,zero-shot-image-classification
130
+ SiglipVisionModel,image-feature-extraction
131
+ Speech2TextModel,feature-extraction
132
+ SpeechT5ForTextToSpeech,text-to-audio
133
+ SplinterModel,feature-extraction
134
+ SqueezeBertModel,feature-extraction
135
+ Swin2SRModel,image-feature-extraction
136
+ SwinModel,image-feature-extraction
137
+ Swinv2Model,image-feature-extraction
138
+ T5ForConditionalGeneration,text2text-generation
139
+ TableTransformerModel,image-feature-extraction
140
+ TableTransformerForObjectDetection,object-detection
141
+ UNet2DConditionModel,text-to-image
142
+ UniSpeechForSequenceClassification,audio-classification
143
+ ViTForImageClassification,image-classification
144
+ ViTMAEModel,image-feature-extraction
145
+ ViTMSNForImageClassification,image-classification
146
+ VisionEncoderDecoderModel,document-question-answering
147
+ VitPoseForPoseEstimation,keypoint-detection
148
+ VitsModel,text-to-audio
149
+ Wav2Vec2ConformerForCTC,automatic-speech-recognition
150
+ Wav2Vec2Model,feature-extraction
151
+ WhisperForConditionalGeneration,automatic-speech-recognition
152
+ XLMModel,feature-extraction
153
+ XLMRobertaForCausalLM,text-generation
154
+ XLMRobertaForMaskedLM,fill-mask
155
+ XLMRobertaModel,sentence-similarity
156
+ Wav2Vec2ForCTC,automatic-speech-recognition
157
+ YolosForObjectDetection,object-detection
158
+ YolosModel,image-feature-extraction
159
+ Alibaba-NLP/gte-large-en-v1.5,sentence-similarity
160
+ emilyalsentzer/Bio_ClinicalBERT,fill-mask
161
+ nvidia/Cosmos-Predict2-2B-Video2World//transformer,image-to-video"""
162
+ )
163
+
164
+ __data_tasks__ = [
165
+ "audio-classification",
166
+ "automatic-speech-recognition",
167
+ "document-question-answering",
168
+ "feature-extraction",
169
+ "fill-mask",
170
+ "image-classification",
171
+ "image-feature-extraction",
172
+ "image-segmentation",
173
+ "image-text-to-text",
174
+ "image-to-text",
175
+ "keypoint-detection",
176
+ "mask-generation",
177
+ "no-pipeline-tag",
178
+ "object-detection",
179
+ "reinforcement-learning",
180
+ "sentence-similarity",
181
+ "summarization",
182
+ "text-classification",
183
+ "text-generation",
184
+ "text-to-image",
185
+ "text-to-audio",
186
+ "text2text-generation",
187
+ "zero-shot-image-classification",
188
+ ]
189
+
190
+ __models_testing__ = """
191
+ hf-internal-testing/tiny-random-BeitForImageClassification
192
+ hf-internal-testing/tiny-random-convnext
193
+ fxmarty/tiny-random-GemmaForCausalLM
194
+ hf-internal-testing/tiny-random-GPTNeoXForCausalLM
195
+ hf-internal-testing/tiny-random-GraniteForCausalLM
196
+ hf-internal-testing/tiny-random-HieraForImageClassification
197
+ fxmarty/tiny-llama-fast-tokenizer
198
+ sshleifer/tiny-marian-en-de
199
+ hf-internal-testing/tiny-random-MaskFormerForInstanceSegmentation
200
+ echarlaix/tiny-random-mistral
201
+ hf-internal-testing/tiny-random-mobilevit
202
+ hf-internal-testing/tiny-random-MoonshineForConditionalGeneration
203
+ hf-internal-testing/tiny-random-OlmoForCausalLM
204
+ hf-internal-testing/tiny-random-Olmo2ForCausalLM
205
+ echarlaix/tiny-random-PhiForCausalLM
206
+ Xenova/tiny-random-Phi3ForCausalLM
207
+ fxmarty/pix2struct-tiny-random
208
+ fxmarty/tiny-dummy-qwen2
209
+ hf-internal-testing/tiny-random-ViTMSNForImageClassification
210
+ hf-internal-testing/tiny-random-YolosModel
211
+ hf-internal-testing/tiny-xlm-roberta
212
+ HuggingFaceM4/tiny-random-idefics
213
+ """
214
+
215
+
216
+ @functools.cache
217
+ def load_models_testing() -> List[str]:
218
+ """Returns model ids for testing."""
219
+ return [_.strip() for _ in __models_testing__.split("\n") if _.strip()]
220
+
221
+
222
+ @functools.cache
223
+ def load_architecture_task() -> Dict[str, str]:
224
+ """
225
+ Returns a dictionary mapping architectures to tasks.
226
+
227
+ import pprint
228
+ from onnx_diagnostic.torch_models.hghub.hub_data import load_architecture_task
229
+ pprint.pprint(load_architecture_task())
230
+ """
231
+ import pandas
232
+
233
+ df = pandas.read_csv(io.StringIO(__data_arch__))
234
+ return dict(zip(list(df["architecture"]), list(df["task"])))