onnx-diagnostic 0.7.5__py3-none-any.whl → 0.7.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +56 -3
- onnx_diagnostic/export/dynamic_shapes.py +24 -10
- onnx_diagnostic/export/shape_helper.py +6 -2
- onnx_diagnostic/ext_test_case.py +2 -0
- onnx_diagnostic/helpers/_log_helper.py +6 -6
- onnx_diagnostic/helpers/cache_helper.py +326 -18
- onnx_diagnostic/helpers/config_helper.py +10 -0
- onnx_diagnostic/helpers/helper.py +152 -11
- onnx_diagnostic/helpers/mini_onnx_builder.py +7 -2
- onnx_diagnostic/helpers/onnx_helper.py +13 -7
- onnx_diagnostic/helpers/torch_helper.py +33 -11
- onnx_diagnostic/reference/ops/op_cast_like.py +15 -11
- onnx_diagnostic/reference/torch_ops/__init__.py +1 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +7 -0
- onnx_diagnostic/tasks/__init__.py +2 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +6 -2
- onnx_diagnostic/tasks/feature_extraction.py +7 -3
- onnx_diagnostic/tasks/fill_mask.py +6 -2
- onnx_diagnostic/tasks/image_classification.py +6 -2
- onnx_diagnostic/tasks/image_text_to_text.py +289 -62
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +2 -2
- onnx_diagnostic/tasks/object_detection.py +6 -2
- onnx_diagnostic/tasks/sentence_similarity.py +6 -2
- onnx_diagnostic/tasks/summarization.py +7 -2
- onnx_diagnostic/tasks/text2text_generation.py +7 -2
- onnx_diagnostic/tasks/text_classification.py +6 -2
- onnx_diagnostic/tasks/text_generation.py +14 -16
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +3 -3
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +17 -1
- onnx_diagnostic/torch_export_patches/patch_inputs.py +5 -2
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -4
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +428 -129
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +60 -41
- onnx_diagnostic/torch_models/hghub/hub_data.py +5 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +288 -0
- onnx_diagnostic/torch_models/validate.py +1 -0
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/RECORD +43 -42
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,12 @@
|
|
|
1
1
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
2
|
import torch
|
|
3
3
|
from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
|
|
4
|
-
from ..helpers.config_helper import
|
|
4
|
+
from ..helpers.config_helper import (
|
|
5
|
+
update_config,
|
|
6
|
+
check_hasattr,
|
|
7
|
+
_pick,
|
|
8
|
+
default_num_hidden_layers as nhl,
|
|
9
|
+
)
|
|
5
10
|
|
|
6
11
|
__TASK__ = "text2text-generation"
|
|
7
12
|
|
|
@@ -12,7 +17,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
|
12
17
|
if hasattr(config, "num_decoder_layers"):
|
|
13
18
|
config.num_decoder_layers = min(config.num_decoder_layers, 2)
|
|
14
19
|
if hasattr(config, "num_hidden_layers"):
|
|
15
|
-
config.num_hidden_layers = min(config.num_hidden_layers,
|
|
20
|
+
config.num_hidden_layers = min(config.num_hidden_layers, nhl())
|
|
16
21
|
update_config(config, kwargs)
|
|
17
22
|
return kwargs
|
|
18
23
|
|
|
@@ -1,6 +1,10 @@
|
|
|
1
1
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
2
|
import torch
|
|
3
|
-
from ..helpers.config_helper import
|
|
3
|
+
from ..helpers.config_helper import (
|
|
4
|
+
update_config,
|
|
5
|
+
check_hasattr,
|
|
6
|
+
default_num_hidden_layers as nhl,
|
|
7
|
+
)
|
|
4
8
|
|
|
5
9
|
__TASK__ = "text-classification"
|
|
6
10
|
|
|
@@ -9,7 +13,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
|
9
13
|
"""Reduces a model size."""
|
|
10
14
|
check_hasattr(config, "num_attention_heads", "num_hidden_layers")
|
|
11
15
|
kwargs = dict(
|
|
12
|
-
num_hidden_layers=min(config.num_hidden_layers,
|
|
16
|
+
num_hidden_layers=min(config.num_hidden_layers, nhl()),
|
|
13
17
|
num_attention_heads=min(config.num_attention_heads, 4),
|
|
14
18
|
)
|
|
15
19
|
update_config(config, kwargs)
|
|
@@ -1,13 +1,17 @@
|
|
|
1
1
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
2
2
|
import torch
|
|
3
|
-
import transformers
|
|
4
3
|
from ..helpers.cache_helper import (
|
|
5
4
|
make_dynamic_cache,
|
|
6
5
|
make_mamba_cache,
|
|
7
6
|
make_sliding_window_cache,
|
|
8
7
|
make_static_cache,
|
|
9
8
|
)
|
|
10
|
-
from ..helpers.config_helper import
|
|
9
|
+
from ..helpers.config_helper import (
|
|
10
|
+
update_config,
|
|
11
|
+
check_hasattr,
|
|
12
|
+
_pick,
|
|
13
|
+
default_num_hidden_layers as nhl,
|
|
14
|
+
)
|
|
11
15
|
|
|
12
16
|
__TASK__ = "text-generation"
|
|
13
17
|
|
|
@@ -26,7 +30,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
|
26
30
|
if config.__class__.__name__ == "FalconMambaConfig":
|
|
27
31
|
check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
|
|
28
32
|
kwargs = dict(
|
|
29
|
-
num_hidden_layers=min(config.num_hidden_layers,
|
|
33
|
+
num_hidden_layers=min(config.num_hidden_layers, nhl()),
|
|
30
34
|
intermediate_size=256 if config is None else min(512, config.intermediate_size),
|
|
31
35
|
hidden_size=512 if config is None else min(512, config.hidden_size),
|
|
32
36
|
cls_cache="MambaCache",
|
|
@@ -38,24 +42,13 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
|
38
42
|
head_dim=getattr(
|
|
39
43
|
config, "head_dim", config.hidden_size // config.num_attention_heads
|
|
40
44
|
),
|
|
41
|
-
num_hidden_layers=min(config.num_hidden_layers,
|
|
45
|
+
num_hidden_layers=min(config.num_hidden_layers, nhl()),
|
|
42
46
|
num_key_value_heads=(
|
|
43
47
|
config.num_key_value_heads
|
|
44
48
|
if hasattr(config, "num_key_value_heads")
|
|
45
49
|
else config.num_attention_heads
|
|
46
50
|
),
|
|
47
|
-
hidden_size=(
|
|
48
|
-
min(config.hidden_size, 4096 // 4)
|
|
49
|
-
if config.hidden_size % 64 == 0
|
|
50
|
-
else config.hidden_size
|
|
51
|
-
),
|
|
52
51
|
)
|
|
53
|
-
if config is None or hasattr(config, "intermediate_size"):
|
|
54
|
-
kwargs["intermediate_size"] = (
|
|
55
|
-
min(config.intermediate_size, 24576 // 4)
|
|
56
|
-
if config.intermediate_size % 4 == 0
|
|
57
|
-
else config.intermediate_size
|
|
58
|
-
)
|
|
59
52
|
update_config(config, kwargs)
|
|
60
53
|
return kwargs
|
|
61
54
|
|
|
@@ -95,9 +88,14 @@ def get_inputs(
|
|
|
95
88
|
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
|
|
96
89
|
|
|
97
90
|
if config is not None and config.__class__.__name__ == "FalconMambaConfig":
|
|
91
|
+
try:
|
|
92
|
+
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
93
|
+
except ImportError:
|
|
94
|
+
from transformers.cache_utils import MambaCache
|
|
95
|
+
|
|
98
96
|
assert cls_cache in (
|
|
99
97
|
"MambaCache",
|
|
100
|
-
|
|
98
|
+
MambaCache,
|
|
101
99
|
), f"Unexpected value for cls_cache={cls_cache} and config={config}"
|
|
102
100
|
seq_length_multiple = 8
|
|
103
101
|
sequence_length = (
|
|
@@ -361,7 +361,7 @@ def torch_export_patches(
|
|
|
361
361
|
torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes
|
|
362
362
|
|
|
363
363
|
# torch._export.non_strict_utils.produce_guards_and_solve_constraints
|
|
364
|
-
if catch_constraints:
|
|
364
|
+
if patch_torch and catch_constraints:
|
|
365
365
|
if verbose:
|
|
366
366
|
print("[torch_export_patches] modifies shape constraints")
|
|
367
367
|
f_produce_guards_and_solve_constraints = (
|
|
@@ -513,7 +513,7 @@ def torch_export_patches(
|
|
|
513
513
|
if verbose:
|
|
514
514
|
print("[torch_export_patches] restored pytorch functions")
|
|
515
515
|
|
|
516
|
-
if stop_if_static:
|
|
516
|
+
if patch_torch and stop_if_static:
|
|
517
517
|
if verbose:
|
|
518
518
|
print("[torch_export_patches] restored ShapeEnv._set_replacement")
|
|
519
519
|
|
|
@@ -529,7 +529,7 @@ def torch_export_patches(
|
|
|
529
529
|
print("[torch_export_patches] restored ShapeEnv._check_frozen")
|
|
530
530
|
ShapeEnv._check_frozen = f_shape_env__check_frozen
|
|
531
531
|
|
|
532
|
-
if catch_constraints:
|
|
532
|
+
if patch_torch and catch_constraints:
|
|
533
533
|
# to catch or skip dynamic_shapes issues
|
|
534
534
|
torch._export.non_strict_utils.produce_guards_and_solve_constraints = (
|
|
535
535
|
f_produce_guards_and_solve_constraints
|
|
@@ -6,12 +6,17 @@ import torch
|
|
|
6
6
|
import transformers
|
|
7
7
|
from transformers.cache_utils import (
|
|
8
8
|
DynamicCache,
|
|
9
|
-
MambaCache,
|
|
10
9
|
EncoderDecoderCache,
|
|
10
|
+
HybridCache,
|
|
11
11
|
SlidingWindowCache,
|
|
12
12
|
StaticCache,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
|
+
try:
|
|
16
|
+
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
17
|
+
except ImportError:
|
|
18
|
+
from transformers.cache_utils import MambaCache
|
|
19
|
+
|
|
15
20
|
from ..helpers import string_type
|
|
16
21
|
from .serialization import _lower_name_with_
|
|
17
22
|
|
|
@@ -161,6 +166,9 @@ def serialization_functions(
|
|
|
161
166
|
flatten_dynamic_cache,
|
|
162
167
|
unflatten_dynamic_cache,
|
|
163
168
|
flatten_with_keys_dynamic_cache,
|
|
169
|
+
flatten_hybrid_cache,
|
|
170
|
+
unflatten_hybrid_cache,
|
|
171
|
+
flatten_with_keys_hybrid_cache,
|
|
164
172
|
flatten_mamba_cache,
|
|
165
173
|
unflatten_mamba_cache,
|
|
166
174
|
flatten_with_keys_mamba_cache,
|
|
@@ -187,6 +195,14 @@ def serialization_functions(
|
|
|
187
195
|
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
188
196
|
verbose=verbose,
|
|
189
197
|
),
|
|
198
|
+
HybridCache: lambda verbose=verbose: register_class_serialization(
|
|
199
|
+
HybridCache,
|
|
200
|
+
flatten_hybrid_cache,
|
|
201
|
+
unflatten_hybrid_cache,
|
|
202
|
+
flatten_with_keys_hybrid_cache,
|
|
203
|
+
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
204
|
+
verbose=verbose,
|
|
205
|
+
),
|
|
190
206
|
MambaCache: lambda verbose=verbose: register_class_serialization(
|
|
191
207
|
MambaCache,
|
|
192
208
|
flatten_mamba_cache,
|
|
@@ -34,7 +34,7 @@ def _make_shape(subset: Dict, cls: type, value: Any) -> Any:
|
|
|
34
34
|
f"Inconsistencies in subset={subset}, found={values}, "
|
|
35
35
|
f"it cannot be a {cls}, value={string_type(value)}"
|
|
36
36
|
)
|
|
37
|
-
cache_length = len(value.key_cache)
|
|
37
|
+
cache_length = len(value.layers if hasattr(value, "layers") else value.key_cache)
|
|
38
38
|
for v in subset.values():
|
|
39
39
|
axes = v
|
|
40
40
|
break
|
|
@@ -70,6 +70,8 @@ def convert_dynamic_axes_into_dynamic_shapes(
|
|
|
70
70
|
:param verbose: verbosity
|
|
71
71
|
:return: (args, kwargs, dynamic shapes)
|
|
72
72
|
"""
|
|
73
|
+
from ..helpers.cache_helper import CacheKeyValue
|
|
74
|
+
|
|
73
75
|
new_kwargs = {}
|
|
74
76
|
if args:
|
|
75
77
|
assert hasattr(model, "forward"), f"Missing method 'forward' for {model!r}"
|
|
@@ -121,7 +123,8 @@ def convert_dynamic_axes_into_dynamic_shapes(
|
|
|
121
123
|
changes[k] = type(updated_kwargs[k])
|
|
122
124
|
continue
|
|
123
125
|
if isinstance(v, transformers.cache_utils.DynamicCache):
|
|
124
|
-
|
|
126
|
+
ca = CacheKeyValue(v)
|
|
127
|
+
updated_kwargs[k] = [ca.key_cache, ca.value_cache]
|
|
125
128
|
changes[k] = type(v)
|
|
126
129
|
continue
|
|
127
130
|
raise NotImplementedError(
|
|
@@ -27,8 +27,8 @@ def _catch_produce_guards_and_solve_constraints(
|
|
|
27
27
|
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
|
|
28
28
|
equalities_inputs: "EqualityConstraint", # noqa: F821
|
|
29
29
|
original_signature: inspect.Signature,
|
|
30
|
-
_is_torch_jit_trace: bool = False,
|
|
31
30
|
verbose: int = 0,
|
|
31
|
+
**kwargs,
|
|
32
32
|
):
|
|
33
33
|
try:
|
|
34
34
|
return previous_function(
|
|
@@ -37,7 +37,7 @@ def _catch_produce_guards_and_solve_constraints(
|
|
|
37
37
|
dynamic_shapes=dynamic_shapes,
|
|
38
38
|
equalities_inputs=equalities_inputs,
|
|
39
39
|
original_signature=original_signature,
|
|
40
|
-
|
|
40
|
+
**kwargs,
|
|
41
41
|
)
|
|
42
42
|
except Exception as e:
|
|
43
43
|
if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")):
|
|
@@ -51,7 +51,7 @@ def _catch_produce_guards_and_solve_constraints(
|
|
|
51
51
|
f"dynamic_shapes={dynamic_shapes}\n"
|
|
52
52
|
f"equalities_inputs={equalities_inputs}\n"
|
|
53
53
|
f"original_signature={original_signature}\n"
|
|
54
|
-
f"
|
|
54
|
+
f"kwargs={kwargs}\n"
|
|
55
55
|
f"exc={e}\ngm={gm}"
|
|
56
56
|
)
|
|
57
57
|
torch._dynamo.reset()
|
|
@@ -174,7 +174,7 @@ class patched_ShapeEnv:
|
|
|
174
174
|
self.counter["ignored_backward_guard"] += 1
|
|
175
175
|
raise AssertionError(
|
|
176
176
|
f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
|
|
177
|
-
f"this could result in accuracy problems
|
|
177
|
+
f"this could result in accuracy problems"
|
|
178
178
|
)
|
|
179
179
|
|
|
180
180
|
def _set_replacement(
|