onnx-diagnostic 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +66 -8
  3. onnx_diagnostic/ext_test_case.py +2 -0
  4. onnx_diagnostic/helpers/_log_helper.py +461 -0
  5. onnx_diagnostic/helpers/cache_helper.py +250 -15
  6. onnx_diagnostic/helpers/helper.py +146 -10
  7. onnx_diagnostic/helpers/log_helper.py +404 -315
  8. onnx_diagnostic/helpers/mini_onnx_builder.py +7 -2
  9. onnx_diagnostic/helpers/onnx_helper.py +13 -7
  10. onnx_diagnostic/helpers/torch_helper.py +33 -11
  11. onnx_diagnostic/tasks/__init__.py +2 -0
  12. onnx_diagnostic/tasks/feature_extraction.py +86 -5
  13. onnx_diagnostic/tasks/image_text_to_text.py +260 -56
  14. onnx_diagnostic/tasks/mask_generation.py +139 -0
  15. onnx_diagnostic/tasks/text2text_generation.py +2 -2
  16. onnx_diagnostic/tasks/text_generation.py +6 -2
  17. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +7 -1
  18. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +17 -1
  19. onnx_diagnostic/torch_export_patches/patch_inputs.py +4 -1
  20. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +397 -128
  21. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +57 -40
  22. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +288 -0
  23. onnx_diagnostic/torch_models/hghub/model_inputs.py +5 -0
  24. onnx_diagnostic/torch_models/validate.py +26 -3
  25. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/METADATA +1 -1
  26. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/RECORD +29 -27
  27. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/WHEEL +0 -0
  28. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/licenses/LICENSE.txt +0 -0
  29. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
- from ..helpers.cache_helper import make_dynamic_cache
3
+ from ..helpers.cache_helper import make_dynamic_cache, make_hybrid_cache
4
4
  from ..helpers.config_helper import update_config, check_hasattr, _pick
5
5
 
6
6
  __TASK__ = "image-text-to-text"
@@ -11,99 +11,284 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
11
11
  kwargs: Dict[str, Any] = {}
12
12
  if hasattr(config, "num_hidden_layers"):
13
13
  config.num_hidden_layers = min(config.num_hidden_layers, 2)
14
- if hasattr(config, "vision_config") and hasattr(config.vision_config, "num_hidden_layers"):
15
- config.vision_config.num_hidden_layers = min(config.vision_config.num_hidden_layers, 2)
14
+ if hasattr(config, "mm_tokens_per_image"):
15
+ config.mm_tokens_per_image = min(config.mm_tokens_per_image, 2)
16
+ if hasattr(config, "vision_config"):
17
+ if hasattr(config.vision_config, "num_hidden_layers"):
18
+ config.vision_config.num_hidden_layers = min(
19
+ config.vision_config.num_hidden_layers, 2
20
+ )
21
+ if hasattr(config.vision_config, "image_size"):
22
+ config.vision_config.image_size = min(config.vision_config.image_size, 96)
23
+ if hasattr(config.vision_config, "intermediate_size"):
24
+ config.vision_config.intermediate_size = min(
25
+ config.vision_config.intermediate_size, 1076
26
+ )
27
+ if hasattr(config.vision_config, "patch_size"):
28
+ config.vision_config.patch_size = min(config.vision_config.patch_size, 2)
29
+ if hasattr(config.vision_config, "hidden_size"):
30
+ config.vision_config.hidden_size = min(config.vision_config.hidden_size, 16)
31
+ if hasattr(config, "text_config"):
32
+ if hasattr(config.text_config, "intermediate_size"):
33
+ config.text_config.intermediate_size = min(
34
+ config.text_config.intermediate_size, 320
35
+ )
36
+ if hasattr(config.text_config, "hidden_size"):
37
+ config.text_config.hidden_size = min(config.text_config.hidden_size, 16)
38
+ if hasattr(config.text_config, "num_hidden_layers"):
39
+ config.text_config.num_hidden_layers = min(config.text_config.num_hidden_layers, 2)
40
+ if hasattr(config.text_config, "layer_types"):
41
+ config.text_config.layer_types = config.text_config.layer_types[
42
+ : config.text_config.num_hidden_layers
43
+ ]
44
+ if hasattr(config.text_config, "num_attention_heads"):
45
+ config.text_config.num_attention_heads = min(
46
+ config.text_config.num_attention_heads, 2
47
+ )
16
48
  update_config(config, kwargs)
