onnx-diagnostic 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +66 -8
- onnx_diagnostic/ext_test_case.py +2 -0
- onnx_diagnostic/helpers/_log_helper.py +461 -0
- onnx_diagnostic/helpers/cache_helper.py +250 -15
- onnx_diagnostic/helpers/helper.py +146 -10
- onnx_diagnostic/helpers/log_helper.py +404 -315
- onnx_diagnostic/helpers/mini_onnx_builder.py +7 -2
- onnx_diagnostic/helpers/onnx_helper.py +13 -7
- onnx_diagnostic/helpers/torch_helper.py +33 -11
- onnx_diagnostic/tasks/__init__.py +2 -0
- onnx_diagnostic/tasks/feature_extraction.py +86 -5
- onnx_diagnostic/tasks/image_text_to_text.py +260 -56
- onnx_diagnostic/tasks/mask_generation.py +139 -0
- onnx_diagnostic/tasks/text2text_generation.py +2 -2
- onnx_diagnostic/tasks/text_generation.py +6 -2
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +7 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +17 -1
- onnx_diagnostic/torch_export_patches/patch_inputs.py +4 -1
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +397 -128
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +57 -40
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +288 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +5 -0
- onnx_diagnostic/torch_models/validate.py +26 -3
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/RECORD +29 -27
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
2
|
import torch
|
|
3
|
-
from ..helpers.cache_helper import make_dynamic_cache
|
|
3
|
+
from ..helpers.cache_helper import make_dynamic_cache, make_hybrid_cache
|
|
4
4
|
from ..helpers.config_helper import update_config, check_hasattr, _pick
|
|
5
5
|
|
|
6
6
|
__TASK__ = "image-text-to-text"
|
|
@@ -11,99 +11,284 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
|
11
11
|
kwargs: Dict[str, Any] = {}
|
|
12
12
|
if hasattr(config, "num_hidden_layers"):
|
|
13
13
|
config.num_hidden_layers = min(config.num_hidden_layers, 2)
|
|
14
|
-
if hasattr(config, "
|
|
15
|
-
config.
|
|
14
|
+
if hasattr(config, "mm_tokens_per_image"):
|
|
15
|
+
config.mm_tokens_per_image = min(config.mm_tokens_per_image, 2)
|
|
16
|
+
if hasattr(config, "vision_config"):
|
|
17
|
+
if hasattr(config.vision_config, "num_hidden_layers"):
|
|
18
|
+
config.vision_config.num_hidden_layers = min(
|
|
19
|
+
config.vision_config.num_hidden_layers, 2
|
|
20
|
+
)
|
|
21
|
+
if hasattr(config.vision_config, "image_size"):
|
|
22
|
+
config.vision_config.image_size = min(config.vision_config.image_size, 96)
|
|
23
|
+
if hasattr(config.vision_config, "intermediate_size"):
|
|
24
|
+
config.vision_config.intermediate_size = min(
|
|
25
|
+
config.vision_config.intermediate_size, 1076
|
|
26
|
+
)
|
|
27
|
+
if hasattr(config.vision_config, "patch_size"):
|
|
28
|
+
config.vision_config.patch_size = min(config.vision_config.patch_size, 2)
|
|
29
|
+
if hasattr(config.vision_config, "hidden_size"):
|
|
30
|
+
config.vision_config.hidden_size = min(config.vision_config.hidden_size, 16)
|
|
31
|
+
if hasattr(config, "text_config"):
|
|
32
|
+
if hasattr(config.text_config, "intermediate_size"):
|
|
33
|
+
config.text_config.intermediate_size = min(
|
|
34
|
+
config.text_config.intermediate_size, 320
|
|
35
|
+
)
|
|
36
|
+
if hasattr(config.text_config, "hidden_size"):
|
|
37
|
+
config.text_config.hidden_size = min(config.text_config.hidden_size, 16)
|
|
38
|
+
if hasattr(config.text_config, "num_hidden_layers"):
|
|
39
|
+
config.text_config.num_hidden_layers = min(config.text_config.num_hidden_layers, 2)
|
|
40
|
+
if hasattr(config.text_config, "layer_types"):
|
|
41
|
+
config.text_config.layer_types = config.text_config.layer_types[
|
|
42
|
+
: config.text_config.num_hidden_layers
|
|
43
|
+
]
|
|
44
|
+
if hasattr(config.text_config, "num_attention_heads"):
|
|
45
|
+
config.text_config.num_attention_heads = min(
|
|
46
|
+
config.text_config.num_attention_heads, 2
|
|
47
|
+
)
|
|
16
48
|
update_config(config, kwargs)
|
|
17
49
|
return kwargs
|
|
18
50
|
|
|
19
51
|
|
|
20
|
-
def
|
|
52
|
+
def _get_inputs_gemma3(
|
|
21
53
|
model: torch.nn.Module,
|
|
22
54
|
config: Optional[Any],
|
|
23
55
|
dummy_max_token_id: int,
|
|
24
56
|
num_key_value_heads: int,
|
|
25
57
|
num_hidden_layers: int,
|
|
58
|
+
pad_token_id: int,
|
|
59
|
+
image_token_index: int,
|
|
26
60
|
head_dim: int,
|
|
27
61
|
width: int,
|
|
28
62
|
height: int,
|
|
29
63
|
num_channels: int,
|
|
30
64
|
batch_size: int = 2,
|
|
31
|
-
sequence_length: int =
|
|
32
|
-
sequence_length2: int =
|
|
65
|
+
sequence_length: int = 43,
|
|
66
|
+
sequence_length2: int = 43,
|
|
33
67
|
n_images: int = 2,
|
|
34
68
|
dynamic_rope: bool = False,
|
|
35
|
-
|
|
69
|
+
max_sequence_length: int = 380,
|
|
36
70
|
**kwargs, # unused
|
|
37
71
|
):
|
|
38
72
|
"""
|
|
39
|
-
|
|
73
|
+
::
|
|
40
74
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
75
|
+
dict(input_ids:T7s1x281,
|
|
76
|
+
pixel_values:T16s1x3x896x896,
|
|
77
|
+
attention_mask:dict(full_attention:T9s1x1x281x380,sliding_attention:T9s1x1x281x380),
|
|
78
|
+
position_ids:T7s1x281,
|
|
79
|
+
past_key_values:HybridCache(
|
|
80
|
+
key_cache=#34[T1s1x4x380x256,...],
|
|
81
|
+
value_cache=#34[T1s1x4x380x256,...]),
|
|
82
|
+
token_type_ids:T7s1x281,
|
|
83
|
+
cache_position:T7s281,
|
|
84
|
+
logits_to_keep:1)
|
|
85
|
+
dict(input_ids:T7s1x1,
|
|
86
|
+
pixel_values:None,
|
|
87
|
+
attention_mask:dict(full_attention:T9s1x1x1x380,sliding_attention:T9s1x1x1x380),
|
|
88
|
+
position_ids:T7s1x1,
|
|
89
|
+
past_key_values:HybridCache(
|
|
90
|
+
key_cache=#34[T1s1x4x380x256,...],
|
|
91
|
+
value_cache=#34[T1s1x4x380x256,...]),
|
|
92
|
+
token_type_ids:T7s1x1,
|
|
93
|
+
cache_position:T7s1,
|
|
94
|
+
logits_to_keep:1)
|
|
54
95
|
"""
|
|
55
96
|
assert (
|
|
56
97
|
"cls_cache" not in kwargs
|
|
57
98
|
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
58
99
|
batch = torch.export.Dim("batch", min=1, max=1024)
|
|
59
100
|
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
|
|
60
|
-
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
|
|
61
|
-
images = "images" # torch.export.Dim("images", min=1, max=4096)
|
|
101
|
+
# cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
|
|
62
102
|
|
|
63
103
|
shapes = {
|
|
64
104
|
"input_ids": {0: batch, 1: seq_length},
|
|
105
|
+
"token_type_ids": {0: batch, 1: seq_length},
|
|
65
106
|
"attention_mask": {
|
|
66
|
-
0: batch,
|
|
67
|
-
|
|
68
|
-
},
|
|
69
|
-
"position_ids": {
|
|
70
|
-
0: batch,
|
|
71
|
-
1: "cache+seq", # cache_length + seq_length
|
|
107
|
+
"full_attention": {0: batch, 2: seq_length},
|
|
108
|
+
"sliding_attention": {0: batch, 2: seq_length},
|
|
72
109
|
},
|
|
110
|
+
"position_ids": {0: batch, 1: seq_length},
|
|
111
|
+
"cache_position": {1: seq_length},
|
|
73
112
|
"past_key_values": [
|
|
74
|
-
[{0: batch
|
|
75
|
-
[{0: batch
|
|
113
|
+
[{0: batch} for _ in range(num_hidden_layers)],
|
|
114
|
+
[{0: batch} for _ in range(num_hidden_layers)],
|
|
76
115
|
],
|
|
77
|
-
"pixel_values": {0: batch
|
|
78
|
-
"
|
|
116
|
+
"pixel_values": {0: batch},
|
|
117
|
+
"use_cache": None,
|
|
79
118
|
}
|
|
119
|
+
|
|
120
|
+
input_ids = torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
|
|
121
|
+
torch.int64
|
|
122
|
+
)
|
|
123
|
+
input_ids[:, 1] = image_token_index
|
|
124
|
+
# input_ids[input_ids == image_token_index] = pad_token_id
|
|
125
|
+
token_type_ids = torch.zeros_like(input_ids)
|
|
126
|
+
token_type_ids[input_ids == image_token_index] = 1
|
|
127
|
+
|
|
80
128
|
inputs = dict(
|
|
81
|
-
input_ids=
|
|
82
|
-
|
|
129
|
+
input_ids=input_ids,
|
|
130
|
+
token_type_ids=token_type_ids,
|
|
131
|
+
attention_mask=dict(
|
|
132
|
+
full_attention=torch.randn(batch_size, 1, sequence_length, max_sequence_length),
|
|
133
|
+
sliding_attention=torch.randn(batch_size, 1, sequence_length, max_sequence_length),
|
|
83
134
|
),
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
|
|
88
|
-
.to(torch.int64)
|
|
89
|
-
.expand((batch_size, -1)),
|
|
90
|
-
past_key_values=make_dynamic_cache(
|
|
135
|
+
cache_position=torch.arange(0, sequence_length).to(torch.int64),
|
|
136
|
+
position_ids=torch.arange(0, sequence_length).to(torch.int64).expand((batch_size, -1)),
|
|
137
|
+
past_key_values=make_hybrid_cache(
|
|
91
138
|
[
|
|
92
139
|
(
|
|
93
|
-
torch.randn(
|
|
94
|
-
|
|
140
|
+
torch.randn(
|
|
141
|
+
batch_size, num_key_value_heads, max_sequence_length, head_dim
|
|
142
|
+
),
|
|
143
|
+
torch.randn(
|
|
144
|
+
batch_size, num_key_value_heads, max_sequence_length, head_dim
|
|
145
|
+
),
|
|
95
146
|
)
|
|
96
147
|
for i in range(num_hidden_layers)
|
|
97
148
|
]
|
|
98
149
|
),
|
|
99
|
-
pixel_values=torch.
|
|
100
|
-
torch.int64
|
|
101
|
-
),
|
|
150
|
+
pixel_values=torch.randn(n_images, num_channels, width, height).clamp(-1, 1),
|
|
102
151
|
image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
|
|
103
152
|
torch.int64
|
|
104
153
|
),
|
|
154
|
+
use_cache=True, # Gemma3 does not set this value to true when a cache is provided
|
|
105
155
|
)
|
|
106
|
-
|
|
156
|
+
return dict(inputs=inputs, dynamic_shapes=shapes)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def get_inputs(
|
|
160
|
+
model: torch.nn.Module,
|
|
161
|
+
config: Optional[Any],
|
|
162
|
+
dummy_max_token_id: int,
|
|
163
|
+
num_key_value_heads: int,
|
|
164
|
+
num_hidden_layers: int,
|
|
165
|
+
pad_token_id: int,
|
|
166
|
+
image_token_index: int,
|
|
167
|
+
head_dim: int,
|
|
168
|
+
width: int,
|
|
169
|
+
height: int,
|
|
170
|
+
num_channels: int,
|
|
171
|
+
batch_size: int = 2,
|
|
172
|
+
sequence_length: int = 43,
|
|
173
|
+
sequence_length2: int = 43,
|
|
174
|
+
n_images: int = 2,
|
|
175
|
+
dynamic_rope: bool = False,
|
|
176
|
+
add_second_input: int = 1,
|
|
177
|
+
**kwargs, # unused
|
|
178
|
+
):
|
|
179
|
+
"""
|
|
180
|
+
Generates input for task ``image-text-to-text``.
|
|
181
|
+
|
|
182
|
+
:param model: model to get the missing information
|
|
183
|
+
:param config: configuration used to generate the model
|
|
184
|
+
:param head_dim: last dimension of the cache
|
|
185
|
+
:param dummy_max_token_id: dummy max token id
|
|
186
|
+
:param pad_token_id: pad_token_id
|
|
187
|
+
:param image_token_index: image_token_index
|
|
188
|
+
:param batch_size: batch size
|
|
189
|
+
:param sequence_length: sequence length
|
|
190
|
+
:param sequence_length2: new sequence length
|
|
191
|
+
:param n_images: number of images
|
|
192
|
+
:param width: width of the image
|
|
193
|
+
:param height: height of the image
|
|
194
|
+
:param num_channels: number of channels
|
|
195
|
+
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
|
|
196
|
+
:return: dictionary
|
|
197
|
+
"""
|
|
198
|
+
if model.__class__.__name__.startswith("Gemma3"):
|
|
199
|
+
res = _get_inputs_gemma3(
|
|
200
|
+
model,
|
|
201
|
+
config,
|
|
202
|
+
dummy_max_token_id=dummy_max_token_id,
|
|
203
|
+
num_key_value_heads=num_key_value_heads,
|
|
204
|
+
num_hidden_layers=num_hidden_layers,
|
|
205
|
+
pad_token_id=pad_token_id,
|
|
206
|
+
image_token_index=image_token_index,
|
|
207
|
+
head_dim=head_dim,
|
|
208
|
+
width=width,
|
|
209
|
+
height=height,
|
|
210
|
+
num_channels=num_channels,
|
|
211
|
+
batch_size=batch_size,
|
|
212
|
+
sequence_length=sequence_length,
|
|
213
|
+
sequence_length2=sequence_length2,
|
|
214
|
+
n_images=n_images,
|
|
215
|
+
dynamic_rope=dynamic_rope,
|
|
216
|
+
**kwargs,
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
assert (
|
|
220
|
+
"cls_cache" not in kwargs
|
|
221
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
222
|
+
batch = torch.export.Dim("batch", min=1, max=1024)
|
|
223
|
+
batch_img = torch.export.Dim("batch_img", min=1, max=1024)
|
|
224
|
+
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
|
|
225
|
+
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
|
|
226
|
+
images = "images" # torch.export.Dim("images", min=1, max=4096)
|
|
227
|
+
|
|
228
|
+
shapes = {
|
|
229
|
+
"input_ids": {0: batch, 1: seq_length},
|
|
230
|
+
"token_type_ids": {0: batch, 1: seq_length},
|
|
231
|
+
"attention_mask": {0: batch, 1: "cache+seq"},
|
|
232
|
+
"position_ids": {0: batch, 1: "cache+seq"},
|
|
233
|
+
"past_key_values": [
|
|
234
|
+
[{0: batch} for _ in range(num_hidden_layers)],
|
|
235
|
+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
236
|
+
],
|
|
237
|
+
"pixel_values": (
|
|
238
|
+
{0: batch, 1: images}
|
|
239
|
+
if model.__class__.__name__ == "IdeficsForVisionText2Text"
|
|
240
|
+
else {0: batch_img}
|
|
241
|
+
),
|
|
242
|
+
"image_attention_mask": {0: batch, 1: seq_length, 2: images},
|
|
243
|
+
"use_cache": None,
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
input_ids = torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
|
|
247
|
+
torch.int64
|
|
248
|
+
)
|
|
249
|
+
input_ids[0, 0] = image_token_index
|
|
250
|
+
input_ids[1, 1] = image_token_index
|
|
251
|
+
# input_ids[input_ids == image_token_index] = pad_token_id
|
|
252
|
+
token_type_ids = torch.zeros_like(input_ids)
|
|
253
|
+
token_type_ids[input_ids == image_token_index] = 1
|
|
254
|
+
|
|
255
|
+
inputs = dict(
|
|
256
|
+
input_ids=input_ids,
|
|
257
|
+
attention_mask=torch.cat(
|
|
258
|
+
[
|
|
259
|
+
torch.ones((batch_size, sequence_length), dtype=torch.int64),
|
|
260
|
+
input_ids.ne(pad_token_id).to(torch.int64),
|
|
261
|
+
],
|
|
262
|
+
axis=-1,
|
|
263
|
+
),
|
|
264
|
+
position_ids=torch.arange(0, sequence_length2)
|
|
265
|
+
.to(torch.int64)
|
|
266
|
+
.expand((batch_size, -1)),
|
|
267
|
+
past_key_values=make_dynamic_cache(
|
|
268
|
+
[
|
|
269
|
+
(
|
|
270
|
+
torch.randn(
|
|
271
|
+
batch_size, num_key_value_heads, sequence_length, head_dim
|
|
272
|
+
),
|
|
273
|
+
torch.randn(
|
|
274
|
+
batch_size, num_key_value_heads, sequence_length, head_dim
|
|
275
|
+
),
|
|
276
|
+
)
|
|
277
|
+
for i in range(num_hidden_layers)
|
|
278
|
+
]
|
|
279
|
+
),
|
|
280
|
+
pixel_values=(
|
|
281
|
+
torch.randn((batch_size, n_images, num_channels, width, height)).clamp(-1, 1)
|
|
282
|
+
if model.__class__.__name__ == "IdeficsForVisionText2Text"
|
|
283
|
+
else torch.randn(n_images, num_channels, width, height).clamp(-1, 1)
|
|
284
|
+
),
|
|
285
|
+
image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
|
|
286
|
+
torch.int64
|
|
287
|
+
),
|
|
288
|
+
token_type_ids=token_type_ids,
|
|
289
|
+
use_cache=True, # Gemma3 does not set this value to true when a cache is provided
|
|
290
|
+
)
|
|
291
|
+
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
107
292
|
if add_second_input:
|
|
108
293
|
assert (
|
|
109
294
|
add_second_input > 0
|
|
@@ -123,6 +308,8 @@ def get_inputs(
|
|
|
123
308
|
sequence_length2=sequence_length2 + 1,
|
|
124
309
|
n_images=n_images + 1,
|
|
125
310
|
dynamic_rope=dynamic_rope,
|
|
311
|
+
pad_token_id=pad_token_id,
|
|
312
|
+
image_token_index=image_token_index,
|
|
126
313
|
add_second_input=0,
|
|
127
314
|
**kwargs,
|
|
128
315
|
)["inputs"]
|
|
@@ -145,8 +332,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
145
332
|
("num_key_value_heads", "num_attention_heads"),
|
|
146
333
|
"intermediate_size",
|
|
147
334
|
"hidden_size",
|
|
335
|
+
"pad_token_id",
|
|
148
336
|
)
|
|
149
|
-
check_hasattr(config, "vision_config")
|
|
337
|
+
check_hasattr(config, "vision_config", "image_token_index")
|
|
150
338
|
text_config = True
|
|
151
339
|
else:
|
|
152
340
|
check_hasattr(
|
|
@@ -163,19 +351,25 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
163
351
|
check_hasattr(config.vision_config, "image_size", "num_channels")
|
|
164
352
|
kwargs = dict(
|
|
165
353
|
batch_size=2,
|
|
166
|
-
sequence_length=
|
|
167
|
-
sequence_length2=
|
|
354
|
+
sequence_length=43,
|
|
355
|
+
sequence_length2=43,
|
|
168
356
|
head_dim=(
|
|
169
357
|
16
|
|
170
358
|
if config is None
|
|
171
359
|
else getattr(
|
|
172
360
|
config,
|
|
173
361
|
"head_dim",
|
|
174
|
-
(
|
|
175
|
-
|
|
176
|
-
config.text_config
|
|
177
|
-
|
|
178
|
-
|
|
362
|
+
(
|
|
363
|
+
config.text_config.head_dim
|
|
364
|
+
if text_config and hasattr(config.text_config, "head_dim")
|
|
365
|
+
else (
|
|
366
|
+
(config.text_config.hidden_size if text_config else config.hidden_size)
|
|
367
|
+
// (
|
|
368
|
+
config.text_config.num_attention_heads
|
|
369
|
+
if text_config
|
|
370
|
+
else config.num_attention_heads
|
|
371
|
+
)
|
|
372
|
+
)
|
|
179
373
|
),
|
|
180
374
|
)
|
|
181
375
|
),
|
|
@@ -219,5 +413,15 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
219
413
|
width=224 if config is None else config.vision_config.image_size,
|
|
220
414
|
height=224 if config is None else config.vision_config.image_size,
|
|
221
415
|
num_channels=3 if config is None else config.vision_config.num_channels,
|
|
416
|
+
pad_token_id=(
|
|
417
|
+
0
|
|
418
|
+
if config is None or not hasattr(config, "text_config")
|
|
419
|
+
else config.text_config.pad_token_id
|
|
420
|
+
),
|
|
421
|
+
image_token_index=(
|
|
422
|
+
4
|
|
423
|
+
if config is None or not hasattr(config, "image_token_index")
|
|
424
|
+
else config.image_token_index
|
|
425
|
+
),
|
|
222
426
|
)
|
|
223
427
|
return kwargs, get_inputs
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
from ..helpers.config_helper import update_config, check_hasattr
|
|
4
|
+
|
|
5
|
+
__TASK__ = "mask-generation"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
9
|
+
"""Reduces a model size."""
|
|
10
|
+
kwargs: Dict[str, Any] = {}
|
|
11
|
+
if hasattr(config, "num_hidden_layers"):
|
|
12
|
+
config.num_hidden_layers = min(config.num_hidden_layers, 2)
|
|
13
|
+
if hasattr(config, "vision_config") and hasattr(config.vision_config, "num_hidden_layers"):
|
|
14
|
+
config.vision_config.num_hidden_layers = min(config.vision_config.num_hidden_layers, 2)
|
|
15
|
+
update_config(config, kwargs)
|
|
16
|
+
return kwargs
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_inputs(
|
|
20
|
+
model: torch.nn.Module,
|
|
21
|
+
config: Optional[Any],
|
|
22
|
+
batch_size: int,
|
|
23
|
+
width: int,
|
|
24
|
+
height: int,
|
|
25
|
+
num_channels: int,
|
|
26
|
+
output_channels: int,
|
|
27
|
+
window_size: int,
|
|
28
|
+
add_second_input: bool = True,
|
|
29
|
+
**kwargs, # unused
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Generates input for task ``mask-generation``.
|
|
33
|
+
|
|
34
|
+
:param model: model to get the missing information
|
|
35
|
+
:param config: configuration used to generate the model
|
|
36
|
+
:param batch_size: batch size
|
|
37
|
+
:param width: width of the image
|
|
38
|
+
:param height: height of the image
|
|
39
|
+
:param num_channels: number of channels in the image
|
|
40
|
+
:param output_channels: number of output channels
|
|
41
|
+
:param window_size: size of the window for the vision model
|
|
42
|
+
:return: dictionary with inputs and dynamic shapes
|
|
43
|
+
|
|
44
|
+
"""
|
|
45
|
+
assert (
|
|
46
|
+
"cls_cache" not in kwargs
|
|
47
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
48
|
+
|
|
49
|
+
# TODO(anyone): input_masks is weirdly failing all the time with mismatch channels
|
|
50
|
+
# with Conv or embedding_size. I guess maybe the model is too implicit on the
|
|
51
|
+
# input_masks shape.
|
|
52
|
+
|
|
53
|
+
# TODO(titaiwang): modeling code specifically requires the height and width of inputs
|
|
54
|
+
# should be the same as the config.vision_config.image_size. Does that make sense?
|
|
55
|
+
|
|
56
|
+
shapes = {
|
|
57
|
+
"pixel_values": {0: "batch"}, # 1: num_channels is static
|
|
58
|
+
"input_points": {0: "batch", 1: "point_batch_size", 2: "nb_points_per_image"},
|
|
59
|
+
"input_boxes": {0: "batch", 1: "point_batch_size"},
|
|
60
|
+
# "input_masks": {0: "batch", 2: "height", 3: "width"},
|
|
61
|
+
}
|
|
62
|
+
inputs = dict(
|
|
63
|
+
pixel_values=torch.randn(
|
|
64
|
+
(batch_size, num_channels, height, width), dtype=torch.float32
|
|
65
|
+
).clamp(-1, 1),
|
|
66
|
+
input_points=torch.randn(
|
|
67
|
+
(batch_size, 2, 10, 2), dtype=torch.float32
|
|
68
|
+
), # 10 points per image
|
|
69
|
+
input_boxes=torch.randn((batch_size, 2, 4), dtype=torch.float32), # 1 box per image
|
|
70
|
+
# input_masks=torch.randn(
|
|
71
|
+
# (batch_size, 1, height, width), dtype=torch.float32
|
|
72
|
+
# ), # mask for the image
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
76
|
+
if add_second_input:
|
|
77
|
+
assert (
|
|
78
|
+
add_second_input > 0
|
|
79
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
80
|
+
res["inputs2"] = get_inputs(
|
|
81
|
+
model=model,
|
|
82
|
+
config=config,
|
|
83
|
+
batch_size=batch_size + 1,
|
|
84
|
+
width=width,
|
|
85
|
+
height=height,
|
|
86
|
+
num_channels=num_channels,
|
|
87
|
+
output_channels=output_channels,
|
|
88
|
+
window_size=window_size,
|
|
89
|
+
add_second_input=False,
|
|
90
|
+
**kwargs,
|
|
91
|
+
)["inputs"]
|
|
92
|
+
return res
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
96
|
+
"""
|
|
97
|
+
Inputs kwargs.
|
|
98
|
+
|
|
99
|
+
If the configuration is None, the function selects typical dimensions.
|
|
100
|
+
"""
|
|
101
|
+
if config is not None:
|
|
102
|
+
# generates mask as outputs
|
|
103
|
+
if hasattr(config, "mask_decoder_config"):
|
|
104
|
+
check_hasattr(
|
|
105
|
+
config.mask_decoder_config,
|
|
106
|
+
"hidden_size",
|
|
107
|
+
"iou_head_hidden_dim",
|
|
108
|
+
"iou_head_depth",
|
|
109
|
+
"num_hidden_layers",
|
|
110
|
+
"num_multimask_outputs",
|
|
111
|
+
)
|
|
112
|
+
if hasattr(config, "prompt_encoder_config"):
|
|
113
|
+
check_hasattr(
|
|
114
|
+
config.prompt_encoder_config,
|
|
115
|
+
"hidden_size",
|
|
116
|
+
"image_embedding_size",
|
|
117
|
+
"image_size",
|
|
118
|
+
"mask_input_channels",
|
|
119
|
+
)
|
|
120
|
+
if hasattr(config, "vision_config"):
|
|
121
|
+
check_hasattr(
|
|
122
|
+
config.vision_config,
|
|
123
|
+
"image_size",
|
|
124
|
+
"hidden_size",
|
|
125
|
+
"intermediate_size",
|
|
126
|
+
"num_hidden_layers",
|
|
127
|
+
"output_channels",
|
|
128
|
+
"num_channels",
|
|
129
|
+
"window_size",
|
|
130
|
+
)
|
|
131
|
+
kwargs = dict(
|
|
132
|
+
batch_size=2,
|
|
133
|
+
width=1024 if config is None else config.vision_config.image_size,
|
|
134
|
+
height=1024 if config is None else config.vision_config.image_size,
|
|
135
|
+
num_channels=3 if config is None else config.vision_config.num_channels,
|
|
136
|
+
output_channels=256 if config is None else config.vision_config.output_channels,
|
|
137
|
+
window_size=14 if config is None else config.vision_config.window_size,
|
|
138
|
+
)
|
|
139
|
+
return kwargs, get_inputs
|
|
@@ -69,8 +69,8 @@ def get_inputs(
|
|
|
69
69
|
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
70
70
|
batch = torch.export.Dim("batch", min=1, max=1024)
|
|
71
71
|
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
|
|
72
|
-
cache_length = "cache_length_key"
|
|
73
|
-
cache_length2 = "cache_length_val"
|
|
72
|
+
cache_length = "cache_length_key"
|
|
73
|
+
cache_length2 = "cache_length_val"
|
|
74
74
|
|
|
75
75
|
shapes = {
|
|
76
76
|
"input_ids": {0: batch, 1: seq_length},
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
2
2
|
import torch
|
|
3
|
-
import transformers
|
|
4
3
|
from ..helpers.cache_helper import (
|
|
5
4
|
make_dynamic_cache,
|
|
6
5
|
make_mamba_cache,
|
|
@@ -95,9 +94,14 @@ def get_inputs(
|
|
|
95
94
|
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
|
|
96
95
|
|
|
97
96
|
if config is not None and config.__class__.__name__ == "FalconMambaConfig":
|
|
97
|
+
try:
|
|
98
|
+
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
99
|
+
except ImportError:
|
|
100
|
+
from transformers.cache_utils import MambaCache
|
|
101
|
+
|
|
98
102
|
assert cls_cache in (
|
|
99
103
|
"MambaCache",
|
|
100
|
-
|
|
104
|
+
MambaCache,
|
|
101
105
|
), f"Unexpected value for cls_cache={cls_cache} and config={config}"
|
|
102
106
|
seq_length_multiple = 8
|
|
103
107
|
sequence_length = (
|
|
@@ -16,6 +16,8 @@ def get_function(name: str) -> Tuple[type, Callable]:
|
|
|
16
16
|
module_name = ".".join(spl[:-1])
|
|
17
17
|
fname = spl[-1]
|
|
18
18
|
mod = importlib.import_module(module_name)
|
|
19
|
+
if not hasattr(mod, fname):
|
|
20
|
+
return None, None
|
|
19
21
|
return mod, getattr(mod, fname)
|
|
20
22
|
|
|
21
23
|
|
|
@@ -33,12 +35,16 @@ def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
|
|
|
33
35
|
doc = v.__doc__.lstrip()
|
|
34
36
|
if doc.startswith("manual patch"):
|
|
35
37
|
continue
|
|
36
|
-
reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]")
|
|
38
|
+
reg = re.compile("[\\[]patch:([a-z_A-Z.]+)[\\]]")
|
|
37
39
|
fall = reg.findall(doc)
|
|
38
40
|
assert (
|
|
39
41
|
len(fall) == 1
|
|
40
42
|
), f"Unable to find patching information for {v} in \n{doc}"
|
|
41
43
|
fmod, f = get_function(fall[0])
|
|
44
|
+
if fmod is None and f is None:
|
|
45
|
+
# The function does not exist in this version of transformers.
|
|
46
|
+
# No patch is needed.
|
|
47
|
+
continue
|
|
42
48
|
to_patch.append({"module": fmod, "function": f, "patch": v})
|
|
43
49
|
|
|
44
50
|
name = mod.__name__
|
|
@@ -6,12 +6,17 @@ import torch
|
|
|
6
6
|
import transformers
|
|
7
7
|
from transformers.cache_utils import (
|
|
8
8
|
DynamicCache,
|
|
9
|
-
MambaCache,
|
|
10
9
|
EncoderDecoderCache,
|
|
10
|
+
HybridCache,
|
|
11
11
|
SlidingWindowCache,
|
|
12
12
|
StaticCache,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
|
+
try:
|
|
16
|
+
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
17
|
+
except ImportError:
|
|
18
|
+
from transformers.cache_utils import MambaCache
|
|
19
|
+
|
|
15
20
|
from ..helpers import string_type
|
|
16
21
|
from .serialization import _lower_name_with_
|
|
17
22
|
|
|
@@ -161,6 +166,9 @@ def serialization_functions(
|
|
|
161
166
|
flatten_dynamic_cache,
|
|
162
167
|
unflatten_dynamic_cache,
|
|
163
168
|
flatten_with_keys_dynamic_cache,
|
|
169
|
+
flatten_hybrid_cache,
|
|
170
|
+
unflatten_hybrid_cache,
|
|
171
|
+
flatten_with_keys_hybrid_cache,
|
|
164
172
|
flatten_mamba_cache,
|
|
165
173
|
unflatten_mamba_cache,
|
|
166
174
|
flatten_with_keys_mamba_cache,
|
|
@@ -187,6 +195,14 @@ def serialization_functions(
|
|
|
187
195
|
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
188
196
|
verbose=verbose,
|
|
189
197
|
),
|
|
198
|
+
HybridCache: lambda verbose=verbose: register_class_serialization(
|
|
199
|
+
HybridCache,
|
|
200
|
+
flatten_hybrid_cache,
|
|
201
|
+
unflatten_hybrid_cache,
|
|
202
|
+
flatten_with_keys_hybrid_cache,
|
|
203
|
+
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
204
|
+
verbose=verbose,
|
|
205
|
+
),
|
|
190
206
|
MambaCache: lambda verbose=verbose: register_class_serialization(
|
|
191
207
|
MambaCache,
|
|
192
208
|
flatten_mamba_cache,
|
|
@@ -70,6 +70,8 @@ def convert_dynamic_axes_into_dynamic_shapes(
|
|
|
70
70
|
:param verbose: verbosity
|
|
71
71
|
:return: (args, kwargs, dynamic shapes)
|
|
72
72
|
"""
|
|
73
|
+
from ..helpers.cache_helper import CacheKeyValue
|
|
74
|
+
|
|
73
75
|
new_kwargs = {}
|
|
74
76
|
if args:
|
|
75
77
|
assert hasattr(model, "forward"), f"Missing method 'forward' for {model!r}"
|
|
@@ -121,7 +123,8 @@ def convert_dynamic_axes_into_dynamic_shapes(
|
|
|
121
123
|
changes[k] = type(updated_kwargs[k])
|
|
122
124
|
continue
|
|
123
125
|
if isinstance(v, transformers.cache_utils.DynamicCache):
|
|
124
|
-
|
|
126
|
+
ca = CacheKeyValue(v)
|
|
127
|
+
updated_kwargs[k] = [ca.key_cache, ca.value_cache]
|
|
125
128
|
changes[k] = type(v)
|
|
126
129
|
continue
|
|
127
130
|
raise NotImplementedError(
|