onnx-diagnostic 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +22 -5
- onnx_diagnostic/ext_test_case.py +31 -0
- onnx_diagnostic/helpers/cache_helper.py +23 -12
- onnx_diagnostic/helpers/config_helper.py +16 -1
- onnx_diagnostic/helpers/log_helper.py +308 -83
- onnx_diagnostic/helpers/rt_helper.py +11 -1
- onnx_diagnostic/helpers/torch_helper.py +7 -3
- onnx_diagnostic/tasks/__init__.py +2 -0
- onnx_diagnostic/tasks/text_generation.py +17 -8
- onnx_diagnostic/tasks/text_to_image.py +91 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +3 -1
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +148 -351
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +89 -10
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +259 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +15 -4
- onnx_diagnostic/torch_models/hghub/hub_data.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +28 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -5
- onnx_diagnostic/torch_models/validate.py +36 -12
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/METADATA +26 -1
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/RECORD +28 -24
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@ import inspect
|
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from functools import wraps
|
|
4
4
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
5
|
+
import packaging.version as pv
|
|
5
6
|
import torch
|
|
6
7
|
import transformers
|
|
7
8
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
@@ -20,18 +21,41 @@ def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) ->
|
|
|
20
21
|
]
|
|
21
22
|
if bh_indices:
|
|
22
23
|
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
|
|
24
|
+
# reshape
|
|
23
25
|
dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
|
|
24
26
|
dimensions = tuple(reversed(dimensions))
|
|
25
27
|
indices = tuple(shape.index(-1) for shape in dimensions)
|
|
26
28
|
|
|
29
|
+
# unsqueeze
|
|
30
|
+
udimensions = [tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions]
|
|
31
|
+
|
|
27
32
|
def vector_mask_function(
|
|
28
33
|
*args, mask_function=mask_function, dimensions=dimensions, indices=indices
|
|
29
34
|
):
|
|
30
|
-
assert len(args) == len(
|
|
31
|
-
dimensions
|
|
32
|
-
|
|
35
|
+
assert len(args) == len(dimensions) == len(udimensions), (
|
|
36
|
+
f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
|
|
37
|
+
f"and udimensions={udimensions}."
|
|
38
|
+
)
|
|
39
|
+
assert len(indices) == len(args), (
|
|
40
|
+
f"Mismatch between args={string_type(args)} and indices={indices}, "
|
|
41
|
+
f"they should have the same length."
|
|
42
|
+
)
|
|
43
|
+
for a in args:
|
|
44
|
+
assert (
|
|
45
|
+
a.ndim == 1
|
|
46
|
+
), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
|
|
47
|
+
torch._check(a.shape[0] > 0)
|
|
48
|
+
|
|
33
49
|
new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
|
|
50
|
+
# new_args = [
|
|
51
|
+
# a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
|
|
52
|
+
# for a, dims in zip(args, udimensions)
|
|
53
|
+
# ]
|
|
34
54
|
max_shape = tuple(args[i].shape[0] for i in indices)
|
|
55
|
+
# if is_torchdynamo_exporting():
|
|
56
|
+
# for a in args:
|
|
57
|
+
# # The exporter should export with a dimension > 1 to make sure it is dynamic.
|
|
58
|
+
# torch._check(a.shape[0] > 1)
|
|
35
59
|
expanded_args = [a.expand(max_shape) for a in new_args]
|
|
36
60
|
return mask_function(*expanded_args)
|
|
37
61
|
|
|
@@ -190,8 +214,8 @@ class patched_DynamicCache:
|
|
|
190
214
|
if len(self.key_cache) <= layer_idx:
|
|
191
215
|
# There may be skipped layers, fill them with empty lists
|
|
192
216
|
for _ in range(len(self.key_cache), layer_idx):
|
|
193
|
-
self.key_cache.append(torch.tensor([]))
|
|
194
|
-
self.value_cache.append(torch.tensor([]))
|
|
217
|
+
self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
218
|
+
self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
195
219
|
self.key_cache.append(key_states)
|
|
196
220
|
self.value_cache.append(value_states)
|
|
197
221
|
elif not self.key_cache[
|
|
@@ -207,7 +231,6 @@ class patched_DynamicCache:
|
|
|
207
231
|
self.value_cache[layer_idx] = torch.cat(
|
|
208
232
|
[self.value_cache[layer_idx], value_states], dim=-2
|
|
209
233
|
)
|
|
210
|
-
|
|
211
234
|
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
212
235
|
|
|
213
236
|
def crop(self, max_length: int):
|
|
@@ -791,10 +814,7 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
791
814
|
return wrapper
|
|
792
815
|
|
|
793
816
|
|
|
794
|
-
class
|
|
795
|
-
_PATCHES_ = ["forward"]
|
|
796
|
-
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
|
|
797
|
-
|
|
817
|
+
class common_RotaryEmbedding(torch.nn.Module):
|
|
798
818
|
@torch.no_grad()
|
|
799
819
|
@patched_dynamic_rope_update
|
|
800
820
|
def forward(self, x, position_ids):
|
|
@@ -820,6 +840,65 @@ class patched_Phi3RotaryEmbedding(torch.nn.Module):
|
|
|
820
840
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
821
841
|
|
|
822
842
|
|
|
843
|
+
class patched_GemmaRotaryEmbedding(common_RotaryEmbedding):
|
|
844
|
+
_PATCHES_ = ["forward"]
|
|
845
|
+
_PATCHED_CLASS_ = transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.52"):
|
|
849
|
+
|
|
850
|
+
class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding):
|
|
851
|
+
_PATCHES_ = ["forward"]
|
|
852
|
+
_PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding
|
|
853
|
+
|
|
854
|
+
class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding):
|
|
855
|
+
_PATCHES_ = ["forward"]
|
|
856
|
+
_PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
class patched_LlamaRotaryEmbedding(common_RotaryEmbedding):
|
|
860
|
+
_PATCHES_ = ["forward"]
|
|
861
|
+
_PATCHED_CLASS_ = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
|
|
862
|
+
|
|
863
|
+
|
|
864
|
+
class patched_MistralRotaryEmbedding(common_RotaryEmbedding):
|
|
865
|
+
_PATCHES_ = ["forward"]
|
|
866
|
+
_PATCHED_CLASS_ = transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding
|
|
867
|
+
|
|
868
|
+
|
|
869
|
+
class patched_MixtralRotaryEmbedding(common_RotaryEmbedding):
|
|
870
|
+
_PATCHES_ = ["forward"]
|
|
871
|
+
_PATCHED_CLASS_ = transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
class patched_PhiRotaryEmbedding(common_RotaryEmbedding):
|
|
875
|
+
_PATCHES_ = ["forward"]
|
|
876
|
+
_PATCHED_CLASS_ = transformers.models.phi.modeling_phi.PhiRotaryEmbedding
|
|
877
|
+
|
|
878
|
+
|
|
879
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.51"):
|
|
880
|
+
|
|
881
|
+
class patched_Phi3RotaryEmbedding(common_RotaryEmbedding):
|
|
882
|
+
_PATCHES_ = ["forward"]
|
|
883
|
+
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.52"):
|
|
887
|
+
|
|
888
|
+
class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding):
|
|
889
|
+
_PATCHES_ = ["forward"]
|
|
890
|
+
_PATCHED_CLASS_ = (
|
|
891
|
+
transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalRotaryEmbedding
|
|
892
|
+
)
|
|
893
|
+
|
|
894
|
+
|
|
895
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.53"):
|
|
896
|
+
|
|
897
|
+
class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding):
|
|
898
|
+
_PATCHES_ = ["forward"]
|
|
899
|
+
_PATCHED_CLASS_ = transformers.models.smollm3.modeling_smollm3.SmolLM3RotaryEmbedding
|
|
900
|
+
|
|
901
|
+
|
|
823
902
|
class patched_IdeficsEmbedding(torch.nn.Module):
|
|
824
903
|
_PATCHES_ = ["forward"]
|
|
825
904
|
_PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Any, Callable, List, Set, Tuple
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _lower_name_with_(name):
|
|
7
|
+
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
|
8
|
+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def make_serialization_function_for_dataclass(
|
|
12
|
+
cls: type, supported_classes: Set[type]
|
|
13
|
+
) -> Tuple[Callable, Callable, Callable]:
|
|
14
|
+
"""
|
|
15
|
+
Automatically creates serialization function for a class decorated with
|
|
16
|
+
``dataclasses.dataclass``.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def flatten_cls(obj: cls) -> Tuple[List[Any], torch.utils._pytree.Context]: # type: ignore[valid-type]
|
|
20
|
+
"""Serializes a ``%s`` with python objects."""
|
|
21
|
+
return list(obj.values()), list(obj.keys())
|
|
22
|
+
|
|
23
|
+
def flatten_with_keys_cls(
|
|
24
|
+
obj: cls, # type: ignore[valid-type]
|
|
25
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
26
|
+
"""Serializes a ``%s`` with python objects with keys."""
|
|
27
|
+
values, context = list(obj.values()), list(obj.keys())
|
|
28
|
+
return [
|
|
29
|
+
(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)
|
|
30
|
+
], context
|
|
31
|
+
|
|
32
|
+
def unflatten_cls(
|
|
33
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
34
|
+
) -> cls: # type: ignore[valid-type]
|
|
35
|
+
"""Restores an instance of ``%s`` from python objects."""
|
|
36
|
+
return cls(**dict(zip(context, values)))
|
|
37
|
+
|
|
38
|
+
name = _lower_name_with_(cls.__name__)
|
|
39
|
+
flatten_cls.__name__ = f"flatten_{name}"
|
|
40
|
+
flatten_with_keys_cls.__name__ = f"flatten_with_keys_{name}"
|
|
41
|
+
unflatten_cls.__name__ = f"unflatten_{name}"
|
|
42
|
+
flatten_cls.__doc__ = flatten_cls.__doc__ % cls.__name__
|
|
43
|
+
flatten_with_keys_cls.__doc__ = flatten_with_keys_cls.__doc__ % cls.__name__
|
|
44
|
+
unflatten_cls.__doc__ = unflatten_cls.__doc__ % cls.__name__
|
|
45
|
+
supported_classes.add(cls)
|
|
46
|
+
return flatten_cls, flatten_with_keys_cls, unflatten_cls
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from typing import Dict, Optional, Set
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
|
5
|
+
except ImportError as e:
|
|
6
|
+
try:
|
|
7
|
+
import diffusers
|
|
8
|
+
except ImportError:
|
|
9
|
+
diffusers = None
|
|
10
|
+
UNet2DConditionOutput = None
|
|
11
|
+
if diffusers:
|
|
12
|
+
raise e
|
|
13
|
+
|
|
14
|
+
from . import make_serialization_function_for_dataclass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _make_wrong_registrations() -> Dict[type, Optional[str]]:
|
|
18
|
+
res: Dict[type, Optional[str]] = {}
|
|
19
|
+
for c in [UNet2DConditionOutput]:
|
|
20
|
+
if c is not None:
|
|
21
|
+
res[c] = None
|
|
22
|
+
return res
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
SUPPORTED_DATACLASSES: Set[type] = set()
|
|
26
|
+
WRONG_REGISTRATIONS = _make_wrong_registrations()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
if UNet2DConditionOutput is not None:
|
|
30
|
+
(
|
|
31
|
+
flatten_u_net2_d_condition_output,
|
|
32
|
+
flatten_with_keys_u_net2_d_condition_output,
|
|
33
|
+
unflatten_u_net2_d_condition_output,
|
|
34
|
+
) = make_serialization_function_for_dataclass(UNet2DConditionOutput, SUPPORTED_DATACLASSES)
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
from typing import Any, List, Set, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
import transformers
|
|
4
|
+
from transformers.cache_utils import (
|
|
5
|
+
DynamicCache,
|
|
6
|
+
MambaCache,
|
|
7
|
+
EncoderDecoderCache,
|
|
8
|
+
SlidingWindowCache,
|
|
9
|
+
StaticCache,
|
|
10
|
+
)
|
|
11
|
+
from transformers.modeling_outputs import BaseModelOutput
|
|
12
|
+
from ...helpers.cache_helper import make_static_cache
|
|
13
|
+
from . import make_serialization_function_for_dataclass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
SUPPORTED_DATACLASSES: Set[type] = set()
|
|
17
|
+
WRONG_REGISTRATIONS = {
|
|
18
|
+
DynamicCache: "4.50",
|
|
19
|
+
BaseModelOutput: None,
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
############
|
|
24
|
+
# MambaCache
|
|
25
|
+
############
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def flatten_mamba_cache(
|
|
29
|
+
mamba_cache: MambaCache,
|
|
30
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
31
|
+
"""Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
|
|
32
|
+
flat = [
|
|
33
|
+
("conv_states", mamba_cache.conv_states),
|
|
34
|
+
("ssm_states", mamba_cache.ssm_states),
|
|
35
|
+
]
|
|
36
|
+
return [f[1] for f in flat], [f[0] for f in flat]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def unflatten_mamba_cache(
|
|
40
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
41
|
+
) -> MambaCache:
|
|
42
|
+
"""Restores a :class:`transformers.cache_utils.MambaCache` from python objects."""
|
|
43
|
+
conv_states, ssm_states = values
|
|
44
|
+
|
|
45
|
+
class _config:
|
|
46
|
+
def __init__(self):
|
|
47
|
+
if isinstance(conv_states, list):
|
|
48
|
+
self.intermediate_size = conv_states[0].shape[1]
|
|
49
|
+
self.state_size = ssm_states[0].shape[2]
|
|
50
|
+
self.conv_kernel = conv_states[0].shape[2]
|
|
51
|
+
self.num_hidden_layers = len(conv_states)
|
|
52
|
+
else:
|
|
53
|
+
self.intermediate_size = conv_states.shape[2]
|
|
54
|
+
self.state_size = ssm_states.shape[3]
|
|
55
|
+
self.conv_kernel = conv_states.shape[3]
|
|
56
|
+
self.num_hidden_layers = conv_states.shape[0]
|
|
57
|
+
|
|
58
|
+
cache = MambaCache(
|
|
59
|
+
_config(),
|
|
60
|
+
max_batch_size=1,
|
|
61
|
+
dtype=values[-1][0].dtype,
|
|
62
|
+
device="cpu" if values[-1][0].get_device() < 0 else "cuda",
|
|
63
|
+
)
|
|
64
|
+
values = dict(zip(context, values))
|
|
65
|
+
for k, v in values.items():
|
|
66
|
+
setattr(cache, k, v)
|
|
67
|
+
return cache
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[
|
|
71
|
+
List[Tuple[torch.utils._pytree.KeyEntry, Any]],
|
|
72
|
+
torch.utils._pytree.Context,
|
|
73
|
+
]:
|
|
74
|
+
"""Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
|
|
75
|
+
values, context = flatten_mamba_cache(cache)
|
|
76
|
+
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
##############
|
|
80
|
+
# DynamicCache
|
|
81
|
+
##############
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def flatten_dynamic_cache(
|
|
85
|
+
dynamic_cache: DynamicCache,
|
|
86
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
87
|
+
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
88
|
+
if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"):
|
|
89
|
+
return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache)
|
|
90
|
+
flat = [("key_cache", dynamic_cache.key_cache), ("value_cache", dynamic_cache.value_cache)]
|
|
91
|
+
return [f[1] for f in flat], [f[0] for f in flat]
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def flatten_with_keys_dynamic_cache(
|
|
95
|
+
dynamic_cache: DynamicCache,
|
|
96
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
97
|
+
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
|
|
98
|
+
if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"):
|
|
99
|
+
return transformers.cache_utils._flatten_with_keys_dynamic_cache(dynamic_cache)
|
|
100
|
+
values, context = flatten_dynamic_cache(dynamic_cache)
|
|
101
|
+
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def unflatten_dynamic_cache(
|
|
105
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
106
|
+
) -> DynamicCache:
|
|
107
|
+
"""Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
|
|
108
|
+
if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"):
|
|
109
|
+
assert output_type is None, f"output_type={output_type} not supported"
|
|
110
|
+
return transformers.cache_utils._unflatten_dynamic_cache(values, context)
|
|
111
|
+
|
|
112
|
+
cache = transformers.cache_utils.DynamicCache()
|
|
113
|
+
values = dict(zip(context, values))
|
|
114
|
+
for k, v in values.items():
|
|
115
|
+
setattr(cache, k, v)
|
|
116
|
+
return cache
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
#############
|
|
120
|
+
# StaticCache
|
|
121
|
+
#############
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def flatten_static_cache(
|
|
125
|
+
cache: StaticCache,
|
|
126
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
127
|
+
"""Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
|
|
128
|
+
assert not cache.key_cache or cache.max_cache_len == cache.key_cache[0].shape[2], (
|
|
129
|
+
f"Serialization doet not work when "
|
|
130
|
+
f"cache.max_cache_len={cache.max_cache_len} != "
|
|
131
|
+
f"cache.key_cache[0].shape[2]={cache.key_cache[0].shape[2]}"
|
|
132
|
+
)
|
|
133
|
+
flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
|
|
134
|
+
return [f[1] for f in flat], [f[0] for f in flat]
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def flatten_with_keys_static_cache(
|
|
138
|
+
cache: StaticCache,
|
|
139
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
140
|
+
"""Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
|
|
141
|
+
values, context = flatten_static_cache(cache)
|
|
142
|
+
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def unflatten_static_cache(
|
|
146
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
147
|
+
) -> StaticCache:
|
|
148
|
+
"""Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
|
|
149
|
+
return make_static_cache(
|
|
150
|
+
list(zip(values[0], values[1])), max_cache_len=values[0][0].shape[2]
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
####################
|
|
155
|
+
# SlidingWindowCache
|
|
156
|
+
####################
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def flatten_sliding_window_cache(
|
|
160
|
+
cache: SlidingWindowCache,
|
|
161
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
162
|
+
"""
|
|
163
|
+
Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
|
|
164
|
+
with python objects.
|
|
165
|
+
"""
|
|
166
|
+
flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
|
|
167
|
+
return [f[1] for f in flat], [f[0] for f in flat]
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def flatten_with_keys_sliding_window_cache(
|
|
171
|
+
cache: SlidingWindowCache,
|
|
172
|
+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
|
|
173
|
+
"""
|
|
174
|
+
Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
|
|
175
|
+
with python objects.
|
|
176
|
+
"""
|
|
177
|
+
values, context = flatten_sliding_window_cache(cache)
|
|
178
|
+
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def unflatten_sliding_window_cache(
|
|
182
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
183
|
+
) -> SlidingWindowCache:
|
|
184
|
+
"""Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects."""
|
|
185
|
+
key_cache, value_cache = values
|
|
186
|
+
|
|
187
|
+
class _config:
|
|
188
|
+
def __init__(self):
|
|
189
|
+
self.head_dim = key_cache[0].shape[-1]
|
|
190
|
+
self.num_attention_heads = key_cache[0].shape[1]
|
|
191
|
+
self.num_hidden_layers = len(key_cache)
|
|
192
|
+
self.sliding_window = key_cache[0].shape[2]
|
|
193
|
+
|
|
194
|
+
cache = SlidingWindowCache(
|
|
195
|
+
_config(),
|
|
196
|
+
max_batch_size=key_cache[0].shape[0],
|
|
197
|
+
max_cache_len=key_cache[0].shape[2], # sligding window
|
|
198
|
+
device=key_cache[0].device,
|
|
199
|
+
dtype=key_cache[0].dtype,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
values = dict(zip(context, values))
|
|
203
|
+
for k, v in values.items():
|
|
204
|
+
setattr(cache, k, v)
|
|
205
|
+
return cache
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
#####################
|
|
209
|
+
# EncoderDecoderCache
|
|
210
|
+
#####################
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def flatten_encoder_decoder_cache(
|
|
214
|
+
ec_cache: EncoderDecoderCache,
|
|
215
|
+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
|
|
216
|
+
"""
|
|
217
|
+
Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
|
|
218
|
+
with python objects.
|
|
219
|
+
"""
|
|
220
|
+
dictionary = {
|
|
221
|
+
"self_attention_cache": ec_cache.self_attention_cache,
|
|
222
|
+
"cross_attention_cache": ec_cache.cross_attention_cache,
|
|
223
|
+
}
|
|
224
|
+
return torch.utils._pytree._dict_flatten(dictionary)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def flatten_with_keys_encoder_decoder_cache(ec_cache: EncoderDecoderCache) -> Tuple[
|
|
228
|
+
List[Tuple[torch.utils._pytree.KeyEntry, Any]],
|
|
229
|
+
torch.utils._pytree.Context,
|
|
230
|
+
]:
|
|
231
|
+
"""
|
|
232
|
+
Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
|
|
233
|
+
with python objects.
|
|
234
|
+
"""
|
|
235
|
+
dictionary = {
|
|
236
|
+
"self_attention_cache": ec_cache.self_attention_cache,
|
|
237
|
+
"cross_attention_cache": ec_cache.cross_attention_cache,
|
|
238
|
+
}
|
|
239
|
+
return torch.utils._pytree._dict_flatten_with_keys(dictionary)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def unflatten_encoder_decoder_cache(
|
|
243
|
+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
|
|
244
|
+
) -> EncoderDecoderCache:
|
|
245
|
+
"""Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects."""
|
|
246
|
+
dictionary = torch.utils._pytree._dict_unflatten(values, context)
|
|
247
|
+
return EncoderDecoderCache(**dictionary)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
#############
|
|
251
|
+
# dataclasses
|
|
252
|
+
#############
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
(
|
|
256
|
+
flatten_base_model_output,
|
|
257
|
+
flatten_with_keys_base_model_output,
|
|
258
|
+
unflatten_base_model_output,
|
|
259
|
+
) = make_serialization_function_for_dataclass(BaseModelOutput, SUPPORTED_DATACLASSES)
|
|
@@ -140,7 +140,10 @@ def _guess_task_from_config(config: Any) -> Optional[str]:
|
|
|
140
140
|
|
|
141
141
|
@functools.cache
|
|
142
142
|
def task_from_arch(
|
|
143
|
-
arch: str,
|
|
143
|
+
arch: str,
|
|
144
|
+
default_value: Optional[str] = None,
|
|
145
|
+
model_id: Optional[str] = None,
|
|
146
|
+
subfolder: Optional[str] = None,
|
|
144
147
|
) -> str:
|
|
145
148
|
"""
|
|
146
149
|
This function relies on stored information. That information needs to be refresh.
|
|
@@ -148,6 +151,7 @@ def task_from_arch(
|
|
|
148
151
|
:param arch: architecture name
|
|
149
152
|
:param default_value: default value in case the task cannot be determined
|
|
150
153
|
:param model_id: unused unless the architecture does not help.
|
|
154
|
+
:param subfolder: subfolder
|
|
151
155
|
:return: task
|
|
152
156
|
|
|
153
157
|
.. runpython::
|
|
@@ -162,7 +166,7 @@ def task_from_arch(
|
|
|
162
166
|
data = load_architecture_task()
|
|
163
167
|
if arch not in data and model_id:
|
|
164
168
|
# Let's try with the model id.
|
|
165
|
-
return task_from_id(model_id)
|
|
169
|
+
return task_from_id(model_id, subfolder=subfolder)
|
|
166
170
|
if default_value is not None:
|
|
167
171
|
return data.get(arch, default_value)
|
|
168
172
|
assert arch in data, (
|
|
@@ -178,6 +182,7 @@ def task_from_id(
|
|
|
178
182
|
default_value: Optional[str] = None,
|
|
179
183
|
pretrained: bool = False,
|
|
180
184
|
fall_back_to_pretrained: bool = True,
|
|
185
|
+
subfolder: Optional[str] = None,
|
|
181
186
|
) -> str:
|
|
182
187
|
"""
|
|
183
188
|
Returns the task attached to a model id.
|
|
@@ -187,7 +192,7 @@ def task_from_id(
|
|
|
187
192
|
if the task cannot be determined
|
|
188
193
|
:param pretrained: uses the config
|
|
189
194
|
:param fall_back_to_pretrained: falls back to pretrained config
|
|
190
|
-
:param
|
|
195
|
+
:param subfolder: subfolder
|
|
191
196
|
:return: task
|
|
192
197
|
"""
|
|
193
198
|
if not pretrained:
|
|
@@ -196,7 +201,7 @@ def task_from_id(
|
|
|
196
201
|
except RuntimeError:
|
|
197
202
|
if not fall_back_to_pretrained:
|
|
198
203
|
raise
|
|
199
|
-
config = get_pretrained_config(model_id)
|
|
204
|
+
config = get_pretrained_config(model_id, subfolder=subfolder)
|
|
200
205
|
try:
|
|
201
206
|
return config.pipeline_tag
|
|
202
207
|
except AttributeError:
|
|
@@ -206,6 +211,12 @@ def task_from_id(
|
|
|
206
211
|
data = load_architecture_task()
|
|
207
212
|
if model_id in data:
|
|
208
213
|
return data[model_id]
|
|
214
|
+
if type(config) is dict and "_class_name" in config:
|
|
215
|
+
return task_from_arch(config["_class_name"], default_value=default_value)
|
|
216
|
+
if not config.architectures or not config.architectures:
|
|
217
|
+
# Some hardcoded values until a better solution is found.
|
|
218
|
+
if model_id.startswith("google/bert_"):
|
|
219
|
+
return "fill-mask"
|
|
209
220
|
assert config.architectures is not None and len(config.architectures) == 1, (
|
|
210
221
|
f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
|
|
211
222
|
f"architectures={config.architectures} in config={config}. "
|
|
@@ -22,6 +22,7 @@ __data_arch__ = textwrap.dedent(
|
|
|
22
22
|
BlenderbotModel,feature-extraction
|
|
23
23
|
BloomModel,feature-extraction
|
|
24
24
|
CLIPModel,zero-shot-image-classification
|
|
25
|
+
CLIPTextModel,feature-extraction
|
|
25
26
|
CLIPVisionModel,feature-extraction
|
|
26
27
|
CamembertModel,feature-extraction
|
|
27
28
|
CodeGenModel,feature-extraction
|
|
@@ -4302,3 +4302,31 @@ def _ccached_microsoft_phi_35_mini_instruct():
|
|
|
4302
4302
|
"vocab_size": 32064,
|
|
4303
4303
|
}
|
|
4304
4304
|
)
|
|
4305
|
+
|
|
4306
|
+
|
|
4307
|
+
def _ccached_diffusers_tiny_torch_full_checker_unet():
|
|
4308
|
+
"diffusers/tiny-torch-full-checker/unet"
|
|
4309
|
+
return {
|
|
4310
|
+
"_class_name": "UNet2DConditionModel",
|
|
4311
|
+
"_diffusers_version": "0.8.0",
|
|
4312
|
+
"_name_or_path": "https://huggingface.co/diffusers/tiny-torch-full-checker/blob/main/unet/config.json",
|
|
4313
|
+
"act_fn": "silu",
|
|
4314
|
+
"attention_head_dim": 8,
|
|
4315
|
+
"block_out_channels": [32, 64],
|
|
4316
|
+
"center_input_sample": false,
|
|
4317
|
+
"cross_attention_dim": 32,
|
|
4318
|
+
"down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D"],
|
|
4319
|
+
"downsample_padding": 1,
|
|
4320
|
+
"dual_cross_attention": false,
|
|
4321
|
+
"flip_sin_to_cos": true,
|
|
4322
|
+
"freq_shift": 0,
|
|
4323
|
+
"in_channels": 4,
|
|
4324
|
+
"layers_per_block": 2,
|
|
4325
|
+
"mid_block_scale_factor": 1,
|
|
4326
|
+
"norm_eps": 1e-05,
|
|
4327
|
+
"norm_num_groups": 32,
|
|
4328
|
+
"out_channels": 4,
|
|
4329
|
+
"sample_size": 32,
|
|
4330
|
+
"up_block_types": ["CrossAttnUpBlock2D", "UpBlock2D"],
|
|
4331
|
+
"use_linear_projection": false,
|
|
4332
|
+
}
|
|
@@ -106,7 +106,7 @@ def get_untrained_model_with_inputs(
|
|
|
106
106
|
print(f"[get_untrained_model_with_inputs] architectures={archs!r}")
|
|
107
107
|
print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
|
|
108
108
|
if task is None:
|
|
109
|
-
task = task_from_arch(archs[0], model_id=model_id)
|
|
109
|
+
task = task_from_arch(archs[0], model_id=model_id, subfolder=subfolder)
|
|
110
110
|
if verbose:
|
|
111
111
|
print(f"[get_untrained_model_with_inputs] task={task!r}")
|
|
112
112
|
|
|
@@ -145,12 +145,19 @@ def get_untrained_model_with_inputs(
|
|
|
145
145
|
f"{config._attn_implementation!r}" # type: ignore[union-attr]
|
|
146
146
|
)
|
|
147
147
|
|
|
148
|
+
if type(config) is dict and "_diffusers_version" in config:
|
|
149
|
+
import diffusers
|
|
150
|
+
|
|
151
|
+
package_source = diffusers
|
|
152
|
+
else:
|
|
153
|
+
package_source = transformers
|
|
154
|
+
|
|
148
155
|
if use_pretrained:
|
|
149
156
|
model = transformers.AutoModel.from_pretrained(model_id, **mkwargs)
|
|
150
157
|
else:
|
|
151
158
|
if archs is not None:
|
|
152
159
|
try:
|
|
153
|
-
|
|
160
|
+
cls_model = getattr(package_source, archs[0])
|
|
154
161
|
except AttributeError as e:
|
|
155
162
|
# The code of the models is not in transformers but in the
|
|
156
163
|
# repository of the model. We need to download it.
|
|
@@ -174,10 +181,12 @@ def get_untrained_model_with_inputs(
|
|
|
174
181
|
f"[get_untrained_model_with_inputs] from folder "
|
|
175
182
|
f"{os.path.split(pyfiles[0])[0]!r}"
|
|
176
183
|
)
|
|
177
|
-
|
|
178
|
-
|
|
184
|
+
cls_model = (
|
|
185
|
+
transformers.dynamic_module_utils.get_class_from_dynamic_module(
|
|
186
|
+
cls_name,
|
|
187
|
+
pretrained_model_name_or_path=os.path.split(pyfiles[0])[0],
|
|
188
|
+
)
|
|
179
189
|
)
|
|
180
|
-
model = cls(config)
|
|
181
190
|
else:
|
|
182
191
|
raise AttributeError(
|
|
183
192
|
f"Unable to find class 'tranformers.{archs[0]}'. "
|
|
@@ -191,6 +200,16 @@ def get_untrained_model_with_inputs(
|
|
|
191
200
|
f"and use_pretrained=True."
|
|
192
201
|
)
|
|
193
202
|
|
|
203
|
+
try:
|
|
204
|
+
if type(config) is dict:
|
|
205
|
+
model = cls_model(**config)
|
|
206
|
+
else:
|
|
207
|
+
model = cls_model(config)
|
|
208
|
+
except RuntimeError as e:
|
|
209
|
+
raise RuntimeError(
|
|
210
|
+
f"Unable to instantiate class {cls_model.__name__} with\n{config}"
|
|
211
|
+
) from e
|
|
212
|
+
|
|
194
213
|
# input kwargs
|
|
195
214
|
kwargs, fct = random_input_kwargs(config, task)
|
|
196
215
|
if verbose:
|