17
49
  return kwargs
18
50
 
19
51
 
20
- def get_inputs(
52
+ def _get_inputs_gemma3(
21
53
  model: torch.nn.Module,
22
54
  config: Optional[Any],
23
55
  dummy_max_token_id: int,
24
56
  num_key_value_heads: int,
25
57
  num_hidden_layers: int,
58
+ pad_token_id: int,
59
+ image_token_index: int,
26
60
  head_dim: int,
27
61
  width: int,
28
62
  height: int,
29
63
  num_channels: int,
30
64
  batch_size: int = 2,
31
- sequence_length: int = 30,
32
- sequence_length2: int = 3,
65
+ sequence_length: int = 43,
66
+ sequence_length2: int = 43,
33
67
  n_images: int = 2,
34
68
  dynamic_rope: bool = False,
35
- add_second_input: int = 1,
69
+ max_sequence_length: int = 380,
36
70
  **kwargs, # unused
37
71
  ):
38
72
  """
39
- Generates input for task ``image-text-to-text``.
73
+ ::
40
74
 
41
- :param model: model to get the missing information
42
- :param config: configuration used to generate the model
43
- :param head_dim: last dimension of the cache
44
- :param dummy_max_token_id: dummy max token id
45
- :param batch_size: batch size
46
- :param sequence_length: sequence length
47
- :param sequence_length2: new sequence length
48
- :param n_images: number of images
49
- :param width: width of the image
50
- :param height: height of the image
51
- :param num_channels: number of channels
52
- :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
53
- :return: dictionary
75
+ dict(input_ids:T7s1x281,
76
+ pixel_values:T16s1x3x896x896,
77
+ attention_mask:dict(full_attention:T9s1x1x281x380,sliding_attention:T9s1x1x281x380),
78
+ position_ids:T7s1x281,
79
+ past_key_values:HybridCache(
80
+ key_cache=#34[T1s1x4x380x256,...],
81
+ value_cache=#34[T1s1x4x380x256,...]),
82
+ token_type_ids:T7s1x281,
83
+ cache_position:T7s281,
84
+ logits_to_keep:1)
85
+ dict(input_ids:T7s1x1,
86
+ pixel_values:None,
87
+ attention_mask:dict(full_attention:T9s1x1x1x380,sliding_attention:T9s1x1x1x380),
88
+ position_ids:T7s1x1,
89
+ past_key_values:HybridCache(
90
+ key_cache=#34[T1s1x4x380x256,...],
91
+ value_cache=#34[T1s1x4x380x256,...]),
92
+ token_type_ids:T7s1x1,
93
+ cache_position:T7s1,
94
+ logits_to_keep:1)
54
95
  """
55
96
  assert (
56
97
  "cls_cache" not in kwargs
57
98
  ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
58
99
  batch = torch.export.Dim("batch", min=1, max=1024)
59
100
  seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
60
- cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
61
- images = "images" # torch.export.Dim("images", min=1, max=4096)
101
+ # cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
62
102
 
63
103
  shapes = {
64
104
  "input_ids": {0: batch, 1: seq_length},
105
+ "token_type_ids": {0: batch, 1: seq_length},
65
106
  "attention_mask": {
66
- 0: batch,
67
- 1: "cache+seq", # cache_length + seq_length
68
- },
69
- "position_ids": {
70
- 0: batch,
71
- 1: "cache+seq", # cache_length + seq_length
107
+ "full_attention": {0: batch, 2: seq_length},
108
+ "sliding_attention": {0: batch, 2: seq_length},
72
109
  },
110
+ "position_ids": {0: batch, 1: seq_length},
111
+ "cache_position": {1: seq_length},
73
112
  "past_key_values": [
74
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
75
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
113
+ [{0: batch} for _ in range(num_hidden_layers)],
114
+ [{0: batch} for _ in range(num_hidden_layers)],
76
115
  ],
77
- "pixel_values": {0: batch, 1: images},
78
- "image_attention_mask": {0: batch, 1: seq_length, 2: images},
116
+ "pixel_values": {0: batch},
117
+ "use_cache": None,
79
118
  }
119
+
120
+ input_ids = torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
121
+ torch.int64
122
+ )
123
+ input_ids[:, 1] = image_token_index
124
+ # input_ids[input_ids == image_token_index] = pad_token_id
125
+ token_type_ids = torch.zeros_like(input_ids)
126
+ token_type_ids[input_ids == image_token_index] = 1
127
+
80
128
  inputs = dict(
81
- input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
82
- torch.int64
129
+ input_ids=input_ids,
130
+ token_type_ids=token_type_ids,
131
+ attention_mask=dict(
132
+ full_attention=torch.randn(batch_size, 1, sequence_length, max_sequence_length),
133
+ sliding_attention=torch.randn(batch_size, 1, sequence_length, max_sequence_length),
83
134
  ),
84
- attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
85
- torch.int64
86
- ),
87
- position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
88
- .to(torch.int64)
89
- .expand((batch_size, -1)),
90
- past_key_values=make_dynamic_cache(
135
+ cache_position=torch.arange(0, sequence_length).to(torch.int64),
136
+ position_ids=torch.arange(0, sequence_length).to(torch.int64).expand((batch_size, -1)),
137
+ past_key_values=make_hybrid_cache(
91
138
  [
92
139
  (
93
- torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),
94
- torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),
140
+ torch.randn(
141
+ batch_size, num_key_value_heads, max_sequence_length, head_dim
142
+ ),
143
+ torch.randn(
144
+ batch_size, num_key_value_heads, max_sequence_length, head_dim
145
+ ),
95
146
  )
96
147
  for i in range(num_hidden_layers)
97
148
  ]
98
149
  ),
