onnx-diagnostic 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +22 -5
  3. onnx_diagnostic/ext_test_case.py +31 -0
  4. onnx_diagnostic/helpers/cache_helper.py +23 -12
  5. onnx_diagnostic/helpers/config_helper.py +16 -1
  6. onnx_diagnostic/helpers/log_helper.py +308 -83
  7. onnx_diagnostic/helpers/rt_helper.py +11 -1
  8. onnx_diagnostic/helpers/torch_helper.py +7 -3
  9. onnx_diagnostic/tasks/__init__.py +2 -0
  10. onnx_diagnostic/tasks/text_generation.py +17 -8
  11. onnx_diagnostic/tasks/text_to_image.py +91 -0
  12. onnx_diagnostic/torch_export_patches/eval/__init__.py +3 -1
  13. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
  14. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +148 -351
  15. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +89 -10
  16. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  17. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  18. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +259 -0
  19. onnx_diagnostic/torch_models/hghub/hub_api.py +15 -4
  20. onnx_diagnostic/torch_models/hghub/hub_data.py +1 -0
  21. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +28 -0
  22. onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -5
  23. onnx_diagnostic/torch_models/validate.py +36 -12
  24. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/METADATA +26 -1
  25. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/RECORD +28 -24
  26. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/WHEEL +0 -0
  27. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/licenses/LICENSE.txt +0 -0
  28. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ import inspect
2
2
  from dataclasses import dataclass
3
3
  from functools import wraps
4
4
  from typing import Any, Callable, Dict, List, Optional, Tuple
5
+ import packaging.version as pv
5
6
  import torch
6
7
  import transformers
7
8
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
@@ -20,18 +21,41 @@ def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) ->
20
21
  ]
21
22
  if bh_indices:
22
23
  dimensions.extend([(None, 0, None, None), (0, None, None, None)])
24
+ # reshape
23
25
  dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
24
26
  dimensions = tuple(reversed(dimensions))
25
27
  indices = tuple(shape.index(-1) for shape in dimensions)
26
28
 
29
+ # unsqueeze
30
+ udimensions = [tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions]
31
+
27
32
  def vector_mask_function(
28
33
  *args, mask_function=mask_function, dimensions=dimensions, indices=indices
29
34
  ):
30
- assert len(args) == len(
31
- dimensions
32
- ), f"Mismatch between args={string_type(args)} and dimensions={dimensions}"
35
+ assert len(args) == len(dimensions) == len(udimensions), (
36
+ f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
37
+ f"and udimensions={udimensions}."
38
+ )
39
+ assert len(indices) == len(args), (
40
+ f"Mismatch between args={string_type(args)} and indices={indices}, "
41
+ f"they should have the same length."
42
+ )
43
+ for a in args:
44
+ assert (
45
+ a.ndim == 1
46
+ ), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
47
+ torch._check(a.shape[0] > 0)
48
+
33
49
  new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
50
+ # new_args = [
51
+ # a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
52
+ # for a, dims in zip(args, udimensions)
53
+ # ]
34
54
  max_shape = tuple(args[i].shape[0] for i in indices)
55
+ # if is_torchdynamo_exporting():
56
+ # for a in args:
57
+ # # The exporter should export with a dimension > 1 to make sure it is dynamic.
58
+ # torch._check(a.shape[0] > 1)
35
59
  expanded_args = [a.expand(max_shape) for a in new_args]
36
60
  return mask_function(*expanded_args)
37
61
 
@@ -190,8 +214,8 @@ class patched_DynamicCache:
190
214
  if len(self.key_cache) <= layer_idx:
191
215
  # There may be skipped layers, fill them with empty lists
192
216
  for _ in range(len(self.key_cache), layer_idx):
193
- self.key_cache.append(torch.tensor([]))
194
- self.value_cache.append(torch.tensor([]))
217
+ self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
218
+ self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
195
219
  self.key_cache.append(key_states)
196
220
  self.value_cache.append(value_states)
