onnx-diagnostic 0.7.5__py3-none-any.whl → 0.7.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +56 -3
  3. onnx_diagnostic/export/dynamic_shapes.py +24 -10
  4. onnx_diagnostic/export/shape_helper.py +6 -2
  5. onnx_diagnostic/ext_test_case.py +2 -0
  6. onnx_diagnostic/helpers/_log_helper.py +6 -6
  7. onnx_diagnostic/helpers/cache_helper.py +326 -18
  8. onnx_diagnostic/helpers/config_helper.py +10 -0
  9. onnx_diagnostic/helpers/helper.py +152 -11
  10. onnx_diagnostic/helpers/mini_onnx_builder.py +7 -2
  11. onnx_diagnostic/helpers/onnx_helper.py +13 -7
  12. onnx_diagnostic/helpers/torch_helper.py +33 -11
  13. onnx_diagnostic/reference/ops/op_cast_like.py +15 -11
  14. onnx_diagnostic/reference/torch_ops/__init__.py +1 -0
  15. onnx_diagnostic/reference/torch_ops/unary_ops.py +7 -0
  16. onnx_diagnostic/tasks/__init__.py +2 -0
  17. onnx_diagnostic/tasks/automatic_speech_recognition.py +6 -2
  18. onnx_diagnostic/tasks/feature_extraction.py +7 -3
  19. onnx_diagnostic/tasks/fill_mask.py +6 -2
  20. onnx_diagnostic/tasks/image_classification.py +6 -2
  21. onnx_diagnostic/tasks/image_text_to_text.py +289 -62
  22. onnx_diagnostic/tasks/mask_generation.py +143 -0
  23. onnx_diagnostic/tasks/mixture_of_expert.py +2 -2
  24. onnx_diagnostic/tasks/object_detection.py +6 -2
  25. onnx_diagnostic/tasks/sentence_similarity.py +6 -2
  26. onnx_diagnostic/tasks/summarization.py +7 -2
  27. onnx_diagnostic/tasks/text2text_generation.py +7 -2
  28. onnx_diagnostic/tasks/text_classification.py +6 -2
  29. onnx_diagnostic/tasks/text_generation.py +14 -16
  30. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +3 -3
  31. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +17 -1
  32. onnx_diagnostic/torch_export_patches/patch_inputs.py +5 -2
  33. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -4
  34. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +428 -129
  35. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +60 -41
  36. onnx_diagnostic/torch_models/hghub/hub_data.py +5 -0
  37. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +288 -0
  38. onnx_diagnostic/torch_models/validate.py +1 -0
  39. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/METADATA +2 -2
  40. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/RECORD +43 -42
  41. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/WHEEL +0 -0
  42. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/licenses/LICENSE.txt +0 -0
  43. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,12 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
- from ..helpers.cache_helper import make_dynamic_cache
4
- from ..helpers.config_helper import update_config, check_hasattr, _pick
3
+ from ..helpers.cache_helper import make_dynamic_cache, make_hybrid_cache
4
+ from ..helpers.config_helper import (
5
+ update_config,
6
+ check_hasattr,
7
+ _pick,
8
+ default_num_hidden_layers as nhl,
9
+ )
5
10
 
6
11
  __TASK__ = "image-text-to-text"
7
12
 
@@ -10,100 +15,285 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
10
15
  """Reduces a model size."""
11
16
  kwargs: Dict[str, Any] = {}
12
17
  if hasattr(config, "num_hidden_layers"):