99
- pixel_values=torch.ones((batch_size, n_images, num_channels, width, height)).to(
100
- torch.int64
101
- ),
150
+ pixel_values=torch.randn(n_images, num_channels, width, height).clamp(-1, 1),
102
151
  image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
103
152
  torch.int64
104
153
  ),
154
+ use_cache=True, # Gemma3 does not set this value to true when a cache is provided
105
155
  )
106
- res = dict(inputs=inputs, dynamic_shapes=shapes)
156
+ return dict(inputs=inputs, dynamic_shapes=shapes)
157
+
158
+
159
+ def get_inputs(
160
+ model: torch.nn.Module,
161
+ config: Optional[Any],
162
+ dummy_max_token_id: int,
163
+ num_key_value_heads: int,
164
+ num_hidden_layers: int,
165
+ pad_token_id: int,
166
+ image_token_index: int,
167
+ head_dim: int,
168
+ width: int,
169
+ height: int,
170
+ num_channels: int,
171
+ batch_size: int = 2,
172
+ sequence_length: int = 43,
173
+ sequence_length2: int = 43,
174
+ n_images: int = 2,
175
+ dynamic_rope: bool = False,
176
+ add_second_input: int = 1,
177
+ **kwargs, # unused
178
+ ):
179
+ """
180
+ Generates input for task ``image-text-to-text``.
181
+
182
+ :param model: model to get the missing information
183
+ :param config: configuration used to generate the model
184
+ :param head_dim: last dimension of the cache
185
+ :param dummy_max_token_id: dummy max token id
186
+ :param pad_token_id: pad_token_id
187
+ :param image_token_index: image_token_index
188
+ :param batch_size: batch size
189
+ :param sequence_length: sequence length
190
+ :param sequence_length2: new sequence length
191
+ :param n_images: number of images
192
+ :param width: width of the image
193
+ :param height: height of the image
194
+ :param num_channels: number of channels
195
+ :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
196
+ :return: dictionary
197
+ """
198
+ if model.__class__.__name__.startswith("Gemma3"):
199
+ res = _get_inputs_gemma3(
200
+ model,
201
+ config,
202
+ dummy_max_token_id=dummy_max_token_id,
203
+ num_key_value_heads=num_key_value_heads,
204
+ num_hidden_layers=num_hidden_layers,
205
+ pad_token_id=pad_token_id,
206
+ image_token_index=image_token_index,
207
+ head_dim=head_dim,
208
+ width=width,
209
+ height=height,
210
+ num_channels=num_channels,
211
+ batch_size=batch_size,
212
+ sequence_length=sequence_length,
213
+ sequence_length2=sequence_length2,
214
+ n_images=n_images,
215
+ dynamic_rope=dynamic_rope,
216
+ **kwargs,
217
+ )
218
+ else:
219
+ assert (
220
+ "cls_cache" not in kwargs
221
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
222
+ batch = torch.export.Dim("batch", min=1, max=1024)
223
+ batch_img = torch.export.Dim("batch_img", min=1, max=1024)
224
+ seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
225
+ cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
226
+ images = "images" # torch.export.Dim("images", min=1, max=4096)
227
+
228
+ shapes = {
229
+ "input_ids": {0: batch, 1: seq_length},
230
+ "token_type_ids": {0: batch, 1: seq_length},
231
+ "attention_mask": {0: batch, 1: "cache+seq"},
232
+ "position_ids": {0: batch, 1: "cache+seq"},
233
+ "past_key_values": [
234
+ [{0: batch} for _ in range(num_hidden_layers)],
235
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
236
+ ],
237
+ "pixel_values": (
238
+ {0: batch, 1: images}
239
+ if model.__class__.__name__ == "IdeficsForVisionText2Text"
240
+ else {0: batch_img}
241
+ ),
242
+ "image_attention_mask": {0: batch, 1: seq_length, 2: images},
243
+ "use_cache": None,
244
+ }
245
+
246
+ input_ids = torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
247
+ torch.int64
248
+ )
249
+ input_ids[0, 0] = image_token_index
250
+ input_ids[1, 1] = image_token_index
251
+ # input_ids[input_ids == image_token_index] = pad_token_id
252
+ token_type_ids = torch.zeros_like(input_ids)
253
+ token_type_ids[input_ids == image_token_index] = 1
254
+
255
+ inputs = dict(
256
+ input_ids=input_ids,
257
+ attention_mask=torch.cat(
258
+ [
259
+ torch.ones((batch_size, sequence_length), dtype=torch.int64),
260
+ input_ids.ne(pad_token_id).to(torch.int64),
261
+ ],
262
+ axis=-1,
263
+ ),
264
+ position_ids=torch.arange(0, sequence_length2)
265
+ .to(torch.int64)
266
+ .expand((batch_size, -1)),
267
+ past_key_values=make_dynamic_cache(
268
+ [
269
+ (
270
+ torch.randn(
271
+ batch_size, num_key_value_heads, sequence_length, head_dim
272
+ ),
273
+ torch.randn(
274
+ batch_size, num_key_value_heads, sequence_length, head_dim
275
+ ),
276
+ )
277
+ for i in range(num_hidden_layers)
278
+ ]
279
+ ),
280
+ pixel_values=(
281
+ torch.randn((batch_size, n_images, num_channels, width, height)).clamp(-1, 1)
282
+ if model.__class__.__name__ == "IdeficsForVisionText2Text"
283
+ else torch.randn(n_images, num_channels, width, height).clamp(-1, 1)
284
+ ),
285
+ image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
286
+ torch.int64
287
+ ),
288
+ token_type_ids=token_type_ids,
289
+ use_cache=True, # Gemma3 does not set this value to true when a cache is provided
290
+ )
291
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
107
292
  if add_second_input:
