onnx-diagnostic 0.7.12__py3-none-any.whl → 0.7.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/export/dynamic_shapes.py +11 -2
- onnx_diagnostic/helpers/helper.py +11 -5
- onnx_diagnostic/helpers/mini_onnx_builder.py +17 -0
- onnx_diagnostic/helpers/model_builder_helper.py +1 -0
- onnx_diagnostic/helpers/rt_helper.py +2 -1
- onnx_diagnostic/helpers/torch_helper.py +31 -7
- onnx_diagnostic/reference/torch_evaluator.py +2 -2
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/image_text_to_text.py +256 -141
- onnx_diagnostic/tasks/text_generation.py +15 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +177 -150
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +19 -1
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +29 -14
- onnx_diagnostic/torch_export_patches/patch_inputs.py +10 -6
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +116 -10
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +269 -4
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +36 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +31 -3
- onnx_diagnostic/torch_models/validate.py +114 -36
- onnx_diagnostic/torch_onnx/sbs.py +2 -1
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.13.dist-info}/METADATA +11 -31
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.13.dist-info}/RECORD +27 -25
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.13.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.13.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.12.dist-info → onnx_diagnostic-0.7.13.dist-info}/top_level.txt +0 -0
|
@@ -7,6 +7,7 @@ from ..helpers.config_helper import (
|
|
|
7
7
|
_pick,
|
|
8
8
|
default_num_hidden_layers as nhl,
|
|
9
9
|
)
|
|
10
|
+
from .data import get_data
|
|
10
11
|
|
|
11
12
|
__TASK__ = "image-text-to-text"
|
|
12
13
|
|
|
@@ -14,6 +15,27 @@ __TASK__ = "image-text-to-text"
|
|
|
14
15
|
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
15
16
|
"""Reduces a model size."""
|
|
16
17
|
kwargs: Dict[str, Any] = {}
|
|
18
|
+
if (
|
|
19
|
+
hasattr(config, "architectures")
|
|
20
|
+
and config.architectures
|
|
21
|
+
and config.architectures[0] == "Gemma3ForConditionalGeneration"
|
|
22
|
+
):
|
|
23
|
+
if hasattr(config, "vision_config"):
|
|
24
|
+
if hasattr(config.vision_config, "num_hidden_layers"):
|
|
25
|
+
config.vision_config.num_hidden_layers = min(
|
|
26
|
+
config.vision_config.num_hidden_layers, nhl()
|
|
27
|
+
)
|
|
28
|
+
if hasattr(config, "text_config"):
|
|
29
|
+
if hasattr(config.text_config, "intermediate_size"):
|
|
30
|
+
config.text_config.intermediate_size = min(
|
|
31
|
+
config.text_config.intermediate_size, 10240 // 10 * 5 // 2
|
|
32
|
+
)
|
|
33
|
+
config.text_config.hidden_size = min(
|
|
34
|
+
config.text_config.hidden_size, 2560 // 10 * 5 // 2
|
|
35
|
+
)
|
|
36
|
+
update_config(config, kwargs)
|
|
37
|
+
return kwargs
|
|
38
|
+
|
|
17
39
|
if hasattr(config, "num_hidden_layers"):
|
|
18
40
|
config.num_hidden_layers = min(config.num_hidden_layers, nhl())
|
|
19
41
|
if hasattr(config, "mm_tokens_per_image"):
|
|
@@ -72,54 +94,63 @@ def _get_inputs_gemma3(
|
|
|
72
94
|
width: int,
|
|
73
95
|
height: int,
|
|
74
96
|
num_channels: int,
|
|
75
|
-
batch_size: int =
|
|
76
|
-
sequence_length: int =
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
max_sequence_length: int = 380,
|
|
97
|
+
batch_size: Optional[int] = 1,
|
|
98
|
+
sequence_length: Optional[int] = 281,
|
|
99
|
+
n_images: Optional[int] = 1,
|
|
100
|
+
max_sequence_length: Optional[int] = 580,
|
|
101
|
+
total_sequence_length: Optional[int] = 860,
|
|
81
102
|
**kwargs, # unused
|
|
82
103
|
):
|
|
83
104
|
"""
|
|
105
|
+
The functions uses predefined values for input_ids and token_type_ids.
|
|
106
|
+
|
|
107
|
+
**google/gemma-3-4b-it**
|
|
108
|
+
|
|
109
|
+
iteration 1
|
|
110
|
+
|
|
84
111
|
::
|
|
112
|
+
cache_position:T7s281,
|
|
113
|
+
input_ids:T7s1x281,
|
|
114
|
+
token_type_ids:T7s1x281,
|
|
115
|
+
attention_mask:dict(sliding_attention:T9s1x1x281x580,
|
|
116
|
+
full_attention:T9s1x1x281x580),
|
|
117
|
+
pixel_values:T16s1x3x896x896,
|
|
85
118
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
attention_mask:dict(full_attention:T9s1x1x1x380,sliding_attention:T9s1x1x1x380),
|
|
99
|
-
position_ids:T7s1x1,
|
|
100
|
-
past_key_values:HybridCache(
|
|
101
|
-
key_cache=#34[T1s1x4x380x256,...],
|
|
102
|
-
value_cache=#34[T1s1x4x380x256,...]),
|
|
103
|
-
token_type_ids:T7s1x1,
|
|
104
|
-
cache_position:T7s1,
|
|
105
|
-
logits_to_keep:1)
|
|
119
|
+
iteration 2
|
|
120
|
+
|
|
121
|
+
::
|
|
122
|
+
|
|
123
|
+
cache_position:T7s1,
|
|
124
|
+
past_key_values:StaticCache(key_cache=#34[T1s1x4x580x256,...],
|
|
125
|
+
value_cache=#34[T1s1x4x580x256,...]),
|
|
126
|
+
input_ids:T7s1x1,
|
|
127
|
+
inputs_embeds:None,
|
|
128
|
+
token_type_ids:T7s1x1,
|
|
129
|
+
attention_mask:dict(sliding_attention:T9s1x1x1x580,full_attention:T9s1x1x1x580),
|
|
130
|
+
position_ids:None,
|
|
106
131
|
"""
|
|
132
|
+
batch_size = 1 if batch_size is None else batch_size
|
|
133
|
+
sequence_length = 281 if sequence_length is None else sequence_length
|
|
134
|
+
n_images = 1 if n_images is None else n_images
|
|
135
|
+
max_sequence_length = 580 if max_sequence_length is None else max_sequence_length
|
|
136
|
+
total_sequence_length = 860 if total_sequence_length is None else total_sequence_length
|
|
137
|
+
|
|
107
138
|
assert (
|
|
108
139
|
"cls_cache" not in kwargs
|
|
109
140
|
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
110
141
|
batch = "batch"
|
|
111
|
-
seq_length = "seq_length"
|
|
112
|
-
|
|
142
|
+
seq_length = "seq_length"
|
|
143
|
+
tot_length = "total_length"
|
|
113
144
|
|
|
114
145
|
shapes = {
|
|
115
146
|
"input_ids": {0: batch, 1: seq_length},
|
|
116
147
|
"token_type_ids": {0: batch, 1: seq_length},
|
|
117
148
|
"attention_mask": {
|
|
118
|
-
"full_attention": {0: batch, 2: seq_length},
|
|
119
|
-
"sliding_attention": {0: batch, 2: seq_length},
|
|
149
|
+
"full_attention": {0: batch, 2: seq_length, 3: tot_length},
|
|
150
|
+
"sliding_attention": {0: batch, 2: seq_length, 3: tot_length},
|
|
120
151
|
},
|
|
121
152
|
"position_ids": {0: batch, 1: seq_length},
|
|
122
|
-
"cache_position": {
|
|
153
|
+
"cache_position": {0: seq_length},
|
|
123
154
|
"past_key_values": [
|
|
124
155
|
[{0: batch} for _ in range(num_hidden_layers)],
|
|
125
156
|
[{0: batch} for _ in range(num_hidden_layers)],
|
|
@@ -128,23 +159,55 @@ def _get_inputs_gemma3(
|
|
|
128
159
|
"use_cache": None,
|
|
129
160
|
}
|
|
130
161
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
162
|
+
# retrieve specific inputs to keep the consistency between
|
|
163
|
+
# ids and images
|
|
164
|
+
dummies = get_data("dummies_imagetext2text_generation_gemma3.onnx")
|
|
165
|
+
dummies = dummies[("", 0, "I")][1]
|
|
166
|
+
dummies = {k: v for k, v in dummies.items() if k in shapes}
|
|
167
|
+
expected = {"input_ids", "token_type_ids", "position_ids", "cache_position"}
|
|
168
|
+
|
|
169
|
+
def _check_():
|
|
170
|
+
assert expected & set(
|
|
171
|
+
dummies
|
|
172
|
+
), f"Unable to find expected inputs {expected} in loaded inputs {set(dummies)}"
|
|
173
|
+
assert sequence_length == dummies["input_ids"].shape[-1], (
|
|
174
|
+
f"sequence_length={sequence_length} != {dummies['input_ids'].shape[-1]} for "
|
|
175
|
+
f"model class {model.__class__.__name__}"
|
|
176
|
+
)
|
|
177
|
+
assert batch_size == dummies["input_ids"].shape[0], (
|
|
178
|
+
f"batch_size={batch_size} != {dummies['input_ids'].shape[0]} for "
|
|
179
|
+
f"model class {model.__class__.__name__}"
|
|
180
|
+
)
|
|
181
|
+
assert max_sequence_length == 580, (
|
|
182
|
+
f"max_sequence_length={max_sequence_length} != 580 "
|
|
183
|
+
f"for model {model.__class__.__name__}"
|
|
184
|
+
)
|
|
185
|
+
assert total_sequence_length == 860, (
|
|
186
|
+
f"total_sequence_length={total_sequence_length} != 860 "
|
|
187
|
+
f"for model {model.__class__.__name__}"
|
|
188
|
+
)
|
|
189
|
+
assert (
|
|
190
|
+
head_dim == 256
|
|
191
|
+
), f"head_dim={head_dim} != 256 for model {model.__class__.__name__}"
|
|
192
|
+
assert n_images == 1, f"n_images={n_images} != 1 for model {model.__class__.__name__}"
|
|
193
|
+
assert num_key_value_heads == 4, (
|
|
194
|
+
f"num_key_value_heads={num_key_value_heads} != 256 "
|
|
195
|
+
f"for this model {model.__class__.__name__}"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
_check_()
|
|
138
199
|
|
|
139
200
|
inputs = dict(
|
|
140
|
-
input_ids=input_ids,
|
|
141
|
-
token_type_ids=token_type_ids,
|
|
201
|
+
input_ids=dummies["input_ids"],
|
|
202
|
+
token_type_ids=dummies["token_type_ids"],
|
|
142
203
|
attention_mask=dict(
|
|
143
|
-
full_attention=torch.randn(batch_size, 1, sequence_length,
|
|
144
|
-
sliding_attention=torch.randn(
|
|
204
|
+
full_attention=torch.randn(batch_size, 1, sequence_length, total_sequence_length),
|
|
205
|
+
sliding_attention=torch.randn(
|
|
206
|
+
batch_size, 1, sequence_length, total_sequence_length
|
|
207
|
+
),
|
|
145
208
|
),
|
|
146
|
-
cache_position=torch.arange(0, sequence_length).to(torch.int64),
|
|
147
209
|
position_ids=torch.arange(0, sequence_length).to(torch.int64).expand((batch_size, -1)),
|
|
210
|
+
cache_position=torch.arange(0, sequence_length).to(torch.int64),
|
|
148
211
|
past_key_values=make_hybrid_cache(
|
|
149
212
|
[
|
|
150
213
|
(
|
|
@@ -159,12 +222,121 @@ def _get_inputs_gemma3(
|
|
|
159
222
|
]
|
|
160
223
|
),
|
|
161
224
|
pixel_values=torch.randn(n_images, num_channels, width, height).clamp(-1, 1),
|
|
162
|
-
image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
|
|
225
|
+
# image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
|
|
226
|
+
# torch.int64
|
|
227
|
+
# ),
|
|
228
|
+
use_cache=True, # Gemma3 does not set this value to true when a cache is provided
|
|
229
|
+
)
|
|
230
|
+
return dict(inputs=inputs, dynamic_shapes=shapes)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def get_inputs_default(
|
|
234
|
+
model: torch.nn.Module,
|
|
235
|
+
config: Optional[Any],
|
|
236
|
+
dummy_max_token_id: int,
|
|
237
|
+
num_key_value_heads: int,
|
|
238
|
+
num_hidden_layers: int,
|
|
239
|
+
pad_token_id: int,
|
|
240
|
+
image_token_index: int,
|
|
241
|
+
head_dim: int,
|
|
242
|
+
width: int,
|
|
243
|
+
height: int,
|
|
244
|
+
num_channels: int,
|
|
245
|
+
batch_size: Optional[int] = 2,
|
|
246
|
+
sequence_length: Optional[int] = 43,
|
|
247
|
+
n_images: Optional[int] = 2,
|
|
248
|
+
max_sequence_length: Optional[int] = 43,
|
|
249
|
+
total_sequence_length: Optional[int] = 43,
|
|
250
|
+
add_second_input: int = 0,
|
|
251
|
+
**kwargs, # unused
|
|
252
|
+
):
|
|
253
|
+
batch_size = 2 if batch_size is None else batch_size
|
|
254
|
+
sequence_length = 43 if sequence_length is None else sequence_length
|
|
255
|
+
n_images = 2 if n_images is None else n_images
|
|
256
|
+
max_sequence_length = 43 if max_sequence_length is None else max_sequence_length
|
|
257
|
+
total_sequence_length = 43 if total_sequence_length is None else total_sequence_length
|
|
258
|
+
|
|
259
|
+
assert batch_size > 0, "batch_size cannot be null"
|
|
260
|
+
assert (
|
|
261
|
+
"cls_cache" not in kwargs
|
|
262
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
263
|
+
batch = "batch"
|
|
264
|
+
batch_img = torch.export.Dim("batch_img", min=1, max=1024)
|
|
265
|
+
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
|
|
266
|
+
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
|
|
267
|
+
images = "images" # torch.export.Dim("images", min=1, max=4096)
|
|
268
|
+
|
|
269
|
+
shapes = {
|
|
270
|
+
"input_ids": {0: batch, 1: seq_length},
|
|
271
|
+
"token_type_ids": {0: batch, 1: seq_length},
|
|
272
|
+
"attention_mask": {0: batch, 1: "cache+seq"},
|
|
273
|
+
"position_ids": {0: batch, 1: "cache+seq"},
|
|
274
|
+
"past_key_values": [
|
|
275
|
+
[{0: batch} for _ in range(num_hidden_layers)],
|
|
276
|
+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
277
|
+
],
|
|
278
|
+
"pixel_values": (
|
|
279
|
+
{0: batch, 1: images}
|
|
280
|
+
if model.__class__.__name__ == "IdeficsForVisionText2Text"
|
|
281
|
+
else {0: batch_img}
|
|
282
|
+
),
|
|
283
|
+
"image_attention_mask": {0: batch, 1: seq_length, 2: images},
|
|
284
|
+
"image_grid_thw": {0: batch},
|
|
285
|
+
"use_cache": None,
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
input_ids = torch.randint(0, dummy_max_token_id, (batch_size, total_sequence_length)).to(
|
|
289
|
+
torch.int64
|
|
290
|
+
)
|
|
291
|
+
if total_sequence_length > 0:
|
|
292
|
+
input_ids[0, 0] = image_token_index
|
|
293
|
+
if min(input_ids.shape) > 1:
|
|
294
|
+
input_ids[1, 1] = image_token_index
|
|
295
|
+
# input_ids[input_ids == image_token_index] = pad_token_id
|
|
296
|
+
token_type_ids = torch.zeros_like(input_ids)
|
|
297
|
+
token_type_ids[input_ids == image_token_index] = 1
|
|
298
|
+
image_grid_thw = torch.zeros((n_images, 3), dtype=torch.int64)
|
|
299
|
+
if n_images > 0:
|
|
300
|
+
image_grid_thw[:, 1] = height
|
|
301
|
+
image_grid_thw[:, 2] = width
|
|
302
|
+
image_grid_thw[0, :] //= 2
|
|
303
|
+
image_grid_thw[:, 0] = torch.arange(n_images, dtype=image_grid_thw.dtype)
|
|
304
|
+
|
|
305
|
+
inputs = dict(
|
|
306
|
+
input_ids=input_ids,
|
|
307
|
+
token_type_ids=token_type_ids,
|
|
308
|
+
attention_mask=torch.cat(
|
|
309
|
+
[
|
|
310
|
+
torch.ones((batch_size, sequence_length), dtype=torch.int64),
|
|
311
|
+
input_ids.ne(pad_token_id).to(torch.int64),
|
|
312
|
+
],
|
|
313
|
+
axis=-1,
|
|
314
|
+
),
|
|
315
|
+
position_ids=torch.arange(0, total_sequence_length)
|
|
316
|
+
.to(torch.int64)
|
|
317
|
+
.expand((batch_size, -1)),
|
|
318
|
+
past_key_values=make_dynamic_cache(
|
|
319
|
+
[
|
|
320
|
+
(
|
|
321
|
+
torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),
|
|
322
|
+
torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),
|
|
323
|
+
)
|
|
324
|
+
for i in range(num_hidden_layers)
|
|
325
|
+
]
|
|
326
|
+
),
|
|
327
|
+
pixel_values=(
|
|
328
|
+
torch.randn((batch_size, n_images, num_channels, width, height)).clamp(-1, 1)
|
|
329
|
+
if model.__class__.__name__ == "IdeficsForVisionText2Text"
|
|
330
|
+
else torch.randn(n_images, num_channels, width, height).clamp(-1, 1)
|
|
331
|
+
),
|
|
332
|
+
image_attention_mask=torch.ones((batch_size, total_sequence_length, n_images)).to(
|
|
163
333
|
torch.int64
|
|
164
334
|
),
|
|
335
|
+
image_grid_thw=image_grid_thw,
|
|
165
336
|
use_cache=True, # Gemma3 does not set this value to true when a cache is provided
|
|
166
337
|
)
|
|
167
|
-
|
|
338
|
+
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
339
|
+
return res
|
|
168
340
|
|
|
169
341
|
|
|
170
342
|
def get_inputs(
|
|
@@ -179,12 +351,12 @@ def get_inputs(
|
|
|
179
351
|
width: int,
|
|
180
352
|
height: int,
|
|
181
353
|
num_channels: int,
|
|
182
|
-
batch_size: int =
|
|
183
|
-
sequence_length: int =
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
add_second_input: int =
|
|
354
|
+
batch_size: Optional[int] = None,
|
|
355
|
+
sequence_length: Optional[int] = None,
|
|
356
|
+
n_images: Optional[int] = None,
|
|
357
|
+
max_sequence_length: Optional[int] = None,
|
|
358
|
+
total_sequence_length: Optional[int] = None,
|
|
359
|
+
add_second_input: int = 0,
|
|
188
360
|
**kwargs, # unused
|
|
189
361
|
):
|
|
190
362
|
"""
|
|
@@ -198,13 +370,19 @@ def get_inputs(
|
|
|
198
370
|
:param image_token_index: image_token_index
|
|
199
371
|
:param batch_size: batch size
|
|
200
372
|
:param sequence_length: sequence length
|
|
201
|
-
:param
|
|
373
|
+
:param max_sequence_length: for the cache
|
|
374
|
+
:param total_sequence_length: for the mask
|
|
202
375
|
:param n_images: number of images
|
|
203
376
|
:param width: width of the image
|
|
204
377
|
:param height: height of the image
|
|
205
378
|
:param num_channels: number of channels
|
|
206
|
-
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
|
|
207
379
|
:return: dictionary
|
|
380
|
+
|
|
381
|
+
.. note::
|
|
382
|
+
|
|
383
|
+
The content of the input_ids and its shape is correlated to the images.
|
|
384
|
+
The function uses a predefined values. The function raises an exception
|
|
385
|
+
if dimension are not the expected ones.
|
|
208
386
|
"""
|
|
209
387
|
if model.__class__.__name__.startswith("Gemma3"):
|
|
210
388
|
res = _get_inputs_gemma3(
|
|
@@ -221,92 +399,32 @@ def get_inputs(
|
|
|
221
399
|
num_channels=num_channels,
|
|
222
400
|
batch_size=batch_size,
|
|
223
401
|
sequence_length=sequence_length,
|
|
224
|
-
|
|
402
|
+
max_sequence_length=max_sequence_length,
|
|
403
|
+
total_sequence_length=total_sequence_length,
|
|
225
404
|
n_images=n_images,
|
|
226
|
-
dynamic_rope=dynamic_rope,
|
|
227
405
|
**kwargs,
|
|
228
406
|
)
|
|
229
407
|
else:
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
"pixel_values": (
|
|
249
|
-
{0: batch, 1: images}
|
|
250
|
-
if model.__class__.__name__ == "IdeficsForVisionText2Text"
|
|
251
|
-
else {0: batch_img}
|
|
252
|
-
),
|
|
253
|
-
"image_attention_mask": {0: batch, 1: seq_length, 2: images},
|
|
254
|
-
"image_grid_thw": {0: batch},
|
|
255
|
-
"use_cache": None,
|
|
256
|
-
}
|
|
257
|
-
|
|
258
|
-
input_ids = torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
|
|
259
|
-
torch.int64
|
|
408
|
+
res = get_inputs_default(
|
|
409
|
+
model,
|
|
410
|
+
config,
|
|
411
|
+
dummy_max_token_id=dummy_max_token_id,
|
|
412
|
+
num_key_value_heads=num_key_value_heads,
|
|
413
|
+
num_hidden_layers=num_hidden_layers,
|
|
414
|
+
pad_token_id=pad_token_id,
|
|
415
|
+
image_token_index=image_token_index,
|
|
416
|
+
head_dim=head_dim,
|
|
417
|
+
width=width,
|
|
418
|
+
height=height,
|
|
419
|
+
num_channels=num_channels,
|
|
420
|
+
batch_size=batch_size,
|
|
421
|
+
sequence_length=sequence_length,
|
|
422
|
+
max_sequence_length=max_sequence_length,
|
|
423
|
+
total_sequence_length=total_sequence_length,
|
|
424
|
+
n_images=n_images,
|
|
425
|
+
**kwargs,
|
|
260
426
|
)
|
|
261
|
-
input_ids[0, 0] = image_token_index
|
|
262
|
-
input_ids[1, 1] = image_token_index
|
|
263
|
-
# input_ids[input_ids == image_token_index] = pad_token_id
|
|
264
|
-
token_type_ids = torch.zeros_like(input_ids)
|
|
265
|
-
token_type_ids[input_ids == image_token_index] = 1
|
|
266
|
-
image_grid_thw = torch.zeros((n_images, 3), dtype=torch.int64)
|
|
267
|
-
image_grid_thw[:, 1] = height
|
|
268
|
-
image_grid_thw[:, 2] = width
|
|
269
|
-
image_grid_thw[0, :] //= 2
|
|
270
|
-
image_grid_thw[:, 0] = torch.arange(n_images, dtype=image_grid_thw.dtype)
|
|
271
427
|
|
|
272
|
-
inputs = dict(
|
|
273
|
-
input_ids=input_ids,
|
|
274
|
-
attention_mask=torch.cat(
|
|
275
|
-
[
|
|
276
|
-
torch.ones((batch_size, sequence_length), dtype=torch.int64),
|
|
277
|
-
input_ids.ne(pad_token_id).to(torch.int64),
|
|
278
|
-
],
|
|
279
|
-
axis=-1,
|
|
280
|
-
),
|
|
281
|
-
position_ids=torch.arange(0, sequence_length2)
|
|
282
|
-
.to(torch.int64)
|
|
283
|
-
.expand((batch_size, -1)),
|
|
284
|
-
past_key_values=make_dynamic_cache(
|
|
285
|
-
[
|
|
286
|
-
(
|
|
287
|
-
torch.randn(
|
|
288
|
-
batch_size, num_key_value_heads, sequence_length, head_dim
|
|
289
|
-
),
|
|
290
|
-
torch.randn(
|
|
291
|
-
batch_size, num_key_value_heads, sequence_length, head_dim
|
|
292
|
-
),
|
|
293
|
-
)
|
|
294
|
-
for i in range(num_hidden_layers)
|
|
295
|
-
]
|
|
296
|
-
),
|
|
297
|
-
pixel_values=(
|
|
298
|
-
torch.randn((batch_size, n_images, num_channels, width, height)).clamp(-1, 1)
|
|
299
|
-
if model.__class__.__name__ == "IdeficsForVisionText2Text"
|
|
300
|
-
else torch.randn(n_images, num_channels, width, height).clamp(-1, 1)
|
|
301
|
-
),
|
|
302
|
-
image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
|
|
303
|
-
torch.int64
|
|
304
|
-
),
|
|
305
|
-
token_type_ids=token_type_ids,
|
|
306
|
-
image_grid_thw=image_grid_thw,
|
|
307
|
-
use_cache=True, # Gemma3 does not set this value to true when a cache is provided
|
|
308
|
-
)
|
|
309
|
-
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
310
428
|
if add_second_input:
|
|
311
429
|
assert (
|
|
312
430
|
add_second_input > 0
|
|
@@ -321,11 +439,11 @@ def get_inputs(
|
|
|
321
439
|
width=width,
|
|
322
440
|
height=height,
|
|
323
441
|
num_channels=num_channels,
|
|
324
|
-
batch_size=
|
|
325
|
-
sequence_length=
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
442
|
+
batch_size=3,
|
|
443
|
+
sequence_length=1,
|
|
444
|
+
max_sequence_length=1,
|
|
445
|
+
total_sequence_length=1,
|
|
446
|
+
n_images=0,
|
|
329
447
|
pad_token_id=pad_token_id,
|
|
330
448
|
image_token_index=image_token_index,
|
|
331
449
|
add_second_input=0,
|
|
@@ -368,9 +486,6 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
368
486
|
text_config = False
|
|
369
487
|
check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
|
|
370
488
|
kwargs = dict(
|
|
371
|
-
batch_size=2,
|
|
372
|
-
sequence_length=43,
|
|
373
|
-
sequence_length2=43,
|
|
374
489
|
head_dim=(
|
|
375
490
|
16
|
|
376
491
|
if config is None
|
|
@@ -269,6 +269,21 @@ def get_inputs(
|
|
|
269
269
|
add_second_input=0,
|
|
270
270
|
**kwargs,
|
|
271
271
|
)["inputs"]
|
|
272
|
+
res["inputs_empty_cache"] = get_inputs(
|
|
273
|
+
model=model,
|
|
274
|
+
config=config,
|
|
275
|
+
dummy_max_token_id=dummy_max_token_id,
|
|
276
|
+
num_hidden_layers=num_hidden_layers,
|
|
277
|
+
batch_size=batch_size,
|
|
278
|
+
sequence_length=0,
|
|
279
|
+
sequence_length2=sequence_length2,
|
|
280
|
+
dynamic_rope=dynamic_rope,
|
|
281
|
+
num_key_value_heads=num_key_value_heads,
|
|
282
|
+
head_dim=head_dim,
|
|
283
|
+
cls_cache=cls_cache,
|
|
284
|
+
add_second_input=0,
|
|
285
|
+
**kwargs,
|
|
286
|
+
)["inputs"]
|
|
272
287
|
return res
|
|
273
288
|
|
|
274
289
|
|