onnx-diagnostic 0.6.3__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +281 -80
  3. onnx_diagnostic/doc.py +22 -0
  4. onnx_diagnostic/export/dynamic_shapes.py +48 -20
  5. onnx_diagnostic/export/shape_helper.py +126 -0
  6. onnx_diagnostic/ext_test_case.py +1 -1
  7. onnx_diagnostic/helpers/cache_helper.py +78 -8
  8. onnx_diagnostic/helpers/config_helper.py +8 -4
  9. onnx_diagnostic/helpers/helper.py +30 -3
  10. onnx_diagnostic/helpers/log_helper.py +1744 -0
  11. onnx_diagnostic/helpers/mini_onnx_builder.py +4 -1
  12. onnx_diagnostic/helpers/model_builder_helper.py +54 -73
  13. onnx_diagnostic/helpers/torch_helper.py +18 -2
  14. onnx_diagnostic/reference/__init__.py +1 -0
  15. onnx_diagnostic/reference/ort_evaluator.py +29 -4
  16. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  17. onnx_diagnostic/reference/torch_evaluator.py +21 -0
  18. onnx_diagnostic/tasks/automatic_speech_recognition.py +3 -0
  19. onnx_diagnostic/tasks/feature_extraction.py +3 -0
  20. onnx_diagnostic/tasks/fill_mask.py +3 -0
  21. onnx_diagnostic/tasks/image_classification.py +7 -1
  22. onnx_diagnostic/tasks/image_text_to_text.py +72 -18
  23. onnx_diagnostic/tasks/mixture_of_expert.py +3 -0
  24. onnx_diagnostic/tasks/object_detection.py +3 -0
  25. onnx_diagnostic/tasks/sentence_similarity.py +3 -0
  26. onnx_diagnostic/tasks/summarization.py +3 -0
  27. onnx_diagnostic/tasks/text2text_generation.py +3 -0
  28. onnx_diagnostic/tasks/text_classification.py +3 -0
  29. onnx_diagnostic/tasks/text_generation.py +90 -43
  30. onnx_diagnostic/tasks/zero_shot_image_classification.py +3 -0
  31. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +78 -25
  32. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +37 -0
  33. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +365 -17
  34. onnx_diagnostic/torch_models/hghub/hub_api.py +81 -8
  35. onnx_diagnostic/torch_models/hghub/hub_data.py +6 -2
  36. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +209 -0
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +58 -14
  38. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +23 -50
  39. onnx_diagnostic/torch_models/{test_helper.py → validate.py} +166 -106
  40. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/METADATA +2 -2
  41. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/RECORD +44 -41
  42. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/WHEEL +0 -0
  43. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
  44. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/top_level.txt +0 -0