108
293
  assert (
109
294
  add_second_input > 0
@@ -123,6 +308,8 @@ def get_inputs(
123
308
  sequence_length2=sequence_length2 + 1,
124
309
  n_images=n_images + 1,
125
310
  dynamic_rope=dynamic_rope,
311
+ pad_token_id=pad_token_id,
312
+ image_token_index=image_token_index,
126
313
  add_second_input=0,
127
314
  **kwargs,
128
315
  )["inputs"]
@@ -145,8 +332,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
145
332
  ("num_key_value_heads", "num_attention_heads"),
146
333
  "intermediate_size",
147
334
  "hidden_size",
335
+ "pad_token_id",
148
336
  )
149
- check_hasattr(config, "vision_config")
337
+ check_hasattr(config, "vision_config", "image_token_index")
150
338
  text_config = True
151
339
  else:
152
340
  check_hasattr(
@@ -163,19 +351,25 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
163
351
  check_hasattr(config.vision_config, "image_size", "num_channels")
164
352
  kwargs = dict(
165
353
  batch_size=2,
166
- sequence_length=30,
167
- sequence_length2=3,
354
+ sequence_length=43,
355
+ sequence_length2=43,
168
356
  head_dim=(
169
357
  16
170
358
  if config is None
171
359
  else getattr(
172
360
  config,
173
361
  "head_dim",
174
- (config.text_config.hidden_size if text_config else config.hidden_size)
175
- // (
176
- config.text_config.num_attention_heads
177
- if text_config
178
- else config.num_attention_heads
362
+ (
363
+ config.text_config.head_dim
364
+ if text_config and hasattr(config.text_config, "head_dim")
365
+ else (
366
+ (config.text_config.hidden_size if text_config else config.hidden_size)
367
+ // (
368
+ config.text_config.num_attention_heads
369
+ if text_config
370
+ else config.num_attention_heads
371
+ )
372
+ )
179
373
  ),
180
374
  )
181
375
  ),
@@ -219,5 +413,15 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
219
413
  width=224 if config is None else config.vision_config.image_size,
220
414
  height=224 if config is None else config.vision_config.image_size,
221
415
  num_channels=3 if config is None else config.vision_config.num_channels,
416
+ pad_token_id=(
417
+ 0
418
+ if config is None or not hasattr(config, "text_config")
419
+ else config.text_config.pad_token_id
420
+ ),
421
+ image_token_index=(
422
+ 4
423
+ if config is None or not hasattr(config, "image_token_index")
424
+ else config.image_token_index
425
+ ),
222
426
  )