197
221
  elif not self.key_cache[
@@ -207,7 +231,6 @@ class patched_DynamicCache:
207
231
  self.value_cache[layer_idx] = torch.cat(
208
232
  [self.value_cache[layer_idx], value_states], dim=-2
209
233
  )
210
-
211
234
  return self.key_cache[layer_idx], self.value_cache[layer_idx]
212
235
 
213
236
  def crop(self, max_length: int):
@@ -791,10 +814,7 @@ def patched_dynamic_rope_update(rope_forward):
791
814
  return wrapper
792
815
 
793
816
 
794
- class patched_Phi3RotaryEmbedding(torch.nn.Module):
795
- _PATCHES_ = ["forward"]
796
- _PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
797
-
817
+ class common_RotaryEmbedding(torch.nn.Module):
798
818
  @torch.no_grad()
799
819
  @patched_dynamic_rope_update
800
820
  def forward(self, x, position_ids):
@@ -820,6 +840,65 @@ class patched_Phi3RotaryEmbedding(torch.nn.Module):
820
840
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
821
841
 
822
842
 
843
+ class patched_GemmaRotaryEmbedding(common_RotaryEmbedding):
844
+ _PATCHES_ = ["forward"]
845
+ _PATCHED_CLASS_ = transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding
846
+
847
+
848
+ if pv.Version(transformers.__version__) >= pv.Version("4.52"):
849
+
850
+ class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding):
851
+ _PATCHES_ = ["forward"]
852
+ _PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding
853
+
854
+ class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding):
855
+ _PATCHES_ = ["forward"]
856
+ _PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding
857
+
858
+
859
+ class patched_LlamaRotaryEmbedding(common_RotaryEmbedding):
860
+ _PATCHES_ = ["forward"]
861
+ _PATCHED_CLASS_ = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
862
+
863
+
864
+ class patched_MistralRotaryEmbedding(common_RotaryEmbedding):
865
+ _PATCHES_ = ["forward"]
866
+ _PATCHED_CLASS_ = transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding
867
+
868
+
869
+ class patched_MixtralRotaryEmbedding(common_RotaryEmbedding):
870
+ _PATCHES_ = ["forward"]
871
+ _PATCHED_CLASS_ = transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding
872
+
873
+
874
+ class patched_PhiRotaryEmbedding(common_RotaryEmbedding):
875
+ _PATCHES_ = ["forward"]
876
+ _PATCHED_CLASS_ = transformers.models.phi.modeling_phi.PhiRotaryEmbedding
877
+
878
+
879
+ if pv.Version(transformers.__version__) >= pv.Version("4.51"):
880
+
881
+ class patched_Phi3RotaryEmbedding(common_RotaryEmbedding):
882
+ _PATCHES_ = ["forward"]
883
+ _PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
884
+
885
+
886
+ if pv.Version(transformers.__version__) >= pv.Version("4.52"):
887
+
888
+ class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding):
889
+ _PATCHES_ = ["forward"]
890
+ _PATCHED_CLASS_ = (
891
+ transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalRotaryEmbedding
892
+ )
893
+
894
+
895
+ if pv.Version(transformers.__version__) >= pv.Version("4.53"):
896
+
897
+ class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding):
898
+ _PATCHES_ = ["forward"]
899
+ _PATCHED_CLASS_ = transformers.models.smollm3.modeling_smollm3.SmolLM3RotaryEmbedding
900
+
901
+
823
902
  class patched_IdeficsEmbedding(torch.nn.Module):
824
903
  _PATCHES_ = ["forward"]
825
904
  _PATCHED_CLASS_ = transformers.models.idefics.modeling_idefics.IdeficsEmbedding
