onnx-diagnostic 0.7.5__py3-none-any.whl → 0.7.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +56 -3
- onnx_diagnostic/export/dynamic_shapes.py +24 -10
- onnx_diagnostic/export/shape_helper.py +6 -2
- onnx_diagnostic/ext_test_case.py +2 -0
- onnx_diagnostic/helpers/_log_helper.py +6 -6
- onnx_diagnostic/helpers/cache_helper.py +326 -18
- onnx_diagnostic/helpers/config_helper.py +10 -0
- onnx_diagnostic/helpers/helper.py +152 -11
- onnx_diagnostic/helpers/mini_onnx_builder.py +7 -2
- onnx_diagnostic/helpers/onnx_helper.py +13 -7
- onnx_diagnostic/helpers/torch_helper.py +33 -11
- onnx_diagnostic/reference/ops/op_cast_like.py +15 -11
- onnx_diagnostic/reference/torch_ops/__init__.py +1 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +7 -0
- onnx_diagnostic/tasks/__init__.py +2 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +6 -2
- onnx_diagnostic/tasks/feature_extraction.py +7 -3
- onnx_diagnostic/tasks/fill_mask.py +6 -2
- onnx_diagnostic/tasks/image_classification.py +6 -2
- onnx_diagnostic/tasks/image_text_to_text.py +289 -62
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +2 -2
- onnx_diagnostic/tasks/object_detection.py +6 -2
- onnx_diagnostic/tasks/sentence_similarity.py +6 -2
- onnx_diagnostic/tasks/summarization.py +7 -2
- onnx_diagnostic/tasks/text2text_generation.py +7 -2
- onnx_diagnostic/tasks/text_classification.py +6 -2
- onnx_diagnostic/tasks/text_generation.py +14 -16
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +3 -3
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +17 -1
- onnx_diagnostic/torch_export_patches/patch_inputs.py +5 -2
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -4
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +428 -129
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +60 -41
- onnx_diagnostic/torch_models/hghub/hub_data.py +5 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +288 -0
- onnx_diagnostic/torch_models/validate.py +1 -0
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/RECORD +43 -42
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,20 @@
|
|
|
1
1
|
import inspect
|
|
2
|
+
import math
|
|
2
3
|
from dataclasses import dataclass
|
|
3
4
|
from functools import wraps
|
|
4
|
-
from typing import
|
|
5
|
+
from typing import Callable, List, Optional, Tuple
|
|
5
6
|
import packaging.version as pv
|
|
6
7
|
import torch
|
|
7
8
|
import transformers
|
|
8
9
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
9
|
-
from transformers.cache_utils import StaticCache, Cache
|
|
10
|
+
from transformers.cache_utils import StaticCache, Cache
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from transformers.cache_utils import parse_processor_args # noqa: F401
|
|
14
|
+
|
|
15
|
+
patch_parse_processor_args = True
|
|
16
|
+
except ImportError:
|
|
17
|
+
patch_parse_processor_args = False
|
|
10
18
|
|
|
11
19
|
try:
|
|
12
20
|
import transformers.masking_utils
|
|
@@ -15,10 +23,18 @@ try:
|
|
|
15
23
|
except ImportError:
|
|
16
24
|
patch_masking_utils = False
|
|
17
25
|
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
# transformers>= 4.55.1
|
|
29
|
+
from transformers.cache_utils import DynamicLayer
|
|
30
|
+
|
|
31
|
+
patch_DynamicLayer = hasattr(DynamicLayer, "lazy_initialization")
|
|
32
|
+
except ImportError:
|
|
33
|
+
patch_DynamicLayer = False
|
|
34
|
+
|
|
18
35
|
from ...ext_test_case import has_transformers
|
|
19
36
|
from ...helpers.torch_helper import is_torchdynamo_exporting
|
|
20
37
|
|
|
21
|
-
|
|
22
38
|
if patch_masking_utils:
|
|
23
39
|
# Introduced in 4.52
|
|
24
40
|
from transformers.masking_utils import causal_mask_function, sdpa_mask
|
|
@@ -110,6 +126,60 @@ if patch_masking_utils:
|
|
|
110
126
|
return mask
|
|
111
127
|
|
|
112
128
|
|
|
129
|
+
if patch_parse_processor_args:
|
|
130
|
+
|
|
131
|
+
def _init_cache_inspect():
|
|
132
|
+
res = {}
|
|
133
|
+
for processor_class in transformers.cache_utils.PROCESSOR_CLASS_MAP.values():
|
|
134
|
+
try:
|
|
135
|
+
params = list(inspect.signature(processor_class.__init__).parameters)[2:]
|
|
136
|
+
res[processor_class.__init__] = params
|
|
137
|
+
except Exception:
|
|
138
|
+
res[processor_class.__init__] = None
|
|
139
|
+
return res
|
|
140
|
+
|
|
141
|
+
_cache_inspect = _init_cache_inspect()
|
|
142
|
+
|
|
143
|
+
def patched_parse_processor_args(
|
|
144
|
+
processor_class: Optional[type["CacheProcessor"]], kwargs: dict # noqa: F821
|
|
145
|
+
) -> tuple[dict, dict]:
|
|
146
|
+
"""[patch:transformers.cache_utils.parse_processor_args]"""
|
|
147
|
+
# If not patched...
|
|
148
|
+
# Fails with transformers>=4.54 because function ``parse_processor_args``
|
|
149
|
+
# relies in inspect and the exporter is not very fond of that.
|
|
150
|
+
# torch._dynamo.exc.Unsupported: id() with unsupported args
|
|
151
|
+
# Explanation: Dynamo doesn't know how to trace id()
|
|
152
|
+
# call with args
|
|
153
|
+
# (GetAttrVariable(ConstantVariable(NoneType: None), __init__),)
|
|
154
|
+
# Hint: Supported args are Tensors, and functions/nn.Modules/user-defined
|
|
155
|
+
# objects from outside the compiled region.
|
|
156
|
+
# Hint: It may be possible to write Dynamo tracing rules for this code.
|
|
157
|
+
#
|
|
158
|
+
# The patch is caching the signature to avoid any call to inspect.
|
|
159
|
+
if processor_class is None:
|
|
160
|
+
return {}, kwargs
|
|
161
|
+
params = _cache_inspect[processor_class.__init__]
|
|
162
|
+
if params is None:
|
|
163
|
+
return {}, kwargs
|
|
164
|
+
processor_kwargs = {k: kwargs[k] for k in params if k in kwargs}
|
|
165
|
+
remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs}
|
|
166
|
+
return processor_kwargs, remaining_kwargs
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
if patch_DynamicLayer:
|
|
170
|
+
|
|
171
|
+
class patched_DynamicLayer:
|
|
172
|
+
_PATCHES_ = ["lazy_initialization"]
|
|
173
|
+
_PATCHED_CLASS_ = DynamicLayer
|
|
174
|
+
|
|
175
|
+
def lazy_initialization(self, key_states: torch.Tensor):
|
|
176
|
+
self.dtype, self.device = key_states.dtype, key_states.device
|
|
177
|
+
new_shape = list(key_states.shape)
|
|
178
|
+
new_shape[-2] = 0
|
|
179
|
+
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
|
|
180
|
+
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
|
|
181
|
+
|
|
182
|
+
|
|
113
183
|
def _patch_make_causal_mask(
|
|
114
184
|
input_ids_shape: torch.Size,
|
|
115
185
|
dtype: torch.dtype,
|
|
@@ -192,136 +262,148 @@ class patched_AttentionMaskConverter:
|
|
|
192
262
|
return _patch_make_causal_mask(**kwargs)
|
|
193
263
|
|
|
194
264
|
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
`transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
|
|
199
|
-
"""
|
|
200
|
-
|
|
201
|
-
_PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
|
|
202
|
-
_PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
|
|
203
|
-
|
|
204
|
-
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
|
205
|
-
"""Returns the sequence length of the cached states.
|
|
206
|
-
A layer index can be optionally passed."""
|
|
207
|
-
# TODO: deprecate this function in favor of `cache_position`
|
|
208
|
-
is_empty_layer = (
|
|
209
|
-
len(self.key_cache) == 0 # no cache in any layer
|
|
210
|
-
or len(self.key_cache)
|
|
211
|
-
<= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
|
212
|
-
or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
|
|
213
|
-
)
|
|
214
|
-
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
|
|
215
|
-
return layer_seq_length
|
|
216
|
-
|
|
217
|
-
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
218
|
-
"""Reorders the cache for beam search, given the selected beam indices."""
|
|
219
|
-
for layer_idx in range(len(self.key_cache)):
|
|
220
|
-
if self.key_cache[layer_idx].numel():
|
|
221
|
-
device = self.key_cache[layer_idx].device
|
|
222
|
-
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
|
|
223
|
-
0, beam_idx.to(device)
|
|
224
|
-
)
|
|
225
|
-
if self.value_cache[layer_idx].numel():
|
|
226
|
-
device = self.value_cache[layer_idx].device
|
|
227
|
-
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
|
|
228
|
-
0, beam_idx.to(device)
|
|
229
|
-
)
|
|
265
|
+
if pv.Version(transformers.__version__) < pv.Version("4.51"):
|
|
266
|
+
from typing import Any, Dict
|
|
267
|
+
from transformers.cache_utils import DynamicCache
|
|
230
268
|
|
|
231
|
-
|
|
232
|
-
self,
|
|
233
|
-
key_states: torch.Tensor,
|
|
234
|
-
value_states: torch.Tensor,
|
|
235
|
-
layer_idx: int,
|
|
236
|
-
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
237
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
269
|
+
class patched_DynamicCache:
|
|
238
270
|
"""
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
Parameters:
|
|
243
|
-
key_states (`torch.Tensor`):
|
|
244
|
-
The new key states to cache.
|
|
245
|
-
value_states (`torch.Tensor`):
|
|
246
|
-
The new value states to cache.
|
|
247
|
-
layer_idx (`int`):
|
|
248
|
-
The index of the layer to cache the states for.
|
|
249
|
-
cache_kwargs (`Dict[str, Any]`, `optional`):
|
|
250
|
-
Additional arguments for the cache subclass.
|
|
251
|
-
No additional arguments are used in `DynamicCache`.
|
|
252
|
-
|
|
253
|
-
Return:
|
|
254
|
-
A tuple containing the updated key and value states.
|
|
271
|
+
Applies modifications implemented in PR
|
|
272
|
+
`transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
|
|
255
273
|
"""
|
|
256
|
-
# Update the number of seen tokens
|
|
257
|
-
if layer_idx == 0:
|
|
258
|
-
if hasattr(self, "_seen_tokens"):
|
|
259
|
-
self._seen_tokens += key_states.shape[-2]
|
|
260
|
-
|
|
261
|
-
# Update the cache
|
|
262
|
-
if key_states is not None:
|
|
263
|
-
if len(self.key_cache) <= layer_idx:
|
|
264
|
-
# There may be skipped layers, fill them with empty lists
|
|
265
|
-
for _ in range(len(self.key_cache), layer_idx):
|
|
266
|
-
self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
267
|
-
self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
268
|
-
self.key_cache.append(key_states)
|
|
269
|
-
self.value_cache.append(value_states)
|
|
270
|
-
elif not self.key_cache[
|
|
271
|
-
layer_idx
|
|
272
|
-
].numel(): # prefers not t.numel() to len(t) == 0 to export the model
|
|
273
|
-
# fills previously skipped layers; checking for tensor causes errors
|
|
274
|
-
self.key_cache[layer_idx] = key_states
|
|
275
|
-
self.value_cache[layer_idx] = value_states
|
|
276
|
-
else:
|
|
277
|
-
self.key_cache[layer_idx] = torch.cat(
|
|
278
|
-
[self.key_cache[layer_idx], key_states], dim=-2
|
|
279
|
-
)
|
|
280
|
-
self.value_cache[layer_idx] = torch.cat(
|
|
281
|
-
[self.value_cache[layer_idx], value_states], dim=-2
|
|
282
|
-
)
|
|
283
|
-
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
284
274
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
275
|
+
_PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
|
|
276
|
+
_PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
|
|
277
|
+
|
|
278
|
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
|
279
|
+
"""Returns the sequence length of the cached states.
|
|
280
|
+
A layer index can be optionally passed."""
|
|
281
|
+
# TODO: deprecate this function in favor of `cache_position`
|
|
282
|
+
is_empty_layer = (
|
|
283
|
+
len(self.key_cache) == 0 # no cache in any layer
|
|
284
|
+
or len(self.key_cache)
|
|
285
|
+
<= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
|
286
|
+
or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
|
|
287
|
+
)
|
|
288
|
+
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
|
|
289
|
+
return layer_seq_length
|
|
290
|
+
|
|
291
|
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
292
|
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
|
293
|
+
for layer_idx in range(len(self.key_cache)):
|
|
294
|
+
if self.key_cache[layer_idx].numel():
|
|
295
|
+
device = self.key_cache[layer_idx].device
|
|
296
|
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
|
|
297
|
+
0, beam_idx.to(device)
|
|
298
|
+
)
|
|
299
|
+
if self.value_cache[layer_idx].numel():
|
|
300
|
+
device = self.value_cache[layer_idx].device
|
|
301
|
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
|
|
302
|
+
0, beam_idx.to(device)
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
def update(
|
|
306
|
+
self,
|
|
307
|
+
key_states: torch.Tensor,
|
|
308
|
+
value_states: torch.Tensor,
|
|
309
|
+
layer_idx: int,
|
|
310
|
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
311
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
312
|
+
"""
|
|
313
|
+
Updates the cache with the new `key_states`
|
|
314
|
+
and `value_states` for the layer `layer_idx`.
|
|
315
|
+
Parameters:
|
|
316
|
+
key_states (`torch.Tensor`):
|
|
317
|
+
The new key states to cache.
|
|
318
|
+
value_states (`torch.Tensor`):
|
|
319
|
+
The new value states to cache.
|
|
320
|
+
layer_idx (`int`):
|
|
321
|
+
The index of the layer to cache the states for.
|
|
322
|
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
|
323
|
+
Additional arguments for the cache subclass.
|
|
324
|
+
No additional arguments are used in `DynamicCache`.
|
|
325
|
+
Return:
|
|
326
|
+
A tuple containing the updated key and value states.
|
|
327
|
+
"""
|
|
328
|
+
# Update the number of seen tokens
|
|
329
|
+
if layer_idx == 0:
|
|
330
|
+
if hasattr(self, "_seen_tokens"):
|
|
331
|
+
self._seen_tokens += key_states.shape[-2]
|
|
332
|
+
|
|
333
|
+
# Update the cache
|
|
334
|
+
if key_states is not None:
|
|
335
|
+
if len(self.key_cache) <= layer_idx:
|
|
336
|
+
# There may be skipped layers, fill them with empty lists
|
|
337
|
+
for _ in range(len(self.key_cache), layer_idx):
|
|
338
|
+
self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
339
|
+
self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
340
|
+
self.key_cache.append(key_states)
|
|
341
|
+
self.value_cache.append(value_states)
|
|
342
|
+
elif not self.key_cache[
|
|
343
|
+
layer_idx
|
|
344
|
+
].numel(): # prefers not t.numel() to len(t) == 0 to export the model
|
|
345
|
+
# fills previously skipped layers; checking for tensor causes errors
|
|
346
|
+
self.key_cache[layer_idx] = key_states
|
|
347
|
+
self.value_cache[layer_idx] = value_states
|
|
348
|
+
else:
|
|
349
|
+
torch._check(
|
|
350
|
+
len(self.key_cache[layer_idx].shape) == len(key_states.shape),
|
|
351
|
+
lambda: (
|
|
352
|
+
f"Rank mismatch len(self.key_cache[layer_idx].shape)="
|
|
353
|
+
f"{len(self.key_cache[layer_idx].shape)}, "
|
|
354
|
+
f"len(key_states.shape)={len(key_states.shape)}"
|
|
355
|
+
),
|
|
356
|
+
)
|
|
357
|
+
self.key_cache[layer_idx] = torch.cat(
|
|
358
|
+
[self.key_cache[layer_idx], key_states], dim=-2
|
|
359
|
+
)
|
|
360
|
+
self.value_cache[layer_idx] = torch.cat(
|
|
361
|
+
[self.value_cache[layer_idx], value_states], dim=-2
|
|
362
|
+
)
|
|
363
|
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
364
|
+
|
|
365
|
+
def crop(self, max_length: int):
|
|
366
|
+
"""Crop the past key values up to a new `max_length`
|
|
367
|
+
in terms of tokens. `max_length` can also be
|
|
368
|
+
negative to remove `max_length` tokens.
|
|
369
|
+
This is used in assisted decoding and contrastive search.
|
|
370
|
+
"""
|
|
371
|
+
# In case it is negative
|
|
372
|
+
if max_length < 0:
|
|
373
|
+
max_length = self.get_seq_length() - abs(max_length)
|
|
374
|
+
|
|
375
|
+
if self.get_seq_length() <= max_length:
|
|
376
|
+
return
|
|
377
|
+
|
|
378
|
+
if hasattr(self, "_seen_tokens"):
|
|
379
|
+
self._seen_tokens = max_length
|
|
380
|
+
for idx in range(len(self.key_cache)):
|
|
381
|
+
if self.key_cache[idx].numel():
|
|
382
|
+
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
|
|
383
|
+
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
|
|
384
|
+
|
|
385
|
+
@classmethod
|
|
386
|
+
def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache:
|
|
387
|
+
"""This is the opposite of the above `batch_split()` method.
|
|
388
|
+
This will be used by `stack_model_outputs` in
|
|
389
|
+
`generation.utils`"""
|
|
390
|
+
cache = cls()
|
|
391
|
+
for idx in range(len(splits[0])):
|
|
392
|
+
key_cache = [
|
|
393
|
+
current.key_cache[idx]
|
|
394
|
+
for current in splits
|
|
395
|
+
if current.key_cache[idx].numel()
|
|
396
|
+
]
|
|
397
|
+
value_cache = [
|
|
398
|
+
current.value_cache[idx]
|
|
399
|
+
for current in splits
|
|
400
|
+
if current.value_cache[idx].numel()
|
|
401
|
+
]
|
|
402
|
+
if key_cache != []:
|
|
403
|
+
layer_keys = torch.cat(key_cache, dim=0)
|
|
404
|
+
layer_values = torch.cat(value_cache, dim=0)
|
|
405
|
+
cache.update(layer_keys, layer_values, idx)
|
|
406
|
+
return cache
|
|
325
407
|
|
|
326
408
|
|
|
327
409
|
class patched_GenerationMixin:
|
|
@@ -1183,3 +1265,220 @@ class patched_IdeficsAttention(torch.nn.Module):
|
|
|
1183
1265
|
if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
|
|
1184
1266
|
return attn_output, attn_weights, past_key_value
|
|
1185
1267
|
return attn_output, attn_weights
|
|
1268
|
+
|
|
1269
|
+
|
|
1270
|
+
class patched_SamMaskDecoder(torch.nn.Module):
|
|
1271
|
+
_PATCHES_ = ["forward"]
|
|
1272
|
+
_PATCHED_CLASS_ = transformers.models.sam.modeling_sam.SamMaskDecoder
|
|
1273
|
+
|
|
1274
|
+
def forward(
|
|
1275
|
+
self,
|
|
1276
|
+
image_embeddings: torch.Tensor,
|
|
1277
|
+
image_positional_embeddings: torch.Tensor,
|
|
1278
|
+
sparse_prompt_embeddings: torch.Tensor,
|
|
1279
|
+
dense_prompt_embeddings: torch.Tensor,
|
|
1280
|
+
multimask_output: bool,
|
|
1281
|
+
output_attentions: Optional[bool] = None,
|
|
1282
|
+
attention_similarity: Optional[torch.Tensor] = None,
|
|
1283
|
+
target_embedding: Optional[torch.Tensor] = None,
|
|
1284
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1285
|
+
"""
|
|
1286
|
+
Predict masks given image and prompt embeddings.
|
|
1287
|
+
|
|
1288
|
+
Args:
|
|
1289
|
+
image_embeddings (`torch.Tensor`):
|
|
1290
|
+
the embeddings from the image encoder
|
|
1291
|
+
image_positional_embedding (`torch.Tensor`):
|
|
1292
|
+
positional encoding with the shape of image_embeddings
|
|
1293
|
+
sparse_prompt_embeddings (`torch.Tensor`):
|
|
1294
|
+
The embeddings of the points and boxes
|
|
1295
|
+
dense_prompt_embeddings (`torch.Tensor`):
|
|
1296
|
+
the embeddings of the mask inputs
|
|
1297
|
+
multimask_output (bool):
|
|
1298
|
+
Whether to return multiple masks or a single mask.
|
|
1299
|
+
output_attentions (bool, *optional*):
|
|
1300
|
+
Whether or not to return the attentions tensors of all attention layers.
|
|
1301
|
+
"""
|
|
1302
|
+
batch_size, num_channels, height, width = image_embeddings.shape
|
|
1303
|
+
point_batch_size = sparse_prompt_embeddings.shape[1]
|
|
1304
|
+
# Concatenate output tokens
|
|
1305
|
+
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
|
1306
|
+
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
|
|
1307
|
+
|
|
1308
|
+
# torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
|
|
1309
|
+
# torch.any is needed to avoid data-dependent control flow
|
|
1310
|
+
# with sparse_prompt_embeddings.sum().item() != 0
|
|
1311
|
+
def sparse_prompt_embeddings_is_not_empty(output_tokens, sparse_prompt_embeddings):
|
|
1312
|
+
return torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
|
|
1313
|
+
|
|
1314
|
+
def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
|
|
1315
|
+
return output_tokens.clone()
|
|
1316
|
+
|
|
1317
|
+
tokens = torch.cond(
|
|
1318
|
+
torch.any(sparse_prompt_embeddings != 0),
|
|
1319
|
+
sparse_prompt_embeddings_is_not_empty,
|
|
1320
|
+
sparse_prompt_embeddings_is_empty,
|
|
1321
|
+
[output_tokens, sparse_prompt_embeddings],
|
|
1322
|
+
)
|
|
1323
|
+
|
|
1324
|
+
point_embeddings = tokens.to(self.iou_token.weight.dtype)
|
|
1325
|
+
|
|
1326
|
+
# Expand per-image data in batch direction to be per-point
|
|
1327
|
+
image_embeddings = image_embeddings + dense_prompt_embeddings
|
|
1328
|
+
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
|
|
1329
|
+
image_positional_embeddings = image_positional_embeddings.repeat_interleave(
|
|
1330
|
+
point_batch_size, 0
|
|
1331
|
+
)
|
|
1332
|
+
|
|
1333
|
+
# Run the transformer, image_positional_embedding are consumed
|
|
1334
|
+
torch._check(point_embeddings.shape[0] != 0)
|
|
1335
|
+
torch._check(point_embeddings.shape[1] != 0)
|
|
1336
|
+
torch._check(point_embeddings.shape[2] != 0)
|
|
1337
|
+
torch._check(point_embeddings.shape[3] != 0)
|
|
1338
|
+
embeddings_attentions = self.transformer(
|
|
1339
|
+
point_embeddings=point_embeddings,
|
|
1340
|
+
image_embeddings=image_embeddings,
|
|
1341
|
+
image_positional_embeddings=image_positional_embeddings,
|
|
1342
|
+
attention_similarity=attention_similarity,
|
|
1343
|
+
target_embedding=target_embedding,
|
|
1344
|
+
output_attentions=output_attentions,
|
|
1345
|
+
)
|
|
1346
|
+
point_embedding, image_embeddings = embeddings_attentions[:2]
|
|
1347
|
+
iou_token_out = torch.select(point_embedding, dim=2, index=0)
|
|
1348
|
+
mask_tokens_out = torch.narrow(
|
|
1349
|
+
point_embedding, dim=2, start=1, length=self.num_mask_tokens
|
|
1350
|
+
)
|
|
1351
|
+
|
|
1352
|
+
# Upscale mask embeddings and predict masks using the mask tokens
|
|
1353
|
+
image_embeddings = image_embeddings.transpose(2, 3).reshape(
|
|
1354
|
+
batch_size * point_batch_size, num_channels, height, width
|
|
1355
|
+
)
|
|
1356
|
+
|
|
1357
|
+
upscaled_embedding = self.upscale_conv1(image_embeddings)
|
|
1358
|
+
upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
|
|
1359
|
+
upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
|
|
1360
|
+
|
|
1361
|
+
hyper_in_list = []
|
|
1362
|
+
for i in range(self.num_mask_tokens):
|
|
1363
|
+
current_mlp = self.output_hypernetworks_mlps[i]
|
|
1364
|
+
hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
|
|
1365
|
+
hyper_in = torch.stack(hyper_in_list, dim=2)
|
|
1366
|
+
|
|
1367
|
+
_, num_channels, height, width = upscaled_embedding.shape
|
|
1368
|
+
upscaled_embedding = upscaled_embedding.reshape(
|
|
1369
|
+
batch_size, point_batch_size, num_channels, height * width
|
|
1370
|
+
)
|
|
1371
|
+
masks = (hyper_in @ upscaled_embedding).reshape(
|
|
1372
|
+
batch_size, point_batch_size, -1, height, width
|
|
1373
|
+
)
|
|
1374
|
+
|
|
1375
|
+
# Generate mask quality predictions
|
|
1376
|
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
|
1377
|
+
|
|
1378
|
+
# Select the correct mask or masks for output
|
|
1379
|
+
if multimask_output:
|
|
1380
|
+
mask_slice = slice(1, None)
|
|
1381
|
+
else:
|
|
1382
|
+
mask_slice = slice(0, 1)
|
|
1383
|
+
masks = masks[:, :, mask_slice, :, :]
|
|
1384
|
+
iou_pred = iou_pred[:, :, mask_slice]
|
|
1385
|
+
|
|
1386
|
+
outputs = (masks, iou_pred)
|
|
1387
|
+
|
|
1388
|
+
if len(embeddings_attentions) == 2:
|
|
1389
|
+
# transformers==4.54
|
|
1390
|
+
return outputs
|
|
1391
|
+
|
|
1392
|
+
if output_attentions and len(embeddings_attentions) > 2:
|
|
1393
|
+
outputs = outputs + (embeddings_attentions[2],) # noqa: RUF005
|
|
1394
|
+
else:
|
|
1395
|
+
outputs = outputs + (None,) # noqa: RUF005
|
|
1396
|
+
return outputs
|
|
1397
|
+
|
|
1398
|
+
|
|
1399
|
+
def rewrite_loop_for_square_mask(mask: torch.Tensor, seq: torch.Tensor):
|
|
1400
|
+
"""
|
|
1401
|
+
Rewrites the loop in:
|
|
1402
|
+
|
|
1403
|
+
.. code-block:: python
|
|
1404
|
+
|
|
1405
|
+
attention_mask = torch.full(
|
|
1406
|
+
[1, seq_length, seq_length], torch.finfo(q.dtype).min, dtype=q.dtype
|
|
1407
|
+
)
|
|
1408
|
+
for i in range(1, len(seq)):
|
|
1409
|
+
attention_mask[..., seq[i - 1] : seq[i], seq[i - 1] : seq[i]] = 0
|
|
1410
|
+
"""
|
|
1411
|
+
r = torch.arange(0, mask.shape[-1], dtype=torch.int64)
|
|
1412
|
+
less0 = (r.reshape((-1, 1)) < seq.reshape((1, -1))).to(torch.int64)
|
|
1413
|
+
less = less0.sum(axis=-1, keepdim=True) + 1
|
|
1414
|
+
sq = less * less.T
|
|
1415
|
+
look = (
|
|
1416
|
+
torch.max(seq.min() == 0, less != less.max())
|
|
1417
|
+
* torch.max(seq.max() == mask.shape[-1], less != less.min())
|
|
1418
|
+
* less
|
|
1419
|
+
)
|
|
1420
|
+
filt = (sq != look**2).to(mask.dtype)
|
|
1421
|
+
return mask * filt
|
|
1422
|
+
|
|
1423
|
+
|
|
1424
|
+
class patched_VisionAttention(torch.nn.Module):
|
|
1425
|
+
_PATCHES_ = ["forward"]
|
|
1426
|
+
_PATCHED_CLASS_ = transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention
|
|
1427
|
+
|
|
1428
|
+
def forward(
|
|
1429
|
+
self,
|
|
1430
|
+
hidden_states: torch.Tensor,
|
|
1431
|
+
cu_seqlens: torch.Tensor,
|
|
1432
|
+
rotary_pos_emb: Optional[torch.Tensor] = None,
|
|
1433
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
1434
|
+
) -> torch.Tensor:
|
|
1435
|
+
seq_length = hidden_states.shape[0]
|
|
1436
|
+
q, k, v = (
|
|
1437
|
+
self.qkv(hidden_states)
|
|
1438
|
+
.reshape(seq_length, 3, self.num_heads, -1)
|
|
1439
|
+
.permute(1, 0, 2, 3)
|
|
1440
|
+
.unbind(0)
|
|
1441
|
+
)
|
|
1442
|
+
if position_embeddings is None:
|
|
1443
|
+
transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once(
|
|
1444
|
+
"The attention layers in this model are transitioning from "
|
|
1445
|
+
" computing the RoPE embeddings internally "
|
|
1446
|
+
"through `rotary_pos_emb` (2D tensor of RoPE theta values), "
|
|
1447
|
+
"to using externally computed "
|
|
1448
|
+
"`position_embeddings` (Tuple of tensors, containing cos and sin)."
|
|
1449
|
+
" In v4.54 `rotary_pos_emb` will be "
|
|
1450
|
+
"removed and `position_embeddings` will be mandatory."
|
|
1451
|
+
)
|
|
1452
|
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
|
1453
|
+
cos = emb.cos()
|
|
1454
|
+
sin = emb.sin()
|
|
1455
|
+
else:
|
|
1456
|
+
cos, sin = position_embeddings
|
|
1457
|
+
q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision(
|
|
1458
|
+
q, k, cos, sin
|
|
1459
|
+
)
|
|
1460
|
+
|
|
1461
|
+
attention_mask = torch.full(
|
|
1462
|
+
[1, seq_length, seq_length],
|
|
1463
|
+
torch.finfo(q.dtype).min,
|
|
1464
|
+
device=q.device,
|
|
1465
|
+
dtype=q.dtype,
|
|
1466
|
+
)
|
|
1467
|
+
# for i in range(1, len(cu_seqlens)):
|
|
1468
|
+
# attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i],
|
|
1469
|
+
# cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
|
1470
|
+
attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens)
|
|
1471
|
+
|
|
1472
|
+
q = q.transpose(0, 1)
|
|
1473
|
+
k = k.transpose(0, 1)
|
|
1474
|
+
v = v.transpose(0, 1)
|
|
1475
|
+
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
|
1476
|
+
attn_weights = attn_weights + attention_mask
|
|
1477
|
+
attn_weights = torch.nn.functional.softmax(
|
|
1478
|
+
attn_weights, dim=-1, dtype=torch.float32
|
|
1479
|
+
).to(q.dtype)
|
|
1480
|
+
attn_output = torch.matmul(attn_weights, v)
|
|
1481
|
+
attn_output = attn_output.transpose(0, 1)
|
|
1482
|
+
attn_output = attn_output.reshape(seq_length, -1)
|
|
1483
|
+
attn_output = self.proj(attn_output)
|
|
1484
|
+
return attn_output
|