223
427
  return kwargs, get_inputs
@@ -0,0 +1,139 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import update_config, check_hasattr
4
+
5
+ __TASK__ = "mask-generation"
6
+
7
+
8
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
+ """Reduces a model size."""
10
+ kwargs: Dict[str, Any] = {}
11
+ if hasattr(config, "num_hidden_layers"):
12
+ config.num_hidden_layers = min(config.num_hidden_layers, 2)
13
+ if hasattr(config, "vision_config") and hasattr(config.vision_config, "num_hidden_layers"):
14
+ config.vision_config.num_hidden_layers = min(config.vision_config.num_hidden_layers, 2)
15
+ update_config(config, kwargs)
16
+ return kwargs
17
+
18
+
19
+ def get_inputs(
20
+ model: torch.nn.Module,
21
+ config: Optional[Any],
22
+ batch_size: int,
23
+ width: int,
24
+ height: int,
25
+ num_channels: int,
26
+ output_channels: int,
27
+ window_size: int,
28
+ add_second_input: bool = True,
29
+ **kwargs, # unused
30
+ ):
31
+ """
32
+ Generates input for task ``mask-generation``.
33
+
34
+ :param model: model to get the missing information
35
+ :param config: configuration used to generate the model
36
+ :param batch_size: batch size
37
+ :param width: width of the image
38
+ :param height: height of the image
39
+ :param num_channels: number of channels in the image
40
+ :param output_channels: number of output channels
41
+ :param window_size: size of the window for the vision model
42
+ :return: dictionary with inputs and dynamic shapes
43
+
44
+ """
45
+ assert (
46
+ "cls_cache" not in kwargs
47
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
48
+
49
+ # TODO(anyone): input_masks is weirdly failing all the time with mismatch channels
50
+ # with Conv or embedding_size. I guess maybe the model is too implicit on the
51
+ # input_masks shape.
52
+
53
+ # TODO(titaiwang): modeling code specifically requires the height and width of inputs
54
+ # should be the same as the config.vision_config.image_size. Does that make sense?
55
+
56
+ shapes = {
57
+ "pixel_values": {0: "batch"}, # 1: num_channels is static
58
+ "input_points": {0: "batch", 1: "point_batch_size", 2: "nb_points_per_image"},
59
+ "input_boxes": {0: "batch", 1: "point_batch_size"},
60
+ # "input_masks": {0: "batch", 2: "height", 3: "width"},
61
+ }
62
+ inputs = dict(
63
+ pixel_values=torch.randn(
64
+ (batch_size, num_channels, height, width), dtype=torch.float32
65
+ ).clamp(-1, 1),
66
+ input_points=torch.randn(
67
+ (batch_size, 2, 10, 2), dtype=torch.float32
68
+ ), # 10 points per image
69
+ input_boxes=torch.randn((batch_size, 2, 4), dtype=torch.float32), # 1 box per image
70
+ # input_masks=torch.randn(
71
+ # (batch_size, 1, height, width), dtype=torch.float32
72
+ # ), # mask for the image
73
+ )
74
+
75
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
76
+ if add_second_input:
77
+ assert (
78
+ add_second_input > 0
79
+ ), f"Not implemented for add_second_input={add_second_input}."
80
+ res["inputs2"] = get_inputs(
81
+ model=model,
82
+ config=config,
83
+ batch_size=batch_size + 1,
84
+ width=width,
85
+ height=height,
86
+ num_channels=num_channels,
87
+ output_channels=output_channels,
88
+ window_size=window_size,
89
+ add_second_input=False,
90
+ **kwargs,
91
+ )["inputs"]
92
+ return res
93
+
94
+
95
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
96
+ """
97
+ Inputs kwargs.
98
+
99
+ If the configuration is None, the function selects typical dimensions.
100
+ """
101
+ if config is not None:
102
+ # generates mask as outputs
103
+ if hasattr(config, "mask_decoder_config"):
104
+ check_hasattr(
105
+ config.mask_decoder_config,
106
+ "hidden_size",
107
+ "iou_head_hidden_dim",
108
+ "iou_head_depth",
109
+ "num_hidden_layers",
110
+ "num_multimask_outputs",
111
+ )
112
+ if hasattr(config, "prompt_encoder_config"):
113
+ check_hasattr(
114
+ config.prompt_encoder_config,
115
+ "hidden_size",
116
+ "image_embedding_size",
117
+ "image_size",
118
+ "mask_input_channels",
119
+ )
120
+ if hasattr(config, "vision_config"):
121
+ check_hasattr(
122
+ config.vision_config,
123
+ "image_size",
124
+ "hidden_size",
125
+ "intermediate_size",
126
+ "num_hidden_layers",
127
+ "output_channels",
128
+ "num_channels",
129
+ "window_size",
130
+ )
131
+ kwargs = dict(
132
+ batch_size=2,
133
+ width=1024 if config is None else config.vision_config.image_size,
134
+ height=1024 if config is None else config.vision_config.image_size,
135
+ num_channels=3 if config is None else config.vision_config.num_channels,
136
+ output_channels=256 if config is None else config.vision_config.output_channels,
137
+ window_size=14 if config is None else config.vision_config.window_size,
138
+ )
139
+ return kwargs, get_inputs
@@ -69,8 +69,8 @@ def get_inputs(
69
69
  ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
70
70
  batch = torch.export.Dim("batch", min=1, max=1024)
71
71
  seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
72
- cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
73
- cache_length2 = "cache_length_val" # torch.export.Dim("cache_length2", min=1, max=4096)
72
+ cache_length = "cache_length_key"
73
+ cache_length2 = "cache_length_val"
74
74
 
75
75
  shapes = {
76
76
  "input_ids": {0: batch, 1: seq_length},
@@ -1,6 +1,5 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple, Union
2
2
  import torch
3
- import transformers
4
3
  from ..helpers.cache_helper import (
5
4
  make_dynamic_cache,
6
5
  make_mamba_cache,
@@ -95,9 +94,14 @@ def get_inputs(
95
94
  cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
96
95
 
97
96
  if config is not None and config.__class__.__name__ == "FalconMambaConfig":
97
+ try:
98
+ from transformers.models.mamba.modeling_mamba import MambaCache
99
+ except ImportError:
100
+ from transformers.cache_utils import MambaCache
101
+
98
102
  assert cls_cache in (
99
103
  "MambaCache",
100
- transformers.cache_utils.MambaCache,
104
+ MambaCache,
101
105
  ), f"Unexpected value for cls_cache={cls_cache} and config={config}"
102
106
  seq_length_multiple = 8
103
107
  sequence_length = (
@@ -16,6 +16,8 @@ def get_function(name: str) -> Tuple[type, Callable]:
16
16
  module_name = ".".join(spl[:-1])
17
17
  fname = spl[-1]
18
18
  mod = importlib.import_module(module_name)
19
+ if not hasattr(mod, fname):
20
+ return None, None
19
21
  return mod, getattr(mod, fname)
20
22
 
21
23
 
@@ -33,12 +35,16 @@ def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
33
35
  doc = v.__doc__.lstrip()
34
36
  if doc.startswith("manual patch"):
35
37
  continue
36
- reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]")
38
+ reg = re.compile("[\\[]patch:([a-z_A-Z.]+)[\\]]")
37
39
  fall = reg.findall(doc)
38
40
  assert (
39
41
  len(fall) == 1
40
42
  ), f"Unable to find patching information for {v} in \n{doc}"
41
43
  fmod, f = get_function(fall[0])
44
+ if fmod is None and f is None:
45
+ # The function does not exist in this version of transformers.
46
+ # No patch is needed.
47
+ continue
42
48
  to_patch.append({"module": fmod, "function": f, "patch": v})
43
49
 
44
50
  name = mod.__name__
@@ -6,12 +6,17 @@ import torch
6
6
  import transformers
7
7
  from transformers.cache_utils import (
8
8
  DynamicCache,
9
- MambaCache,
10
9
  EncoderDecoderCache,
10
+ HybridCache,
11
11
  SlidingWindowCache,
12
12
  StaticCache,
13
13
  )
14
14
 
15
+ try:
16
+ from transformers.models.mamba.modeling_mamba import MambaCache
17
+ except ImportError:
18
+ from transformers.cache_utils import MambaCache
19
+
15
20
  from ..helpers import string_type
16
21
  from .serialization import _lower_name_with_
17
22
 
@@ -161,6 +166,9 @@ def serialization_functions(
161
166
  flatten_dynamic_cache,
162
167
  unflatten_dynamic_cache,
163
168
  flatten_with_keys_dynamic_cache,
169
+ flatten_hybrid_cache,
170
+ unflatten_hybrid_cache,
171
+ flatten_with_keys_hybrid_cache,
164
172
  flatten_mamba_cache,
165
173
  unflatten_mamba_cache,
166
174
  flatten_with_keys_mamba_cache,
@@ -187,6 +195,14 @@ def serialization_functions(
187
195
  # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
188
196
  verbose=verbose,
189
197
  ),
198
+ HybridCache: lambda verbose=verbose: register_class_serialization(
199
+ HybridCache,
200
+ flatten_hybrid_cache,
201
+ unflatten_hybrid_cache,
202
+ flatten_with_keys_hybrid_cache,
203
+ # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
204
+ verbose=verbose,
205
+ ),
190
206
  MambaCache: lambda verbose=verbose: register_class_serialization(
191
207
  MambaCache,
192
208
  flatten_mamba_cache,
@@ -70,6 +70,8 @@ def convert_dynamic_axes_into_dynamic_shapes(
70
70
  :param verbose: verbosity
71
71
  :return: (args, kwargs, dynamic shapes)
72
72
  """
73
+ from ..helpers.cache_helper import CacheKeyValue
74
+
73
75
  new_kwargs = {}
74
76
  if args:
75
77
  assert hasattr(model, "forward"), f"Missing method 'forward' for {model!r}"
@@ -121,7 +123,8 @@ def convert_dynamic_axes_into_dynamic_shapes(
121
123
  changes[k] = type(updated_kwargs[k])
122
124
  continue
123
125
  if isinstance(v, transformers.cache_utils.DynamicCache):
124
- updated_kwargs[k] = [v.key_cache, v.value_cache]
126
+ ca = CacheKeyValue(v)
127
+ updated_kwargs[k] = [ca.key_cache, ca.value_cache]
125
128
  changes[k] = type(v)
126
129
  continue
127
130
  raise NotImplementedError(