onnx-diagnostic 0.2.2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +452 -0
- onnx_diagnostic/doc.py +4 -4
- onnx_diagnostic/export/__init__.py +2 -1
- onnx_diagnostic/export/dynamic_shapes.py +574 -23
- onnx_diagnostic/export/validate.py +170 -0
- onnx_diagnostic/ext_test_case.py +151 -31
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +216 -0
- onnx_diagnostic/helpers/config_helper.py +80 -0
- onnx_diagnostic/{helpers.py → helpers/helper.py} +341 -662
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/onnx_helper.py +921 -0
- onnx_diagnostic/{ort_session.py → helpers/ort_session.py} +4 -3
- onnx_diagnostic/helpers/rt_helper.py +47 -0
- onnx_diagnostic/{torch_test_helper.py → helpers/torch_test_helper.py} +149 -55
- onnx_diagnostic/reference/ops/op_cast_like.py +1 -1
- onnx_diagnostic/reference/ort_evaluator.py +7 -2
- onnx_diagnostic/tasks/__init__.py +48 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +165 -0
- onnx_diagnostic/tasks/fill_mask.py +67 -0
- onnx_diagnostic/tasks/image_classification.py +96 -0
- onnx_diagnostic/tasks/image_text_to_text.py +145 -0
- onnx_diagnostic/tasks/sentence_similarity.py +67 -0
- onnx_diagnostic/tasks/text2text_generation.py +172 -0
- onnx_diagnostic/tasks/text_classification.py +67 -0
- onnx_diagnostic/tasks/text_generation.py +248 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +106 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +111 -146
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +346 -57
- onnx_diagnostic/torch_export_patches/patch_inputs.py +203 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +41 -2
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +39 -49
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +254 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +203 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +3571 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +151 -0
- onnx_diagnostic/torch_models/test_helper.py +1250 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +3 -4
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +3 -4
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/sbs.py +439 -0
- {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.4.0.dist-info}/METADATA +14 -4
- onnx_diagnostic-0.4.0.dist-info/RECORD +86 -0
- {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.4.0.dist-info}/WHEEL +1 -1
- onnx_diagnostic/cache_helpers.py +0 -104
- onnx_diagnostic/onnx_tools.py +0 -260
- onnx_diagnostic-0.2.2.dist-info/RECORD +0 -59
- /onnx_diagnostic/{args.py → helpers/args_helper.py} +0 -0
- {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.4.0.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import importlib
|
|
3
|
+
import inspect
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
|
6
|
+
import transformers
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def check_hasattr(config: Any, *args: Union[str, Tuple[Any, ...]]):
|
|
10
|
+
"""
|
|
11
|
+
Checks the confiugation has all the attributes in ``args``.
|
|
12
|
+
Raises an exception otherwise.
|
|
13
|
+
"""
|
|
14
|
+
for a in args:
|
|
15
|
+
assert isinstance(a, (str, tuple)), f"unexpected type {type(a)} in {args!r}"
|
|
16
|
+
if isinstance(a, str):
|
|
17
|
+
assert (isinstance(config, dict) and a in config) or hasattr(
|
|
18
|
+
config, a
|
|
19
|
+
), f"Missing attribute {a!r} in\n{config}"
|
|
20
|
+
elif isinstance(a, tuple):
|
|
21
|
+
assert any(
|
|
22
|
+
(isinstance(name, str) and hasattr(config, name))
|
|
23
|
+
or all(hasattr(config, _) for _ in name)
|
|
24
|
+
for name in a
|
|
25
|
+
), f"All attributes in {a!r} are missing from\n{config}"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def update_config(config: Any, mkwargs: Dict[str, Any]):
|
|
29
|
+
"""Updates a configuration with different values."""
|
|
30
|
+
for k, v in mkwargs.items():
|
|
31
|
+
if isinstance(v, dict):
|
|
32
|
+
assert hasattr(
|
|
33
|
+
config, k
|
|
34
|
+
), f"missing attribute {k!r} in config={config}, cannot update it with {v}"
|
|
35
|
+
update_config(getattr(config, k), v)
|
|
36
|
+
else:
|
|
37
|
+
setattr(config, k, v)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _pick(config, *atts):
|
|
41
|
+
"""Returns the first value found in the configuration."""
|
|
42
|
+
for a in atts:
|
|
43
|
+
if isinstance(a, str):
|
|
44
|
+
if hasattr(config, a):
|
|
45
|
+
return getattr(config, a)
|
|
46
|
+
elif isinstance(a, tuple):
|
|
47
|
+
if all(hasattr(config, _) for _ in a[1:]):
|
|
48
|
+
return a[0]([getattr(config, _) for _ in a[1:]])
|
|
49
|
+
raise AssertionError(f"Unable to find any of these {atts!r} in {config}")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@functools.cache
|
|
53
|
+
def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[type]:
|
|
54
|
+
"""
|
|
55
|
+
Retrieves the configuration class for a given architecture.
|
|
56
|
+
|
|
57
|
+
:param arch: architecture (clas name)
|
|
58
|
+
:param exc: raise an exception if not found
|
|
59
|
+
:return: type
|
|
60
|
+
"""
|
|
61
|
+
cls = getattr(transformers, arch)
|
|
62
|
+
mod_name = cls.__module__
|
|
63
|
+
mod = importlib.import_module(mod_name)
|
|
64
|
+
source = inspect.getsource(mod)
|
|
65
|
+
reg = re.compile("config: ([A-Za-z0-9]+)")
|
|
66
|
+
fall = reg.findall(source)
|
|
67
|
+
if len(fall) == 0:
|
|
68
|
+
assert not exc, (
|
|
69
|
+
f"Unable to guess Configuration class name for arch={arch!r}, "
|
|
70
|
+
f"module={mod_name!r}, no candidate, source is\n{source}"
|
|
71
|
+
)
|
|
72
|
+
return None
|
|
73
|
+
unique = set(fall)
|
|
74
|
+
assert len(unique) == 1, (
|
|
75
|
+
f"Unable to guess Configuration class name for arch={arch!r}, "
|
|
76
|
+
f"module={mod_name!r}, found={unique} (#{len(unique)}), "
|
|
77
|
+
f"source is\n{source}"
|
|
78
|
+
)
|
|
79
|
+
cls_name = unique.pop()
|
|
80
|
+
return getattr(transformers, cls_name)
|