@@ -52,6 +52,9 @@ def get_inputs(
52
52
  :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
53
53
  :return: dictionary
54
54
  """
55
+ assert (
56
+ "cls_cache" not in kwargs
57
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
55
58
  batch = torch.export.Dim("batch", min=1, max=1024)
56
59
  seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
57
60
  cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
@@ -93,10 +96,10 @@ def get_inputs(
93
96
  for i in range(num_hidden_layers)
94
97
  ]
95
98
  ),
96
- image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
99
+ pixel_values=torch.ones((batch_size, n_images, num_channels, width, height)).to(
97
100
  torch.int64
98
101
  ),
99
- pixel_values=torch.ones((batch_size, n_images, num_channels, width, height)).to(
102
+ image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
100
103
  torch.int64
101
104
  ),
102
105
  )
@@ -129,16 +132,30 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
129
132
  If the configuration is None, the function selects typical dimensions.
130
133
  """
131
134
  if config is not None:
132
- check_hasattr(
133
- config,
134
- "vocab_size",
135
- "hidden_size",
136
- "num_attention_heads",
137
- ("num_key_value_heads", "num_attention_heads"),
138
- "intermediate_size",
139
- "hidden_size",
140
- "vision_config",
141
- )
135
+ if hasattr(config, "text_config"):
136
+ check_hasattr(
137
+ config.text_config,
138
+ "vocab_size",
139
+ "hidden_size",
140
+ "num_attention_heads",
141
+ ("num_key_value_heads", "num_attention_heads"),
142
+ "intermediate_size",
143
+ "hidden_size",
144
+ )
145
+ check_hasattr(config, "vision_config")
146
+ text_config = True
147
+ else:
148
+ check_hasattr(
149
+ config,
150
+ "vocab_size",
151
+ "hidden_size",
152
+ "num_attention_heads",
153
+ ("num_key_value_heads", "num_attention_heads"),
154
+ "intermediate_size",
155
+ "hidden_size",
156
+ "vision_config",
157
+ )
158
+ text_config = False
142
159
  check_hasattr(config.vision_config, "image_size", "num_channels")
143
160
  kwargs = dict(
144
161
  batch_size=2,
@@ -147,17 +164,54 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
147
164
  head_dim=(
148
165
  16
149
166
  if config is None
150
- else getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
167
+ else getattr(
168
+ config,
169
+ "head_dim",
170
+ (config.text_config.hidden_size if text_config else config.hidden_size)
171
+ // (
172
+ config.text_config.num_attention_heads
173
+ if text_config
174
+ else config.num_attention_heads
175
+ ),
176
+ )
177
+ ),
178
+ dummy_max_token_id=(
179
+ 31999
180
+ if config is None
181
+ else (config.text_config.vocab_size if text_config else config.vocab_size) - 1
182
+ ),
183
+ num_hidden_layers=(
184
+ 4
185
+ if config is None
186
+ else (
187
+ config.text_config.num_hidden_layers
188
+ if text_config
189
+ else config.num_hidden_layers
190
+ )
151
191
  ),
152
- dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
153
- num_hidden_layers=4 if config is None else config.num_hidden_layers,
154
192
  num_key_value_heads=(
155
193
  8
156
194
  if config is None
157
- else _pick(config, "num_key_value_heads", "num_attention_heads")
195
+ else (
196
+ _pick(config.text_config, "num_key_value_heads", "num_attention_heads")
197
+ if text_config
198
+ else _pick(config, "num_key_value_heads", "num_attention_heads")
199
+ )
200
+ ),
201
+ intermediate_size=(
202
+ 1024
203
+ if config is None
204
+ else (
205
+ config.text_config.intermediate_size
206
+ if text_config
207
+ else config.intermediate_size
208
+ )
209
+ ),
210
+ hidden_size=(
211
+ 512
212
+ if config is None
213
+ else (config.text_config.hidden_size if text_config else config.hidden_size)
158
214
  ),
159
- intermediate_size=1024 if config is None else config.intermediate_size,
160
- hidden_size=512 if config is None else config.hidden_size,
161
215
  width=224 if config is None else config.vision_config.image_size,
162
216
  height=224 if config is None else config.vision_config.image_size,
163
217
  num_channels=3 if config is None else config.vision_config.num_channels,
@@ -61,6 +61,9 @@ def get_inputs(
61
61
  :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
62
62
  :return: dictionary
63
63
  """
64
+ assert (
65
+ "cls_cache" not in kwargs
66
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
64
67
  assert not add_second_input, "add_second_input=True not yet implemented"
65
68
  raise NotImplementedError(f"get_inputs not yet implemented for task {__TASK__!r}.")
66
69
 
@@ -41,6 +41,9 @@ def get_inputs(
41
41
  :param input_height: input height
42
42
  :return: dictionary
43
43
  """
44
+ assert (
45
+ "cls_cache" not in kwargs
46
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
44
47
  assert isinstance(
45
48
  input_width, int
46
49
  ), f"Unexpected type for input_width {type(input_width)}{config}"
@@ -35,6 +35,9 @@ def get_inputs(
35
35
  token_type_ids:T7s1x13[0,0:A0.0],
36
36
  attention_mask:T7s1x13[1,1:A1.0])
37
37
  """
38
+ assert (
39
+ "cls_cache" not in kwargs
40
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
38
41
  batch = torch.export.Dim("batch", min=1, max=1024)
39
42
  seq_length = "seq_length"
40
43
  shapes = {
@@ -62,6 +62,9 @@ def get_inputs(
62
62
  decoder_input_ids:T7s1x1,
63
63
  encoder_outputs:dict(last_hidden_state:T1s1x16x512)
64
64
  """
65
+ assert (
66
+ "cls_cache" not in kwargs
67
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
65
68
  batch = torch.export.Dim("batch", min=1, max=1024)
66
69
  seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
67
70
  cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
@@ -64,6 +64,9 @@ def get_inputs(
64
64
  decoder_input_ids:T7s1x1,
65
65
  encoder_outputs:dict(last_hidden_state:T1s1x16x512)
66
66
  """
67
+ assert (
68
+ "cls_cache" not in kwargs
69
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
67
70
  batch = torch.export.Dim("batch", min=1, max=1024)
68
71
  seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
69
72
  cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
@@ -35,6 +35,9 @@ def get_inputs(
35
35
  token_type_ids:T7s1x13[0,0:A0.0],
36
36
  attention_mask:T7s1x13[1,1:A1.0])
37
37
  """
38
+ assert (
39
+ "cls_cache" not in kwargs
40
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
38
41
  batch = torch.export.Dim("batch", min=1, max=1024)
39
42
  seq_length = "seq_length" # torch.export.Dim("sequence_length", min=1, max=1024)
40
43
  shapes = {
@@ -5,6 +5,7 @@ from ..helpers.cache_helper import (
5
5
  make_dynamic_cache,
6
6
  make_mamba_cache,
7
7
  make_sliding_window_cache,
8
+ make_static_cache,
8
9
  )
9
10
  from ..helpers.config_helper import update_config, check_hasattr, _pick
10
11
 
@@ -151,52 +152,98 @@ def get_inputs(
151
152
  assert config, "head_dim is None, the value cannot be set without a configuration"
152
153
  head_dim = config.hidden_size // config.num_attention_heads
153
154
 
154
- shapes = {
155
- "input_ids": {0: batch, 1: seq_length},
156
- "attention_mask": {
157
- 0: batch,
158
- 1: "cache+seq", # cache_length + seq_length
159
- },
160
- "position_ids": {
161
- 0: batch,
162
- 1: "cache+seq", # cache_length + seq_length
163
- },
164
- "past_key_values": [
165
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
166
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
167
- ],
155
+ cache_name = (
156
+ cls_cache
157
+ if cls_cache is None or isinstance(cls_cache, str)
158
+ else cls_cache.__name__
159
+ )
160
+ make_caches = {
161
+ "DynamicCache": make_dynamic_cache,
162
+ "SlidingWindowCache": make_sliding_window_cache,
163
+ "StaticCache": make_static_cache,
168
164
  }
169
-
170
- make_cache = (
171
- make_sliding_window_cache
172
- if cls_cache in ("SlidingWindowCache", transformers.cache_utils.SlidingWindowCache)
173
- else make_dynamic_cache
165
+ assert cache_name is None or cache_name in make_caches, (
166
+ f"Unable to handle cls_cache={cache_name!r}, it should be in "
167
+ f"{sorted(make_caches)}"
174
168
  )
169
+ make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name]
170
+ is_static = cache_name == "StaticCache"
175
171
 
176
- inputs = dict(
177
- input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
178
- torch.int64
179
- ),
180
- attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
181
- torch.int64
182
- ),
183
- position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
184
- .to(torch.int64)
185
- .expand((batch_size, -1)),
186
- past_key_values=make_cache(
187
- [
188
- (
189
- torch.randn(
190
- batch_size, num_key_value_heads, sequence_length, head_dim
191
- ),
192
- torch.randn(
193
- batch_size, num_key_value_heads, sequence_length, head_dim
194
- ),
195
- )
196
- for i in range(num_hidden_layers)
197
- ]
198
- ),
199
- )
172
+ if is_static:
173
+ # static
174
+ shapes = {
175
+ "input_ids": {0: batch, 1: seq_length},
176
+ "attention_mask": {0: batch, 2: "seq"},
177
+ "cache_position": {0: "seq"},
178
+ "past_key_values": [
179
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
180
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
181
+ ],
182
+ }
183
+ inputs = dict(
184
+ input_ids=torch.randint(
185
+ 0, dummy_max_token_id, (batch_size, sequence_length2)
186
+ ).to(torch.int64),
187
+ attention_mask=torch.ones(
188
+ (batch_size, num_key_value_heads, sequence_length2, head_dim)
189
+ ).to(torch.bool),
190
+ cache_position=torch.arange(sequence_length2).to(torch.int64),
191
+ past_key_values=make_cache(
192
+ [
193
+ (
194
+ torch.randn(
195
+ batch_size, num_key_value_heads, sequence_length, head_dim
196
+ ),
197
+ torch.randn(
198
+ batch_size, num_key_value_heads, sequence_length, head_dim
199
+ ),
200
+ )
201
+ for i in range(num_hidden_layers)
202
+ ]
203
+ ),
204
+ )
205
+ else:
206
+ # dynamic
207
+ shapes = {
208
+ "input_ids": {0: batch, 1: seq_length},
209
+ "attention_mask": {
210
+ 0: batch,
211
+ 1: "cache+seq", # cache_length + seq_length
212
+ },
213
+ "position_ids": {
214
+ 0: batch,
215
+ 1: "cache+seq", # cache_length + seq_length
216
+ },
217
+ "past_key_values": [
218
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
219
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
220
+ ],
221
+ }
222
+
223
+ inputs = dict(
224
+ input_ids=torch.randint(
225
+ 0, dummy_max_token_id, (batch_size, sequence_length2)
226
+ ).to(torch.int64),
227
+ attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
228
+ torch.int64
229
+ ),
230
+ position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
231
+ .to(torch.int64)
232
+ .expand((batch_size, -1)),
233
+ past_key_values=make_cache(
234
+ [
235
+ (
236
+ torch.randn(
237
+ batch_size, num_key_value_heads, sequence_length, head_dim
238
+ ),
239
+ torch.randn(
240
+ batch_size, num_key_value_heads, sequence_length, head_dim
241
+ ),
242
+ )
243
+ for i in range(num_hidden_layers)
244
+ ]
245
+ ),
246
+ )
200
247
  res = dict(inputs=inputs, dynamic_shapes=shapes)
201
248
  if add_second_input:
202
249
  res["inputs2"] = get_inputs(
@@ -55,6 +55,9 @@ def get_inputs(
55
55
  # attention_mask:T7s2x7
56
56
  # pixel_values:T1s2x3x224x224
57
57
  """
58
+ assert (
59
+ "cls_cache" not in kwargs
60
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
58
61
  assert isinstance(
59
62
  input_width, int
60
63
  ), f"Unexpected type for input_width {type(input_width)}{config}"
@@ -1,5 +1,8 @@
1
+ import functools
2
+ import importlib
1
3
  import contextlib
2
- from typing import Any, Callable, Dict, List, Optional
4
+ import re
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple
3
6
  from .onnx_export_serialization import (
4
7
  register_cache_serialization,
5
8
  unregister_cache_serialization,
@@ -7,6 +10,41 @@ from .onnx_export_serialization import (
7
10
  from .patches import patch_transformers as patch_transformers_list
8
11
 
9
12
 
13
+ def get_function(name: str) -> Tuple[type, Callable]:
14
+ """Returns the module and the function based on its name."""
15
+ spl = name.split(".")
16
+ module_name = ".".join(spl[:-1])
17
+ fname = spl[-1]
18
+ mod = importlib.import_module(module_name)
19
+ return mod, getattr(mod, fname)
20
+
21
+
22
+ @functools.lru_cache
23
+ def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
24
+ """Returns the list of patches to make for a specific module."""
25
+ to_patch = []
26
+ for k in dir(mod):
27
+ if k.startswith("patched_"):
28
+ v = getattr(mod, k)
29
+ if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
30
+ to_patch.append(v)
31
+ else:
32
+ # a function
33
+ doc = v.__doc__.lstrip()
34
+ if doc.startswith("manual patch"):
35
+ continue
36
+ reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]")
37
+ fall = reg.findall(doc)
38
+ assert (
39
+ len(fall) == 1
40
+ ), f"Unable to find patching information for {v} in \n{doc}"
41
+ fmod, f = get_function(fall[0])
42
+ to_patch.append({"module": fmod, "function": f, "patch": v})
43
+
44
+ name = mod.__name__
45
+ return name, to_patch
46
+
47
+
10
48
  def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]:
11
49
  """
12
50
  Applies all patches defined in classes prefixed by ``patched_``
@@ -23,16 +61,21 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call
23
61
  to_patch = mod
24
62
  name = "list"
25
63
  else:
26
- to_patch = []
27
- for k in dir(mod):
28
- if k.startswith("patched_"):
29
- v = getattr(mod, k)
30
- if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
31
- to_patch.append(v)
32
- name = mod.__name__
64
+ name, to_patch = get_patches(mod, verbose)
33
65
 
34
66
  res = {}
35
67
  for cls in to_patch:
68
+ if isinstance(cls, dict):
69
+ # a function
70
+ keep = {}
71
+ original = cls["module"]
72
+ f = cls["function"]
73
+ res[f] = f
74
+ if verbose:
75
+ print(f"[patch_module_or_classes] function: {original.__name__}.{f.__name__}")
76
+ setattr(original, f.__name__, cls["patch"])
77
+ continue
78
+
36
79
  original = cls._PATCHED_CLASS_
37
80
  methods = cls._PATCHES_
38
81
  if verbose:
@@ -57,26 +100,36 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
57
100
  to_patch = mod
58
101
  name = "list"
59
102
  else:
60
- to_patch = []
61
- for k in dir(mod):
62
- if k.startswith("patched_"):
63
- v = getattr(mod, k)
64
- if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"):
65
- to_patch.append(v)
66
- name = mod.__name__
67
- set_patch = set(to_patch)
103
+ name, to_patch = get_patches(mod, verbose)
104
+
105
+ set_patch_cls = {i for i in to_patch if not isinstance(i, dict)}
106
+ dict_patch_fct = {i["function"]: i for i in to_patch if isinstance(i, dict)}
68
107
 
69
108
  for cls, methods in info.items():
70
- assert cls in set_patch, f"No patch registered for {cls} in {mod} (found {set_patch})"
109
+ if cls in set_patch_cls:
110
+ if verbose:
111
+ print(
112
+ f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}"
113
+ )
114
+ original = cls._PATCHED_CLASS_
115
+ for n, v in methods.items():
116
+ if v is None:
117
+ # The method did not exist. We remove it.
118
+ delattr(original, n)
119
+ else:
120
+ setattr(original, n, v)
121
+ continue
122
+ assert cls in dict_patch_fct, (
123
+ f"No patch registered for {cls} in {mod} "
124
+ f"(found {set_patch_cls} and {set(dict_patch_fct)})"
125
+ )
126
+ patch = dict_patch_fct[cls]
71
127
  if verbose:
72
- print(f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}")
73
- original = cls._PATCHED_CLASS_
74
- for n, v in methods.items():
75
- if v is None:
76
- # The method did not exist. We remove it.
77
- delattr(original, n)
78
- else:
79
- setattr(original, n, v)
128
+ print(
129
+ f"[unpatch_module_or_classes] function "
130
+ f"{patch['module'].__name__}.{cls.__name__}"
131
+ )
132
+ setattr(patch["module"], cls.__name__, patch["function"])
80
133
 
81
134
 
82
135
  @contextlib.contextmanager
@@ -9,9 +9,11 @@ from transformers.cache_utils import (
9
9
  MambaCache,
10
10
  EncoderDecoderCache,
11
11
  SlidingWindowCache,
12
+ StaticCache,
12
13
  )
13
14
  from transformers.modeling_outputs import BaseModelOutput
14
15
  from ..helpers import string_type
16
+ from ..helpers.cache_helper import make_static_cache
15
17
 
16
18
 
17
19
  PATCH_OF_PATCHES: Set[Any] = set()
@@ -175,6 +177,13 @@ def serialization_functions(verbose: int = 0) -> Dict[str, Union[Callable, int]]
175
177
  flatten_with_keys_sliding_window_cache,
176
178
  verbose=verbose,
177
179
  ),
180
+ StaticCache=register_class_serialization(
181
+ StaticCache,
182
+ flatten_static_cache,
183
+ unflatten_static_cache,
184
+ flatten_with_keys_static_cache,
185
+ verbose=verbose,
186
+ ),
178
187
  )
179
188
 
180
189
 
@@ -309,6 +318,34 @@ def unflatten_dynamic_cache(
309
318
  return cache
310
319
 
311
320
 
321
+ #############
322
+ # StaticCache
323
+ #############
324
+
325
+
326
+ def flatten_static_cache(
327
+ cache: StaticCache,
328
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
329
+ """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
330
+ flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
331
+ return [f[1] for f in flat], [f[0] for f in flat]
332
+
333
+
334
+ def flatten_with_keys_static_cache(
335
+ cache: StaticCache,
336
+ ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
337
+ """Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
338
+ values, context = flatten_static_cache(cache)
339
+ return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
340
+
341
+
342
+ def unflatten_static_cache(
343
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
344
+ ) -> StaticCache:
345
+ """Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
346
+ return make_static_cache(list(zip(values[0], values[1])))
347
+
348
+
312
349
  ####################
313
350
  # SlidingWindowCache
314
351
  ####################