13
- config.num_hidden_layers = min(config.num_hidden_layers, 2)
14
- if hasattr(config, "vision_config") and hasattr(config.vision_config, "num_hidden_layers"):
15
- config.vision_config.num_hidden_layers = min(config.vision_config.num_hidden_layers, 2)
18
+ config.num_hidden_layers = min(config.num_hidden_layers, nhl())
19
+ if hasattr(config, "mm_tokens_per_image"):
20
+ config.mm_tokens_per_image = min(config.mm_tokens_per_image, 2)
21
+ if hasattr(config, "vision_config"):
22
+ if hasattr(config.vision_config, "num_hidden_layers"):
23
+ config.vision_config.num_hidden_layers = min(
24
+ config.vision_config.num_hidden_layers, 2
25
+ )
26
+ if hasattr(config.vision_config, "image_size"):
27
+ config.vision_config.image_size = min(config.vision_config.image_size, 96)
28
+ if hasattr(config.vision_config, "intermediate_size"):
29
+ config.vision_config.intermediate_size = min(
30
+ config.vision_config.intermediate_size, 1076
31
+ )
32
+ if hasattr(config.vision_config, "patch_size"):
33
+ config.vision_config.patch_size = min(config.vision_config.patch_size, 2)
34
+ if hasattr(config.vision_config, "hidden_size"):
35
+ config.vision_config.hidden_size = min(config.vision_config.hidden_size, 16)
36
+ if hasattr(config, "text_config"):
37
+ if hasattr(config.text_config, "intermediate_size"):
38
+ config.text_config.intermediate_size = min(
39
+ config.text_config.intermediate_size, 320
40
+ )
41
+ if hasattr(config.text_config, "hidden_size"):
42
+ config.text_config.hidden_size = min(config.text_config.hidden_size, 16)
43
+ if hasattr(config.text_config, "num_hidden_layers"):
44
+ config.text_config.num_hidden_layers = min(config.text_config.num_hidden_layers, 2)
45
+ if hasattr(config.text_config, "layer_types"):
46
+ config.text_config.layer_types = config.text_config.layer_types[
47
+ : config.text_config.num_hidden_layers
48
+ ]
49
+ if hasattr(config.text_config, "num_attention_heads"):
50
+ config.text_config.num_attention_heads = min(
51
+ config.text_config.num_attention_heads, 2
52
+ )
16
53
  update_config(config, kwargs)
17
54
  return kwargs
18
55
 
19
56
 
20
- def get_inputs(
57
+ def _get_inputs_gemma3(
21
58
  model: torch.nn.Module,
22
59
  config: Optional[Any],
23
60
  dummy_max_token_id: int,
24
61
  num_key_value_heads: int,
25
62
  num_hidden_layers: int,
63
+ pad_token_id: int,
64
+ image_token_index: int,
26
65
  head_dim: int,
27
66
  width: int,
28
67
  height: int,
29
68
  num_channels: int,
30
69
  batch_size: int = 2,
31
- sequence_length: int = 30,
32
- sequence_length2: int = 3,
70
+ sequence_length: int = 43,
71
+ sequence_length2: int = 43,
33
72
  n_images: int = 2,
34
73
  dynamic_rope: bool = False,
35
- add_second_input: int = 1,
74
+ max_sequence_length: int = 380,
36
75
  **kwargs, # unused
37
76
  ):
38
77
  """
39
- Generates input for task ``image-text-to-text``.
78
+ ::
40
79
 
41
- :param model: model to get the missing information
42
- :param config: configuration used to generate the model
43
- :param head_dim: last dimension of the cache
44
- :param dummy_max_token_id: dummy max token id
45
- :param batch_size: batch size
46
- :param sequence_length: sequence length
47
- :param sequence_length2: new sequence length
48
- :param n_images: number of images
49
- :param width: width of the image
50
- :param height: height of the image
51
- :param num_channels: number of channels
52
- :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
53
- :return: dictionary
80
+ dict(input_ids:T7s1x281,
81
+ pixel_values:T16s1x3x896x896,
82
+ attention_mask:dict(full_attention:T9s1x1x281x380,sliding_attention:T9s1x1x281x380),
83
+ position_ids:T7s1x281,
84
+ past_key_values:HybridCache(
85
+ key_cache=#34[T1s1x4x380x256,...],
86
+ value_cache=#34[T1s1x4x380x256,...]),
87
+ token_type_ids:T7s1x281,
88
+ cache_position:T7s281,
89
+ logits_to_keep:1)
90
+ dict(input_ids:T7s1x1,
91
+ pixel_values:None,
92
+ attention_mask:dict(full_attention:T9s1x1x1x380,sliding_attention:T9s1x1x1x380),
93
+ position_ids:T7s1x1,
94
+ past_key_values:HybridCache(
95
+ key_cache=#34[T1s1x4x380x256,...],
96
+ value_cache=#34[T1s1x4x380x256,...]),
97
+ token_type_ids:T7s1x1,
98
+ cache_position:T7s1,
99
+ logits_to_keep:1)
54
100
  """
55
101
  assert (
56
102
  "cls_cache" not in kwargs
57
103
  ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
58
104
  batch = torch.export.Dim("batch", min=1, max=1024)
59
105
  seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
60
- cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
61
- images = "images" # torch.export.Dim("images", min=1, max=4096)
106
+ # cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
62
107
 
63
108
  shapes = {
64
109
  "input_ids": {0: batch, 1: seq_length},
110
+ "token_type_ids": {0: batch, 1: seq_length},
65
111
  "attention_mask": {
66
- 0: batch,
67
- 1: "cache+seq", # cache_length + seq_length
68
- },
69
- "position_ids": {
70
- 0: batch,
71
- 1: "cache+seq", # cache_length + seq_length
112
+ "full_attention": {0: batch, 2: seq_length},
113
+ "sliding_attention": {0: batch, 2: seq_length},
72
114
  },
115
+ "position_ids": {0: batch, 1: seq_length},
116
+ "cache_position": {1: seq_length},
73
117
  "past_key_values": [
74
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
75
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
118
+ [{0: batch} for _ in range(num_hidden_layers)],
119
+ [{0: batch} for _ in range(num_hidden_layers)],
76
120
  ],
77
- "pixel_values": {0: batch, 1: images},
78
- "image_attention_mask": {0: batch, 1: seq_length, 2: images},
121
+ "pixel_values": {0: batch},
122
+ "use_cache": None,
79
123
  }
124
+
125
+ input_ids = torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
126
+ torch.int64
127
+ )
128
+ input_ids[:, 1] = image_token_index
129
+ # input_ids[input_ids == image_token_index] = pad_token_id
130
+ token_type_ids = torch.zeros_like(input_ids)
131
+ token_type_ids[input_ids == image_token_index] = 1
132
+
80
133
  inputs = dict(
81
- input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
82
- torch.int64
134
+ input_ids=input_ids,
135
+ token_type_ids=token_type_ids,
136
+ attention_mask=dict(
137
+ full_attention=torch.randn(batch_size, 1, sequence_length, max_sequence_length),
138
+ sliding_attention=torch.randn(batch_size, 1, sequence_length, max_sequence_length),
83
139
  ),
84
- attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
85
- torch.int64
86
- ),
87
- position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
88
- .to(torch.int64)
89
- .expand((batch_size, -1)),
90
- past_key_values=make_dynamic_cache(
140
+ cache_position=torch.arange(0, sequence_length).to(torch.int64),
141
+ position_ids=torch.arange(0, sequence_length).to(torch.int64).expand((batch_size, -1)),
142
+ past_key_values=make_hybrid_cache(
91
143
  [
92
144
  (
93
- torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),
94
- torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),
145
+ torch.randn(
146
+ batch_size, num_key_value_heads, max_sequence_length, head_dim
147
+ ),
148
+ torch.randn(
149
+ batch_size, num_key_value_heads, max_sequence_length, head_dim
150
+ ),
95
151
  )
96
152
  for i in range(num_hidden_layers)
97
153
  ]
98
154
  ),
99
- pixel_values=torch.ones((batch_size, n_images, num_channels, width, height)).to(
100
- torch.int64
101
- ),
155
+ pixel_values=torch.randn(n_images, num_channels, width, height).clamp(-1, 1),
102
156
  image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
103
157
  torch.int64
104
158
  ),
159
+ use_cache=True, # Gemma3 does not set this value to true when a cache is provided
105
160
  )
106
- res = dict(inputs=inputs, dynamic_shapes=shapes)
161
+ return dict(inputs=inputs, dynamic_shapes=shapes)
162
+
163
+
164
+ def get_inputs(
165
+ model: torch.nn.Module,
166
+ config: Optional[Any],
167
+ dummy_max_token_id: int,
168
+ num_key_value_heads: int,
169
+ num_hidden_layers: int,
170
+ pad_token_id: int,
171
+ image_token_index: int,
172
+ head_dim: int,
173
+ width: int,
174
+ height: int,
175
+ num_channels: int,
176
+ batch_size: int = 2,
177
+ sequence_length: int = 43,
178
+ sequence_length2: int = 43,
179
+ n_images: int = 2,
180
+ dynamic_rope: bool = False,
181
+ add_second_input: int = 1,
182
+ **kwargs, # unused
183
+ ):
184
+ """
185
+ Generates input for task ``image-text-to-text``.
186
+
187
+ :param model: model to get the missing information
188
+ :param config: configuration used to generate the model
189
+ :param head_dim: last dimension of the cache
190
+ :param dummy_max_token_id: dummy max token id
191
+ :param pad_token_id: pad_token_id
192
+ :param image_token_index: image_token_index
193
+ :param batch_size: batch size
194
+ :param sequence_length: sequence length
195
+ :param sequence_length2: new sequence length
196
+ :param n_images: number of images
197
+ :param width: width of the image
198
+ :param height: height of the image
199
+ :param num_channels: number of channels
200
+ :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
201
+ :return: dictionary
202
+ """
203
+ if model.__class__.__name__.startswith("Gemma3"):
204
+ res = _get_inputs_gemma3(
205
+ model,
206
+ config,
207
+ dummy_max_token_id=dummy_max_token_id,
208
+ num_key_value_heads=num_key_value_heads,
209
+ num_hidden_layers=num_hidden_layers,
210
+ pad_token_id=pad_token_id,
211
+ image_token_index=image_token_index,
212
+ head_dim=head_dim,
213
+ width=width,
214
+ height=height,
215
+ num_channels=num_channels,
216
+ batch_size=batch_size,
217
+ sequence_length=sequence_length,
218
+ sequence_length2=sequence_length2,
219
+ n_images=n_images,
220
+ dynamic_rope=dynamic_rope,
221
+ **kwargs,
222
+ )
223
+ else:
224
+ assert (
225
+ "cls_cache" not in kwargs
226
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
227
+ batch = torch.export.Dim("batch", min=1, max=1024)
228
+ batch_img = torch.export.Dim("batch_img", min=1, max=1024)
229
+ seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
230
+ cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
231
+ images = "images" # torch.export.Dim("images", min=1, max=4096)
232
+
233
+ shapes = {
234
+ "input_ids": {0: batch, 1: seq_length},
235
+ "token_type_ids": {0: batch, 1: seq_length},
236
+ "attention_mask": {0: batch, 1: "cache+seq"},
237
+ "position_ids": {0: batch, 1: "cache+seq"},
238
+ "past_key_values": [
239
+ [{0: batch} for _ in range(num_hidden_layers)],
240
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
241
+ ],
242
+ "pixel_values": (
243
+ {0: batch, 1: images}
244
+ if model.__class__.__name__ == "IdeficsForVisionText2Text"
245
+ else {0: batch_img}
246
+ ),
247
+ "image_attention_mask": {0: batch, 1: seq_length, 2: images},
248
+ "use_cache": None,
249
+ }
250
+
251
+ input_ids = torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
252
+ torch.int64
253
+ )
254
+ input_ids[0, 0] = image_token_index
255
+ input_ids[1, 1] = image_token_index
256
+ # input_ids[input_ids == image_token_index] = pad_token_id
257
+ token_type_ids = torch.zeros_like(input_ids)
258
+ token_type_ids[input_ids == image_token_index] = 1
259
+
260
+ inputs = dict(
261
+ input_ids=input_ids,
262
+ attention_mask=torch.cat(
263
+ [
264
+ torch.ones((batch_size, sequence_length), dtype=torch.int64),
265
+ input_ids.ne(pad_token_id).to(torch.int64),
266
+ ],
267
+ axis=-1,
268
+ ),
269
+ position_ids=torch.arange(0, sequence_length2)
270
+ .to(torch.int64)
271
+ .expand((batch_size, -1)),
272
+ past_key_values=make_dynamic_cache(
273
+ [
274
+ (
275
+ torch.randn(
276
+ batch_size, num_key_value_heads, sequence_length, head_dim
277
+ ),
278
+ torch.randn(
279
+ batch_size, num_key_value_heads, sequence_length, head_dim
280
+ ),
281
+ )
282
+ for i in range(num_hidden_layers)
283
+ ]
284
+ ),
285
+ pixel_values=(
286
+ torch.randn((batch_size, n_images, num_channels, width, height)).clamp(-1, 1)
287
+ if model.__class__.__name__ == "IdeficsForVisionText2Text"
288
+ else torch.randn(n_images, num_channels, width, height).clamp(-1, 1)
289
+ ),
290
+ image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
291
+ torch.int64
292
+ ),
293
+ token_type_ids=token_type_ids,
294
+ use_cache=True, # Gemma3 does not set this value to true when a cache is provided
295
+ )
296
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
107
297
  if add_second_input:
108
298
  assert (
109
299
  add_second_input > 0
@@ -123,6 +313,8 @@ def get_inputs(
123
313
  sequence_length2=sequence_length2 + 1,
124
314
  n_images=n_images + 1,
125
315
  dynamic_rope=dynamic_rope,
316
+ pad_token_id=pad_token_id,
317
+ image_token_index=image_token_index,
126
318
  add_second_input=0,
127
319
  **kwargs,
128
320
  )["inputs"]
@@ -145,8 +337,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
145
337
  ("num_key_value_heads", "num_attention_heads"),
146
338
  "intermediate_size",
147
339
  "hidden_size",
340
+ "pad_token_id",
148
341
  )
149
- check_hasattr(config, "vision_config")
342
+ check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
150
343
  text_config = True
151
344
  else:
152
345
  check_hasattr(
@@ -160,22 +353,28 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
160
353
  "vision_config",
161
354
  )
162
355
  text_config = False
163
- check_hasattr(config.vision_config, "image_size", "num_channels")
356
+ check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
164
357
  kwargs = dict(
165
358
  batch_size=2,
166
- sequence_length=30,
167
- sequence_length2=3,
359
+ sequence_length=43,
360
+ sequence_length2=43,
168
361
  head_dim=(
169
362
  16
170
363
  if config is None
171
364
  else getattr(
172
365
  config,
173
366
  "head_dim",
174
- (config.text_config.hidden_size if text_config else config.hidden_size)
175
- // (
176
- config.text_config.num_attention_heads
177
- if text_config
178
- else config.num_attention_heads
367
+ (
368
+ config.text_config.head_dim
369
+ if text_config and hasattr(config.text_config, "head_dim")
370
+ else (
371
+ (config.text_config.hidden_size if text_config else config.hidden_size)
372
+ // (
373
+ config.text_config.num_attention_heads
374
+ if text_config
375
+ else config.num_attention_heads
376
+ )
377
+ )
179
378
  ),
180
379
  )
181
380
  ),
@@ -216,8 +415,36 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
216
415
  if config is None
217
416
  else (config.text_config.hidden_size if text_config else config.hidden_size)
218
417
  ),
219
- width=224 if config is None else config.vision_config.image_size,
220
- height=224 if config is None else config.vision_config.image_size,
221
- num_channels=3 if config is None else config.vision_config.num_channels,
418
+ width=(
419
+ 224
420
+ if config is None or not hasattr(config.vision_config, "image_size")
421
+ else config.vision_config.image_size
422
+ ),
423
+ height=(
424
+ 224
425
+ if config is None or not hasattr(config.vision_config, "image_size")
426
+ else config.vision_config.image_size
427
+ ),
428
+ num_channels=(
429
+ 3
430
+ if config is None
431
+ else _pick(config.vision_config, "num_channels", "in_chans", "in_channels")
432
+ ),
433
+ pad_token_id=(
434
+ 0
435
+ if config is None
436
+ or not hasattr(config, "text_config")
437
+ or not hasattr(config.text_config, "pad_token_id")
438
+ else config.text_config.pad_token_id
439
+ ),
440
+ image_token_index=(
441
+ 4
442
+ if config is None
443
+ or (
444
+ not hasattr(config, "image_token_index")
445
+ and not hasattr(config, "image_token_id")
446
+ )
447
+ else _pick(config, "image_token_index", "image_token_id")
448
+ ),
222
449
  )
223
450
  return kwargs, get_inputs
@@ -0,0 +1,143 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
8
+
9
+ __TASK__ = "mask-generation"
10
+
11
+
12
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
13
+ """Reduces a model size."""
14
+ kwargs: Dict[str, Any] = {}
15
+ if hasattr(config, "num_hidden_layers"):
16
+ config.num_hidden_layers = min(config.num_hidden_layers, nhl())
17
+ if hasattr(config, "vision_config") and hasattr(config.vision_config, "num_hidden_layers"):
18
+ config.vision_config.num_hidden_layers = min(config.vision_config.num_hidden_layers, 2)
19
+ update_config(config, kwargs)
20
+ return kwargs
21
+
22
+
23
+ def get_inputs(
24
+ model: torch.nn.Module,
25
+ config: Optional[Any],
26
+ batch_size: int,
27
+ width: int,
28
+ height: int,
29
+ num_channels: int,
30
+ output_channels: int,
31
+ window_size: int,
32
+ add_second_input: bool = True,
33
+ **kwargs, # unused
34
+ ):
35
+ """
36
+ Generates input for task ``mask-generation``.
37
+
38
+ :param model: model to get the missing information
39
+ :param config: configuration used to generate the model
40
+ :param batch_size: batch size
41
+ :param width: width of the image
42
+ :param height: height of the image
43
+ :param num_channels: number of channels in the image
44
+ :param output_channels: number of output channels
45
+ :param window_size: size of the window for the vision model
46
+ :return: dictionary with inputs and dynamic shapes
47
+
48
+ """
49
+ assert (
50
+ "cls_cache" not in kwargs
51
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
52
+
53
+ # TODO(anyone): input_masks is weirdly failing all the time with mismatch channels
54
+ # with Conv or embedding_size. I guess maybe the model is too implicit on the
55
+ # input_masks shape.
56
+
57
+ # TODO(titaiwang): modeling code specifically requires the height and width of inputs
58
+ # should be the same as the config.vision_config.image_size. Does that make sense?
59
+
60
+ shapes = {
61
+ "pixel_values": {0: "batch"}, # 1: num_channels is static
62
+ "input_points": {0: "batch", 1: "point_batch_size", 2: "nb_points_per_image"},
63
+ "input_boxes": {0: "batch", 1: "point_batch_size"},
64
+ # "input_masks": {0: "batch", 2: "height", 3: "width"},
65
+ }
66
+ inputs = dict(
67
+ pixel_values=torch.randn(
68
+ (batch_size, num_channels, height, width), dtype=torch.float32
69
+ ).clamp(-1, 1),
70
+ input_points=torch.randn(
71
+ (batch_size, 2, 10, 2), dtype=torch.float32
72
+ ), # 10 points per image
73
+ input_boxes=torch.randn((batch_size, 2, 4), dtype=torch.float32), # 1 box per image
74
+ # input_masks=torch.randn(
75
+ # (batch_size, 1, height, width), dtype=torch.float32
76
+ # ), # mask for the image
77
+ )
78
+
79
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
80
+ if add_second_input:
81
+ assert (
82
+ add_second_input > 0
83
+ ), f"Not implemented for add_second_input={add_second_input}."
84
+ res["inputs2"] = get_inputs(
85
+ model=model,
86
+ config=config,
87
+ batch_size=batch_size + 1,
88
+ width=width,
89
+ height=height,
90
+ num_channels=num_channels,
91
+ output_channels=output_channels,
92
+ window_size=window_size,
93
+ add_second_input=False,
94
+ **kwargs,
95
+ )["inputs"]
96
+ return res
97
+
98
+
99
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
100
+ """
101
+ Inputs kwargs.
102
+
103
+ If the configuration is None, the function selects typical dimensions.
104
+ """
105
+ if config is not None:
106
+ # generates mask as outputs
107
+ if hasattr(config, "mask_decoder_config"):
108
+ check_hasattr(
109
+ config.mask_decoder_config,
110
+ "hidden_size",
111
+ "iou_head_hidden_dim",
112
+ "iou_head_depth",
113
+ "num_hidden_layers",
114
+ "num_multimask_outputs",
115
+ )
116
+ if hasattr(config, "prompt_encoder_config"):
117
+ check_hasattr(
118
+ config.prompt_encoder_config,
119
+ "hidden_size",
120
+ "image_embedding_size",
121
+ "image_size",
122
+ "mask_input_channels",
123
+ )
124
+ if hasattr(config, "vision_config"):
125
+ check_hasattr(
126
+ config.vision_config,
127
+ "image_size",
128
+ "hidden_size",
129
+ "intermediate_size",
130
+ "num_hidden_layers",
131
+ "output_channels",
132
+ "num_channels",
133
+ "window_size",
134
+ )
135
+ kwargs = dict(
136
+ batch_size=2,
137
+ width=1024 if config is None else config.vision_config.image_size,
138
+ height=1024 if config is None else config.vision_config.image_size,
139
+ num_channels=3 if config is None else config.vision_config.num_channels,
140
+ output_channels=256 if config is None else config.vision_config.output_channels,
141
+ window_size=14 if config is None else config.vision_config.window_size,
142
+ )
143
+ return kwargs, get_inputs
@@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
3
 
4
4
  # from ..helpers.cache_helper import make_dynamic_cache
5
- from ..helpers.config_helper import update_config # , check_hasattr, _pick
5
+ from ..helpers.config_helper import update_config, default_num_hidden_layers as nhl
6
6
 
7
7
  __TASK__ = "MoE"
8
8
 
@@ -11,7 +11,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
11
11
  """Reduces a model size."""
12
12
  kwargs: Dict[str, Any] = {}
13
13
  if hasattr(config, "num_hidden_layers"):
14
- config.num_hidden_layers = min(config.num_hidden_layers, 2)
14
+ config.num_hidden_layers = min(config.num_hidden_layers, nhl())
15
15
  if hasattr(config, "vision_config") and hasattr(config.vision_config, "num_hidden_layers"):
16
16
  config.vision_config.num_hidden_layers = min(config.vision_config.num_hidden_layers, 2)
17
17
  if hasattr(config, "audio_processor") and hasattr(
@@ -1,6 +1,10 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
- from ..helpers.config_helper import update_config, check_hasattr
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
4
8
 
5
9
  __TASK__ = "object-detection"
6
10
 
@@ -10,7 +14,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
10
14
  check_hasattr(config, ("num_hidden_layers", "hidden_sizes"))
11
15
  kwargs = dict(
12
16
  num_hidden_layers=(
13
- min(config.num_hidden_layers, 2)
17
+ min(config.num_hidden_layers, nhl())
14
18
  if hasattr(config, "num_hidden_layers")
15
19
  else len(config.hidden_sizes)
16
20
  )
@@ -1,6 +1,10 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
- from ..helpers.config_helper import update_config, check_hasattr
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
4
8
 
5
9
  __TASK__ = "sentence-similarity"
6
10
 
@@ -9,7 +13,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
13
  """Reduces a model size."""
10
14
  check_hasattr(config, "num_attention_heads", "num_hidden_layers")
11
15
  kwargs = dict(
12
- num_hidden_layers=min(config.num_hidden_layers, 2),
16
+ num_hidden_layers=min(config.num_hidden_layers, nhl()),
13
17
  num_attention_heads=min(config.num_attention_heads, 4),
14
18
  )
15
19
  update_config(config, kwargs)
@@ -1,7 +1,12 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
3
  from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
4
- from ..helpers.config_helper import update_config, check_hasattr, _pick
4
+ from ..helpers.config_helper import (
5
+ update_config,
6
+ check_hasattr,
7
+ _pick,
8
+ default_num_hidden_layers as nhl,
9
+ )
5
10
 
6
11
  __TASK__ = "summarization"
7
12
 
@@ -12,7 +17,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
12
17
  if hasattr(config, "num_decoder_layers"):
13
18
  config.num_decoder_layers = min(config.num_decoder_layers, 2)
14
19
  if hasattr(config, "num_hidden_layers"):
15
- config.num_hidden_layers = min(config.num_hidden_layers, 2)
20
+ config.num_hidden_layers = min(config.num_hidden_layers, nhl())
16
21
  update_config(config, kwargs)
17
22
  return kwargs
18
23