onnx-diagnostic 0.7.11__py3-none-any.whl → 0.7.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +5 -2
  3. onnx_diagnostic/export/dynamic_shapes.py +11 -2
  4. onnx_diagnostic/helpers/helper.py +11 -5
  5. onnx_diagnostic/helpers/log_helper.py +65 -12
  6. onnx_diagnostic/helpers/mini_onnx_builder.py +17 -0
  7. onnx_diagnostic/helpers/model_builder_helper.py +1 -0
  8. onnx_diagnostic/helpers/rt_helper.py +55 -37
  9. onnx_diagnostic/helpers/torch_helper.py +31 -7
  10. onnx_diagnostic/reference/torch_evaluator.py +2 -2
  11. onnx_diagnostic/tasks/data/__init__.py +13 -0
  12. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  13. onnx_diagnostic/tasks/image_text_to_text.py +256 -141
  14. onnx_diagnostic/tasks/text_generation.py +15 -0
  15. onnx_diagnostic/torch_export_patches/eval/__init__.py +177 -150
  16. onnx_diagnostic/torch_export_patches/eval/model_cases.py +19 -1
  17. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +40 -14
  18. onnx_diagnostic/torch_export_patches/patch_inputs.py +10 -6
  19. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +116 -10
  20. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +269 -4
  21. onnx_diagnostic/torch_models/hghub/hub_api.py +4 -10
  22. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +36 -0
  23. onnx_diagnostic/torch_models/hghub/model_inputs.py +32 -4
  24. onnx_diagnostic/torch_models/validate.py +337 -113
  25. onnx_diagnostic/torch_onnx/sbs.py +2 -1
  26. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/METADATA +11 -31
  27. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/RECORD +30 -28
  28. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/WHEEL +0 -0
  29. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/licenses/LICENSE.txt +0 -0
  30. {onnx_diagnostic-0.7.11.dist-info → onnx_diagnostic-0.7.13.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ from ..helpers.config_helper import (
7
7
  _pick,
8
8
  default_num_hidden_layers as nhl,
9
9
  )
10
+ from .data import get_data
10
11
 
11
12
  __TASK__ = "image-text-to-text"
12
13
 
@@ -14,6 +15,27 @@ __TASK__ = "image-text-to-text"
14
15
  def reduce_model_config(config: Any) -> Dict[str, Any]:
15
16
  """Reduces a model size."""
16
17
  kwargs: Dict[str, Any] = {}
18
+ if (
19
+ hasattr(config, "architectures")
20
+ and config.architectures
21
+ and config.architectures[0] == "Gemma3ForConditionalGeneration"
22
+ ):
23
+ if hasattr(config, "vision_config"):
24
+ if hasattr(config.vision_config, "num_hidden_layers"):
25
+ config.vision_config.num_hidden_layers = min(
26
+ config.vision_config.num_hidden_layers, nhl()
27
+ )
28
+ if hasattr(config, "text_config"):
29
+ if hasattr(config.text_config, "intermediate_size"):
30
+ config.text_config.intermediate_size = min(
31
+ config.text_config.intermediate_size, 10240 // 10 * 5 // 2
32
+ )
33
+ config.text_config.hidden_size = min(
34
+ config.text_config.hidden_size, 2560 // 10 * 5 // 2
35
+ )
36
+ update_config(config, kwargs)
37
+ return kwargs
38
+
17
39
  if hasattr(config, "num_hidden_layers"):
18
40
  config.num_hidden_layers = min(config.num_hidden_layers, nhl())
19
41
  if hasattr(config, "mm_tokens_per_image"):
@@ -72,54 +94,63 @@ def _get_inputs_gemma3(
72
94
  width: int,
73
95
  height: int,
74
96
  num_channels: int,
75
- batch_size: int = 2,
76
- sequence_length: int = 43,
77
- sequence_length2: int = 43,
78
- n_images: int = 2,
79
- dynamic_rope: bool = False,
80
- max_sequence_length: int = 380,
97
+ batch_size: Optional[int] = 1,
98
+ sequence_length: Optional[int] = 281,
99
+ n_images: Optional[int] = 1,
100
+ max_sequence_length: Optional[int] = 580,
101
+ total_sequence_length: Optional[int] = 860,
81
102
  **kwargs, # unused
82
103
  ):
83
104
  """
105
+ The functions uses predefined values for input_ids and token_type_ids.
106
+
107
+ **google/gemma-3-4b-it**
108
+
109
+ iteration 1
110
+
84
111
  ::
112
+ cache_position:T7s281,
113
+ input_ids:T7s1x281,
114
+ token_type_ids:T7s1x281,
115
+ attention_mask:dict(sliding_attention:T9s1x1x281x580,
116
+ full_attention:T9s1x1x281x580),
117
+ pixel_values:T16s1x3x896x896,
85
118
 
86
- dict(input_ids:T7s1x281,
87
- pixel_values:T16s1x3x896x896,
88
- attention_mask:dict(full_attention:T9s1x1x281x380,sliding_attention:T9s1x1x281x380),
89
- position_ids:T7s1x281,
90
- past_key_values:HybridCache(
91
- key_cache=#34[T1s1x4x380x256,...],
92
- value_cache=#34[T1s1x4x380x256,...]),
93
- token_type_ids:T7s1x281,
94
- cache_position:T7s281,
95
- logits_to_keep:1)
96
- dict(input_ids:T7s1x1,
97
- pixel_values:None,
98
- attention_mask:dict(full_attention:T9s1x1x1x380,sliding_attention:T9s1x1x1x380),
99
- position_ids:T7s1x1,
100
- past_key_values:HybridCache(
101
- key_cache=#34[T1s1x4x380x256,...],
102
- value_cache=#34[T1s1x4x380x256,...]),
103
- token_type_ids:T7s1x1,
104
- cache_position:T7s1,
105
- logits_to_keep:1)
119
+ iteration 2
120
+
121
+ ::
122
+
123
+ cache_position:T7s1,
124
+ past_key_values:StaticCache(key_cache=#34[T1s1x4x580x256,...],
125
+ value_cache=#34[T1s1x4x580x256,...]),
126
+ input_ids:T7s1x1,
127
+ inputs_embeds:None,
128
+ token_type_ids:T7s1x1,
129
+ attention_mask:dict(sliding_attention:T9s1x1x1x580,full_attention:T9s1x1x1x580),
130
+ position_ids:None,
106
131
  """
132
+ batch_size = 1 if batch_size is None else batch_size
133
+ sequence_length = 281 if sequence_length is None else sequence_length
134
+ n_images = 1 if n_images is None else n_images
135
+ max_sequence_length = 580 if max_sequence_length is None else max_sequence_length
136
+ total_sequence_length = 860 if total_sequence_length is None else total_sequence_length
137
+
107
138
  assert (
108
139
  "cls_cache" not in kwargs
109
140
  ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
110
141
  batch = "batch"
111
- seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
112
- # cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
142
+ seq_length = "seq_length"
143
+ tot_length = "total_length"
113
144
 
114
145
  shapes = {
115
146
  "input_ids": {0: batch, 1: seq_length},
116
147
  "token_type_ids": {0: batch, 1: seq_length},
117
148
  "attention_mask": {
118
- "full_attention": {0: batch, 2: seq_length},
119
- "sliding_attention": {0: batch, 2: seq_length},
149
+ "full_attention": {0: batch, 2: seq_length, 3: tot_length},
150
+ "sliding_attention": {0: batch, 2: seq_length, 3: tot_length},
120
151
  },
121
152
  "position_ids": {0: batch, 1: seq_length},
122
- "cache_position": {1: seq_length},
153
+ "cache_position": {0: seq_length},
123
154
  "past_key_values": [
124
155
  [{0: batch} for _ in range(num_hidden_layers)],
125
156
  [{0: batch} for _ in range(num_hidden_layers)],
@@ -128,23 +159,55 @@ def _get_inputs_gemma3(
128
159
  "use_cache": None,
129
160
  }
130
161
 
131
- input_ids = torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
132
- torch.int64
133
- )
134
- input_ids[:, 1] = image_token_index
135
- # input_ids[input_ids == image_token_index] = pad_token_id
136
- token_type_ids = torch.zeros_like(input_ids)
137
- token_type_ids[input_ids == image_token_index] = 1
162
+ # retrieve specific inputs to keep the consistency between
163
+ # ids and images
164
+ dummies = get_data("dummies_imagetext2text_generation_gemma3.onnx")
165
+ dummies = dummies[("", 0, "I")][1]
166
+ dummies = {k: v for k, v in dummies.items() if k in shapes}
167
+ expected = {"input_ids", "token_type_ids", "position_ids", "cache_position"}
168
+
169
+ def _check_():
170
+ assert expected & set(
171
+ dummies
172
+ ), f"Unable to find expected inputs {expected} in loaded inputs {set(dummies)}"
173
+ assert sequence_length == dummies["input_ids"].shape[-1], (
174
+ f"sequence_length={sequence_length} != {dummies['input_ids'].shape[-1]} for "
175
+ f"model class {model.__class__.__name__}"
176
+ )
177
+ assert batch_size == dummies["input_ids"].shape[0], (
178
+ f"batch_size={batch_size} != {dummies['input_ids'].shape[0]} for "
179
+ f"model class {model.__class__.__name__}"
180
+ )
181
+ assert max_sequence_length == 580, (
182
+ f"max_sequence_length={max_sequence_length} != 580 "
183
+ f"for model {model.__class__.__name__}"
184
+ )
185
+ assert total_sequence_length == 860, (
186
+ f"total_sequence_length={total_sequence_length} != 860 "
187
+ f"for model {model.__class__.__name__}"
188
+ )
189
+ assert (
190
+ head_dim == 256
191
+ ), f"head_dim={head_dim} != 256 for model {model.__class__.__name__}"
192
+ assert n_images == 1, f"n_images={n_images} != 1 for model {model.__class__.__name__}"
193
+ assert num_key_value_heads == 4, (
194
+ f"num_key_value_heads={num_key_value_heads} != 256 "
195
+ f"for this model {model.__class__.__name__}"
196
+ )
197
+
198
+ _check_()
138
199
 
139
200
  inputs = dict(
140
- input_ids=input_ids,
141
- token_type_ids=token_type_ids,
201
+ input_ids=dummies["input_ids"],
202
+ token_type_ids=dummies["token_type_ids"],
142
203
  attention_mask=dict(
143
- full_attention=torch.randn(batch_size, 1, sequence_length, max_sequence_length),
144
- sliding_attention=torch.randn(batch_size, 1, sequence_length, max_sequence_length),
204
+ full_attention=torch.randn(batch_size, 1, sequence_length, total_sequence_length),
205
+ sliding_attention=torch.randn(
206
+ batch_size, 1, sequence_length, total_sequence_length
207
+ ),
145
208
  ),
146
- cache_position=torch.arange(0, sequence_length).to(torch.int64),
147
209
  position_ids=torch.arange(0, sequence_length).to(torch.int64).expand((batch_size, -1)),
210
+ cache_position=torch.arange(0, sequence_length).to(torch.int64),
148
211
  past_key_values=make_hybrid_cache(
149
212
  [
150
213
  (
@@ -159,12 +222,121 @@ def _get_inputs_gemma3(
159
222
  ]
160
223
  ),
161
224
  pixel_values=torch.randn(n_images, num_channels, width, height).clamp(-1, 1),
162
- image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
225
+ # image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
226
+ # torch.int64
227
+ # ),
228
+ use_cache=True, # Gemma3 does not set this value to true when a cache is provided
229
+ )
230
+ return dict(inputs=inputs, dynamic_shapes=shapes)
231
+
232
+
233
+ def get_inputs_default(
234
+ model: torch.nn.Module,
235
+ config: Optional[Any],
236
+ dummy_max_token_id: int,
237
+ num_key_value_heads: int,
238
+ num_hidden_layers: int,
239
+ pad_token_id: int,
240
+ image_token_index: int,
241
+ head_dim: int,
242
+ width: int,
243
+ height: int,
244
+ num_channels: int,
245
+ batch_size: Optional[int] = 2,
246
+ sequence_length: Optional[int] = 43,
247
+ n_images: Optional[int] = 2,
248
+ max_sequence_length: Optional[int] = 43,
249
+ total_sequence_length: Optional[int] = 43,
250
+ add_second_input: int = 0,
251
+ **kwargs, # unused
252
+ ):
253
+ batch_size = 2 if batch_size is None else batch_size
254
+ sequence_length = 43 if sequence_length is None else sequence_length
255
+ n_images = 2 if n_images is None else n_images
256
+ max_sequence_length = 43 if max_sequence_length is None else max_sequence_length
257
+ total_sequence_length = 43 if total_sequence_length is None else total_sequence_length
258
+
259
+ assert batch_size > 0, "batch_size cannot be null"
260
+ assert (
261
+ "cls_cache" not in kwargs
262
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
263
+ batch = "batch"
264
+ batch_img = torch.export.Dim("batch_img", min=1, max=1024)
265
+ seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
266
+ cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
267
+ images = "images" # torch.export.Dim("images", min=1, max=4096)
268
+
269
+ shapes = {
270
+ "input_ids": {0: batch, 1: seq_length},
271
+ "token_type_ids": {0: batch, 1: seq_length},
272
+ "attention_mask": {0: batch, 1: "cache+seq"},
273
+ "position_ids": {0: batch, 1: "cache+seq"},
274
+ "past_key_values": [
275
+ [{0: batch} for _ in range(num_hidden_layers)],
276
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
277
+ ],
278
+ "pixel_values": (
279
+ {0: batch, 1: images}
280
+ if model.__class__.__name__ == "IdeficsForVisionText2Text"
281
+ else {0: batch_img}
282
+ ),
283
+ "image_attention_mask": {0: batch, 1: seq_length, 2: images},
284
+ "image_grid_thw": {0: batch},
285
+ "use_cache": None,
286
+ }
287
+
288
+ input_ids = torch.randint(0, dummy_max_token_id, (batch_size, total_sequence_length)).to(
289
+ torch.int64
290
+ )
291
+ if total_sequence_length > 0:
292
+ input_ids[0, 0] = image_token_index
293
+ if min(input_ids.shape) > 1:
294
+ input_ids[1, 1] = image_token_index
295
+ # input_ids[input_ids == image_token_index] = pad_token_id
296
+ token_type_ids = torch.zeros_like(input_ids)
297
+ token_type_ids[input_ids == image_token_index] = 1
298
+ image_grid_thw = torch.zeros((n_images, 3), dtype=torch.int64)
299
+ if n_images > 0:
300
+ image_grid_thw[:, 1] = height
301
+ image_grid_thw[:, 2] = width
302
+ image_grid_thw[0, :] //= 2
303
+ image_grid_thw[:, 0] = torch.arange(n_images, dtype=image_grid_thw.dtype)
304
+
305
+ inputs = dict(
306
+ input_ids=input_ids,
307
+ token_type_ids=token_type_ids,
308
+ attention_mask=torch.cat(
309
+ [
310
+ torch.ones((batch_size, sequence_length), dtype=torch.int64),
311
+ input_ids.ne(pad_token_id).to(torch.int64),
312
+ ],
313
+ axis=-1,
314
+ ),
315
+ position_ids=torch.arange(0, total_sequence_length)
316
+ .to(torch.int64)
317
+ .expand((batch_size, -1)),
318
+ past_key_values=make_dynamic_cache(
319
+ [
320
+ (
321
+ torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),
322
+ torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),
323
+ )
324
+ for i in range(num_hidden_layers)
325
+ ]
326
+ ),
327
+ pixel_values=(
328
+ torch.randn((batch_size, n_images, num_channels, width, height)).clamp(-1, 1)
329
+ if model.__class__.__name__ == "IdeficsForVisionText2Text"
330
+ else torch.randn(n_images, num_channels, width, height).clamp(-1, 1)
331
+ ),
332
+ image_attention_mask=torch.ones((batch_size, total_sequence_length, n_images)).to(
163
333
  torch.int64
164
334
  ),
335
+ image_grid_thw=image_grid_thw,
165
336
  use_cache=True, # Gemma3 does not set this value to true when a cache is provided
166
337
  )
167
- return dict(inputs=inputs, dynamic_shapes=shapes)
338
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
339
+ return res
168
340
 
169
341
 
170
342
  def get_inputs(
@@ -179,12 +351,12 @@ def get_inputs(
179
351
  width: int,
180
352
  height: int,
181
353
  num_channels: int,
182
- batch_size: int = 2,
183
- sequence_length: int = 43,
184
- sequence_length2: int = 43,
185
- n_images: int = 2,
186
- dynamic_rope: bool = False,
187
- add_second_input: int = 1,
354
+ batch_size: Optional[int] = None,
355
+ sequence_length: Optional[int] = None,
356
+ n_images: Optional[int] = None,
357
+ max_sequence_length: Optional[int] = None,
358
+ total_sequence_length: Optional[int] = None,
359
+ add_second_input: int = 0,
188
360
  **kwargs, # unused
189
361
  ):
190
362
  """
@@ -198,13 +370,19 @@ def get_inputs(
198
370
  :param image_token_index: image_token_index
199
371
  :param batch_size: batch size
200
372
  :param sequence_length: sequence length
201
- :param sequence_length2: new sequence length
373
+ :param max_sequence_length: for the cache
374
+ :param total_sequence_length: for the mask
202
375
  :param n_images: number of images
203
376
  :param width: width of the image
204
377
  :param height: height of the image
205
378
  :param num_channels: number of channels
206
- :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
207
379
  :return: dictionary
380
+
381
+ .. note::
382
+
383
+ The content of the input_ids and its shape is correlated to the images.
384
+ The function uses a predefined values. The function raises an exception
385
+ if dimension are not the expected ones.
208
386
  """
209
387
  if model.__class__.__name__.startswith("Gemma3"):
210
388
  res = _get_inputs_gemma3(
@@ -221,92 +399,32 @@ def get_inputs(
221
399
  num_channels=num_channels,
222
400
  batch_size=batch_size,
223
401
  sequence_length=sequence_length,
224
- sequence_length2=sequence_length2,
402
+ max_sequence_length=max_sequence_length,
403
+ total_sequence_length=total_sequence_length,
225
404
  n_images=n_images,
226
- dynamic_rope=dynamic_rope,
227
405
  **kwargs,
228
406
  )
229
407
  else:
230
- assert (
231
- "cls_cache" not in kwargs
232
- ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
233
- batch = "batch"
234
- batch_img = torch.export.Dim("batch_img", min=1, max=1024)
235
- seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
236
- cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
237
- images = "images" # torch.export.Dim("images", min=1, max=4096)
238
-
239
- shapes = {
240
- "input_ids": {0: batch, 1: seq_length},
241
- "token_type_ids": {0: batch, 1: seq_length},
242
- "attention_mask": {0: batch, 1: "cache+seq"},
243
- "position_ids": {0: batch, 1: "cache+seq"},
244
- "past_key_values": [
245
- [{0: batch} for _ in range(num_hidden_layers)],
246
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
247
- ],
248
- "pixel_values": (
249
- {0: batch, 1: images}
250
- if model.__class__.__name__ == "IdeficsForVisionText2Text"
251
- else {0: batch_img}
252
- ),
253
- "image_attention_mask": {0: batch, 1: seq_length, 2: images},
254
- "image_grid_thw": {0: batch},
255
- "use_cache": None,
256
- }
257
-
258
- input_ids = torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
259
- torch.int64
408
+ res = get_inputs_default(
409
+ model,
410
+ config,
411
+ dummy_max_token_id=dummy_max_token_id,
412
+ num_key_value_heads=num_key_value_heads,
413
+ num_hidden_layers=num_hidden_layers,
414
+ pad_token_id=pad_token_id,
415
+ image_token_index=image_token_index,
416
+ head_dim=head_dim,
417
+ width=width,
418
+ height=height,
419
+ num_channels=num_channels,
420
+ batch_size=batch_size,
421
+ sequence_length=sequence_length,
422
+ max_sequence_length=max_sequence_length,
423
+ total_sequence_length=total_sequence_length,
424
+ n_images=n_images,
425
+ **kwargs,
260
426
  )
261
- input_ids[0, 0] = image_token_index
262
- input_ids[1, 1] = image_token_index
263
- # input_ids[input_ids == image_token_index] = pad_token_id
264
- token_type_ids = torch.zeros_like(input_ids)
265
- token_type_ids[input_ids == image_token_index] = 1
266
- image_grid_thw = torch.zeros((n_images, 3), dtype=torch.int64)
267
- image_grid_thw[:, 1] = height
268
- image_grid_thw[:, 2] = width
269
- image_grid_thw[0, :] //= 2
270
- image_grid_thw[:, 0] = torch.arange(n_images, dtype=image_grid_thw.dtype)
271
427
 
272
- inputs = dict(
273
- input_ids=input_ids,
274
- attention_mask=torch.cat(
275
- [
276
- torch.ones((batch_size, sequence_length), dtype=torch.int64),
277
- input_ids.ne(pad_token_id).to(torch.int64),
278
- ],
279
- axis=-1,
280
- ),
281
- position_ids=torch.arange(0, sequence_length2)
282
- .to(torch.int64)
283
- .expand((batch_size, -1)),
284
- past_key_values=make_dynamic_cache(
285
- [
286
- (
287
- torch.randn(
288
- batch_size, num_key_value_heads, sequence_length, head_dim
289
- ),
290
- torch.randn(
291
- batch_size, num_key_value_heads, sequence_length, head_dim
292
- ),
293
- )
294
- for i in range(num_hidden_layers)
295
- ]
296
- ),
297
- pixel_values=(
298
- torch.randn((batch_size, n_images, num_channels, width, height)).clamp(-1, 1)
299
- if model.__class__.__name__ == "IdeficsForVisionText2Text"
300
- else torch.randn(n_images, num_channels, width, height).clamp(-1, 1)
301
- ),
302
- image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
303
- torch.int64
304
- ),
305
- token_type_ids=token_type_ids,
306
- image_grid_thw=image_grid_thw,
307
- use_cache=True, # Gemma3 does not set this value to true when a cache is provided
308
- )
309
- res = dict(inputs=inputs, dynamic_shapes=shapes)
310
428
  if add_second_input:
311
429
  assert (
312
430
  add_second_input > 0
@@ -321,11 +439,11 @@ def get_inputs(
321
439
  width=width,
322
440
  height=height,
323
441
  num_channels=num_channels,
324
- batch_size=batch_size + 1,
325
- sequence_length=sequence_length + add_second_input,
326
- sequence_length2=sequence_length2 + 1,
327
- n_images=n_images + 1,
328
- dynamic_rope=dynamic_rope,
442
+ batch_size=3,
443
+ sequence_length=1,
444
+ max_sequence_length=1,
445
+ total_sequence_length=1,
446
+ n_images=0,
329
447
  pad_token_id=pad_token_id,
330
448
  image_token_index=image_token_index,
331
449
  add_second_input=0,
@@ -368,9 +486,6 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
368
486
  text_config = False
369
487
  check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
370
488
  kwargs = dict(
371
- batch_size=2,
372
- sequence_length=43,
373
- sequence_length2=43,
374
489
  head_dim=(
375
490
  16
376
491
  if config is None
@@ -269,6 +269,21 @@ def get_inputs(
269
269
  add_second_input=0,
270
270
  **kwargs,
271
271
  )["inputs"]
272
+ res["inputs_empty_cache"] = get_inputs(
273
+ model=model,
274
+ config=config,
275
+ dummy_max_token_id=dummy_max_token_id,
276
+ num_hidden_layers=num_hidden_layers,
277
+ batch_size=batch_size,
278
+ sequence_length=0,
279
+ sequence_length2=sequence_length2,
280
+ dynamic_rope=dynamic_rope,
281
+ num_key_value_heads=num_key_value_heads,
282
+ head_dim=head_dim,
283
+ cls_cache=cls_cache,
284
+ add_second_input=0,
285
+ **kwargs,
286
+ )["inputs"]
272
287
  return res
273
288
 
274
289