@@ -0,0 +1,46 @@
1
+ import re
2
+ from typing import Any, Callable, List, Set, Tuple
3
+ import torch
4
+
5
+
6
+ def _lower_name_with_(name):
7
+ s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
8
+ return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
9
+
10
+
11
+ def make_serialization_function_for_dataclass(
12
+ cls: type, supported_classes: Set[type]
13
+ ) -> Tuple[Callable, Callable, Callable]:
14
+ """
15
+ Automatically creates serialization function for a class decorated with
16
+ ``dataclasses.dataclass``.
17
+ """
18
+
19
+ def flatten_cls(obj: cls) -> Tuple[List[Any], torch.utils._pytree.Context]: # type: ignore[valid-type]
20
+ """Serializes a ``%s`` with python objects."""
21
+ return list(obj.values()), list(obj.keys())
22
+
23
+ def flatten_with_keys_cls(
24
+ obj: cls, # type: ignore[valid-type]
25
+ ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
26
+ """Serializes a ``%s`` with python objects with keys."""
27
+ values, context = list(obj.values()), list(obj.keys())
28
+ return [
29
+ (torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)
30
+ ], context
31
+
32
+ def unflatten_cls(
33
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
34
+ ) -> cls: # type: ignore[valid-type]
35
+ """Restores an instance of ``%s`` from python objects."""
36
+ return cls(**dict(zip(context, values)))
37
+
38
+ name = _lower_name_with_(cls.__name__)
39
+ flatten_cls.__name__ = f"flatten_{name}"
40
+ flatten_with_keys_cls.__name__ = f"flatten_with_keys_{name}"
41
+ unflatten_cls.__name__ = f"unflatten_{name}"
42
+ flatten_cls.__doc__ = flatten_cls.__doc__ % cls.__name__
43
+ flatten_with_keys_cls.__doc__ = flatten_with_keys_cls.__doc__ % cls.__name__
44
+ unflatten_cls.__doc__ = unflatten_cls.__doc__ % cls.__name__
45
+ supported_classes.add(cls)
46
+ return flatten_cls, flatten_with_keys_cls, unflatten_cls
@@ -0,0 +1,34 @@
1
+ from typing import Dict, Optional, Set
2
+
3
+ try:
4
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
5
+ except ImportError as e:
6
+ try:
7
+ import diffusers
8
+ except ImportError:
9
+ diffusers = None
10
+ UNet2DConditionOutput = None
11
+ if diffusers:
12
+ raise e
13
+
14
+ from . import make_serialization_function_for_dataclass
15
+
16
+
17
+ def _make_wrong_registrations() -> Dict[type, Optional[str]]:
18
+ res: Dict[type, Optional[str]] = {}
19
+ for c in [UNet2DConditionOutput]:
20
+ if c is not None:
21
+ res[c] = None
22
+ return res
23
+
24
+
25
+ SUPPORTED_DATACLASSES: Set[type] = set()
26
+ WRONG_REGISTRATIONS = _make_wrong_registrations()
27
+
28
+
29
+ if UNet2DConditionOutput is not None:
30
+ (
31
+ flatten_u_net2_d_condition_output,
32
+ flatten_with_keys_u_net2_d_condition_output,
33
+ unflatten_u_net2_d_condition_output,
34
+ ) = make_serialization_function_for_dataclass(UNet2DConditionOutput, SUPPORTED_DATACLASSES)
@@ -0,0 +1,259 @@
1
+ from typing import Any, List, Set, Tuple
2
+ import torch
3
+ import transformers
4
+ from transformers.cache_utils import (
5
+ DynamicCache,
6
+ MambaCache,
7
+ EncoderDecoderCache,
8
+ SlidingWindowCache,
9
+ StaticCache,
10
+ )
11
+ from transformers.modeling_outputs import BaseModelOutput
12
+ from ...helpers.cache_helper import make_static_cache
13
+ from . import make_serialization_function_for_dataclass
14
+
15
+
16
+ SUPPORTED_DATACLASSES: Set[type] = set()
17
+ WRONG_REGISTRATIONS = {
18
+ DynamicCache: "4.50",
19
+ BaseModelOutput: None,
20
+ }
21
+
22
+
23
+ ############
24
+ # MambaCache
25
+ ############
26
+
27
+
28
+ def flatten_mamba_cache(
29
+ mamba_cache: MambaCache,
30
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
31
+ """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
32
+ flat = [
33
+ ("conv_states", mamba_cache.conv_states),
34
+ ("ssm_states", mamba_cache.ssm_states),
35
+ ]
36
+ return [f[1] for f in flat], [f[0] for f in flat]
37
+
38
+
39
+ def unflatten_mamba_cache(
40
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
41
+ ) -> MambaCache:
42
+ """Restores a :class:`transformers.cache_utils.MambaCache` from python objects."""
43
+ conv_states, ssm_states = values
44
+
45
+ class _config:
46
+ def __init__(self):
47
+ if isinstance(conv_states, list):
48
+ self.intermediate_size = conv_states[0].shape[1]
49
+ self.state_size = ssm_states[0].shape[2]
50
+ self.conv_kernel = conv_states[0].shape[2]
51
+ self.num_hidden_layers = len(conv_states)
52
+ else:
53
+ self.intermediate_size = conv_states.shape[2]
54
+ self.state_size = ssm_states.shape[3]
55
+ self.conv_kernel = conv_states.shape[3]
56
+ self.num_hidden_layers = conv_states.shape[0]
57
+
58
+ cache = MambaCache(
59
+ _config(),
60
+ max_batch_size=1,
61
+ dtype=values[-1][0].dtype,
62
+ device="cpu" if values[-1][0].get_device() < 0 else "cuda",
63
+ )
64
+ values = dict(zip(context, values))
65
+ for k, v in values.items():
66
+ setattr(cache, k, v)
67
+ return cache
68
+
69
+
70
+ def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[
71
+ List[Tuple[torch.utils._pytree.KeyEntry, Any]],
72
+ torch.utils._pytree.Context,
73
+ ]:
74
+ """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
75
+ values, context = flatten_mamba_cache(cache)
76
+ return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
77
+
78
+
79
+ ##############
80
+ # DynamicCache
81
+ ##############
82
+
83
+
84
+ def flatten_dynamic_cache(
85
+ dynamic_cache: DynamicCache,
86
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
87
+ """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
88
+ if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"):
89
+ return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache)
90
+ flat = [("key_cache", dynamic_cache.key_cache), ("value_cache", dynamic_cache.value_cache)]
91
+ return [f[1] for f in flat], [f[0] for f in flat]
92
+
93
+
94
+ def flatten_with_keys_dynamic_cache(
95
+ dynamic_cache: DynamicCache,
96
+ ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
97
+ """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
98
+ if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"):
99
+ return transformers.cache_utils._flatten_with_keys_dynamic_cache(dynamic_cache)
100
+ values, context = flatten_dynamic_cache(dynamic_cache)
101
+ return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
102
+
103
+
104
+ def unflatten_dynamic_cache(
105
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
106
+ ) -> DynamicCache:
107
+ """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
108
+ if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"):
109
+ assert output_type is None, f"output_type={output_type} not supported"
110
+ return transformers.cache_utils._unflatten_dynamic_cache(values, context)
111
+
112
+ cache = transformers.cache_utils.DynamicCache()
113
+ values = dict(zip(context, values))
114
+ for k, v in values.items():
115
+ setattr(cache, k, v)
116
+ return cache
117
+
118
+
119
+ #############
120
+ # StaticCache
121
+ #############
122
+
123
+
124
+ def flatten_static_cache(
125
+ cache: StaticCache,
126
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
127
+ """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
128
+ assert not cache.key_cache or cache.max_cache_len == cache.key_cache[0].shape[2], (
129
+ f"Serialization doet not work when "
130
+ f"cache.max_cache_len={cache.max_cache_len} != "
131
+ f"cache.key_cache[0].shape[2]={cache.key_cache[0].shape[2]}"
132
+ )
133
+ flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
134
+ return [f[1] for f in flat], [f[0] for f in flat]
135
+
136
+
137
+ def flatten_with_keys_static_cache(
138
+ cache: StaticCache,
139
+ ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
140
+ """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
141
+ values, context = flatten_static_cache(cache)
142
+ return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
143
+
144
+
145
+ def unflatten_static_cache(
146
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
147
+ ) -> StaticCache:
148
+ """Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
149
+ return make_static_cache(
150
+ list(zip(values[0], values[1])), max_cache_len=values[0][0].shape[2]
151
+ )
152
+
153
+
154
+ ####################
155
+ # SlidingWindowCache
156
+ ####################
157
+
158
+
159
+ def flatten_sliding_window_cache(
160
+ cache: SlidingWindowCache,
161
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
162
+ """
163
+ Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
164
+ with python objects.
165
+ """
166
+ flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
167
+ return [f[1] for f in flat], [f[0] for f in flat]
168
+
169
+
170
+ def flatten_with_keys_sliding_window_cache(
171
+ cache: SlidingWindowCache,
172
+ ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
173
+ """
174
+ Serializes a :class:`transformers.cache_utils.SlidingWindowCache`
175
+ with python objects.
176
+ """
177
+ values, context = flatten_sliding_window_cache(cache)
178
+ return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
179
+
180
+
181
+ def unflatten_sliding_window_cache(
182
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
183
+ ) -> SlidingWindowCache:
184
+ """Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects."""
185
+ key_cache, value_cache = values
186
+
187
+ class _config:
188
+ def __init__(self):
189
+ self.head_dim = key_cache[0].shape[-1]
190
+ self.num_attention_heads = key_cache[0].shape[1]
191
+ self.num_hidden_layers = len(key_cache)
192
+ self.sliding_window = key_cache[0].shape[2]
193
+
194
+ cache = SlidingWindowCache(
195
+ _config(),
196
+ max_batch_size=key_cache[0].shape[0],
197
+ max_cache_len=key_cache[0].shape[2], # sligding window
198
+ device=key_cache[0].device,
199
+ dtype=key_cache[0].dtype,
200
+ )
201
+
202
+ values = dict(zip(context, values))
203
+ for k, v in values.items():
204
+ setattr(cache, k, v)
205
+ return cache
206
+
207
+
208
+ #####################
209
+ # EncoderDecoderCache
210
+ #####################
211
+
212
+
213
+ def flatten_encoder_decoder_cache(
214
+ ec_cache: EncoderDecoderCache,
215
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
216
+ """
217
+ Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
218
+ with python objects.
219
+ """
220
+ dictionary = {
221
+ "self_attention_cache": ec_cache.self_attention_cache,
222
+ "cross_attention_cache": ec_cache.cross_attention_cache,
223
+ }
224
+ return torch.utils._pytree._dict_flatten(dictionary)
225
+
226
+
227
+ def flatten_with_keys_encoder_decoder_cache(ec_cache: EncoderDecoderCache) -> Tuple[
228
+ List[Tuple[torch.utils._pytree.KeyEntry, Any]],
229
+ torch.utils._pytree.Context,
230
+ ]:
231
+ """
232
+ Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
233
+ with python objects.
234
+ """
235
+ dictionary = {
236
+ "self_attention_cache": ec_cache.self_attention_cache,
237
+ "cross_attention_cache": ec_cache.cross_attention_cache,
238
+ }
239
+ return torch.utils._pytree._dict_flatten_with_keys(dictionary)
240
+
241
+
242
+ def unflatten_encoder_decoder_cache(
243
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
244
+ ) -> EncoderDecoderCache:
245
+ """Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects."""
246
+ dictionary = torch.utils._pytree._dict_unflatten(values, context)
247
+ return EncoderDecoderCache(**dictionary)
248
+
249
+
250
+ #############
251
+ # dataclasses
252
+ #############
253
+
254
+
255
+ (
256
+ flatten_base_model_output,
257
+ flatten_with_keys_base_model_output,
258
+ unflatten_base_model_output,
259
+ ) = make_serialization_function_for_dataclass(BaseModelOutput, SUPPORTED_DATACLASSES)
@@ -140,7 +140,10 @@ def _guess_task_from_config(config: Any) -> Optional[str]:
140
140
 
141
141
  @functools.cache
142
142
  def task_from_arch(
143
- arch: str, default_value: Optional[str] = None, model_id: Optional[str] = None
143
+ arch: str,
144
+ default_value: Optional[str] = None,
145
+ model_id: Optional[str] = None,
146
+ subfolder: Optional[str] = None,
144
147
  ) -> str:
145
148
  """
146
149
  This function relies on stored information. That information needs to be refresh.
@@ -148,6 +151,7 @@ def task_from_arch(
148
151
  :param arch: architecture name
149
152
  :param default_value: default value in case the task cannot be determined
150
153
  :param model_id: unused unless the architecture does not help.
154
+ :param subfolder: subfolder
151
155
  :return: task
152
156
 
153
157
  .. runpython::
@@ -162,7 +166,7 @@ def task_from_arch(
162
166
  data = load_architecture_task()
163
167
  if arch not in data and model_id:
164
168
  # Let's try with the model id.
165
- return task_from_id(model_id)
169
+ return task_from_id(model_id, subfolder=subfolder)
166
170
  if default_value is not None:
167
171
  return data.get(arch, default_value)
168
172
  assert arch in data, (
@@ -178,6 +182,7 @@ def task_from_id(
178
182
  default_value: Optional[str] = None,
179
183
  pretrained: bool = False,
180
184
  fall_back_to_pretrained: bool = True,
185
+ subfolder: Optional[str] = None,
181
186
  ) -> str:
182
187
  """
183
188
  Returns the task attached to a model id.
@@ -187,7 +192,7 @@ def task_from_id(
187
192
  if the task cannot be determined
188
193
  :param pretrained: uses the config
189
194
  :param fall_back_to_pretrained: falls back to pretrained config
190
- :param exc: raises an exception if True
195
+ :param subfolder: subfolder
191
196
  :return: task
192
197
  """
193
198
  if not pretrained:
@@ -196,7 +201,7 @@ def task_from_id(
196
201
  except RuntimeError:
197
202
  if not fall_back_to_pretrained:
198
203
  raise
199
- config = get_pretrained_config(model_id)
204
+ config = get_pretrained_config(model_id, subfolder=subfolder)
200
205
  try:
201
206
  return config.pipeline_tag
202
207
  except AttributeError:
@@ -206,6 +211,12 @@ def task_from_id(
206
211
  data = load_architecture_task()
207
212
  if model_id in data:
208
213
  return data[model_id]
214
+ if type(config) is dict and "_class_name" in config:
215
+ return task_from_arch(config["_class_name"], default_value=default_value)
216
+ if not config.architectures or not config.architectures:
217
+ # Some hardcoded values until a better solution is found.
218
+ if model_id.startswith("google/bert_"):
219
+ return "fill-mask"
209
220
  assert config.architectures is not None and len(config.architectures) == 1, (
210
221
  f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
211
222
  f"architectures={config.architectures} in config={config}. "
@@ -22,6 +22,7 @@ __data_arch__ = textwrap.dedent(
22
22
  BlenderbotModel,feature-extraction
23
23
  BloomModel,feature-extraction
24
24
  CLIPModel,zero-shot-image-classification
25
+ CLIPTextModel,feature-extraction
25
26
  CLIPVisionModel,feature-extraction
26
27
  CamembertModel,feature-extraction
27
28
  CodeGenModel,feature-extraction
@@ -4302,3 +4302,31 @@ def _ccached_microsoft_phi_35_mini_instruct():
4302
4302
  "vocab_size": 32064,
4303
4303
  }
4304
4304
  )
4305
+
4306
+
4307
+ def _ccached_diffusers_tiny_torch_full_checker_unet():
4308
+ "diffusers/tiny-torch-full-checker/unet"
4309
+ return {
4310
+ "_class_name": "UNet2DConditionModel",
4311
+ "_diffusers_version": "0.8.0",
4312
+ "_name_or_path": "https://huggingface.co/diffusers/tiny-torch-full-checker/blob/main/unet/config.json",
4313
+ "act_fn": "silu",
4314
+ "attention_head_dim": 8,
4315
+ "block_out_channels": [32, 64],
4316
+ "center_input_sample": false,
4317
+ "cross_attention_dim": 32,
4318
+ "down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D"],
4319
+ "downsample_padding": 1,
4320
+ "dual_cross_attention": false,
4321
+ "flip_sin_to_cos": true,
4322
+ "freq_shift": 0,
4323
+ "in_channels": 4,
4324
+ "layers_per_block": 2,
4325
+ "mid_block_scale_factor": 1,
4326
+ "norm_eps": 1e-05,
4327
+ "norm_num_groups": 32,
4328
+ "out_channels": 4,
4329
+ "sample_size": 32,
4330
+ "up_block_types": ["CrossAttnUpBlock2D", "UpBlock2D"],
4331
+ "use_linear_projection": false,
4332
+ }
@@ -106,7 +106,7 @@ def get_untrained_model_with_inputs(
106
106
  print(f"[get_untrained_model_with_inputs] architectures={archs!r}")
107
107
  print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
108
108
  if task is None:
109
- task = task_from_arch(archs[0], model_id=model_id)
109
+ task = task_from_arch(archs[0], model_id=model_id, subfolder=subfolder)
110
110
  if verbose:
111
111
  print(f"[get_untrained_model_with_inputs] task={task!r}")
112
112
 
@@ -145,12 +145,19 @@ def get_untrained_model_with_inputs(
145
145
  f"{config._attn_implementation!r}" # type: ignore[union-attr]
146
146
  )
147
147
 
148
+ if type(config) is dict and "_diffusers_version" in config:
149
+ import diffusers
150
+
151
+ package_source = diffusers
152
+ else:
153
+ package_source = transformers
154
+
148
155
  if use_pretrained:
149
156
  model = transformers.AutoModel.from_pretrained(model_id, **mkwargs)
150
157
  else:
151
158
  if archs is not None:
152
159
  try:
153
- model = getattr(transformers, archs[0])(config)
160
+ cls_model = getattr(package_source, archs[0])
154
161
  except AttributeError as e:
155
162
  # The code of the models is not in transformers but in the
156
163
  # repository of the model. We need to download it.
@@ -174,10 +181,12 @@ def get_untrained_model_with_inputs(
174
181
  f"[get_untrained_model_with_inputs] from folder "
175
182
  f"{os.path.split(pyfiles[0])[0]!r}"
176
183
  )
177
- cls = transformers.dynamic_module_utils.get_class_from_dynamic_module(
178
- cls_name, pretrained_model_name_or_path=os.path.split(pyfiles[0])[0]
184
+ cls_model = (
185
+ transformers.dynamic_module_utils.get_class_from_dynamic_module(
186
+ cls_name,
187
+ pretrained_model_name_or_path=os.path.split(pyfiles[0])[0],
188
+ )
179
189
  )
180
- model = cls(config)
181
190
  else:
182
191
  raise AttributeError(
183
192
  f"Unable to find class 'tranformers.{archs[0]}'. "
@@ -191,6 +200,16 @@ def get_untrained_model_with_inputs(
191
200
  f"and use_pretrained=True."
192
201
  )
193
202
 
203
+ try:
204
+ if type(config) is dict:
205
+ model = cls_model(**config)
206
+ else:
207
+ model = cls_model(config)
208
+ except RuntimeError as e:
209
+ raise RuntimeError(
210
+ f"Unable to instantiate class {cls_model.__name__} with\n{config}"
211
+ ) from e
212
+
194
213
  # input kwargs
195
214
  kwargs, fct = random_input_kwargs(config, task)
196
215
  if verbose: