onnx-diagnostic 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +66 -8
- onnx_diagnostic/ext_test_case.py +2 -0
- onnx_diagnostic/helpers/_log_helper.py +461 -0
- onnx_diagnostic/helpers/cache_helper.py +250 -15
- onnx_diagnostic/helpers/helper.py +146 -10
- onnx_diagnostic/helpers/log_helper.py +404 -315
- onnx_diagnostic/helpers/mini_onnx_builder.py +7 -2
- onnx_diagnostic/helpers/onnx_helper.py +13 -7
- onnx_diagnostic/helpers/torch_helper.py +33 -11
- onnx_diagnostic/tasks/__init__.py +2 -0
- onnx_diagnostic/tasks/feature_extraction.py +86 -5
- onnx_diagnostic/tasks/image_text_to_text.py +260 -56
- onnx_diagnostic/tasks/mask_generation.py +139 -0
- onnx_diagnostic/tasks/text2text_generation.py +2 -2
- onnx_diagnostic/tasks/text_generation.py +6 -2
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +7 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +17 -1
- onnx_diagnostic/torch_export_patches/patch_inputs.py +4 -1
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +397 -128
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +57 -40
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +288 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +5 -0
- onnx_diagnostic/torch_models/validate.py +26 -3
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/RECORD +29 -27
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,19 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from functools import wraps
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Callable, List, Optional, Tuple
|
|
5
5
|
import packaging.version as pv
|
|
6
6
|
import torch
|
|
7
7
|
import transformers
|
|
8
8
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
9
|
-
from transformers.cache_utils import StaticCache, Cache
|
|
9
|
+
from transformers.cache_utils import StaticCache, Cache
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from transformers.cache_utils import parse_processor_args # noqa: F401
|
|
13
|
+
|
|
14
|
+
patch_parse_processor_args = True
|
|
15
|
+
except ImportError:
|
|
16
|
+
patch_parse_processor_args = False
|
|
10
17
|
|
|
11
18
|
try:
|
|
12
19
|
import transformers.masking_utils
|
|
@@ -15,10 +22,10 @@ try:
|
|
|
15
22
|
except ImportError:
|
|
16
23
|
patch_masking_utils = False
|
|
17
24
|
|
|
25
|
+
|
|
18
26
|
from ...ext_test_case import has_transformers
|
|
19
27
|
from ...helpers.torch_helper import is_torchdynamo_exporting
|
|
20
28
|
|
|
21
|
-
|
|
22
29
|
if patch_masking_utils:
|
|
23
30
|
# Introduced in 4.52
|
|
24
31
|
from transformers.masking_utils import causal_mask_function, sdpa_mask
|
|
@@ -110,6 +117,46 @@ if patch_masking_utils:
|
|
|
110
117
|
return mask
|
|
111
118
|
|
|
112
119
|
|
|
120
|
+
if patch_parse_processor_args:
|
|
121
|
+
|
|
122
|
+
def _init_cache_inspect():
|
|
123
|
+
res = {}
|
|
124
|
+
for processor_class in transformers.cache_utils.PROCESSOR_CLASS_MAP.values():
|
|
125
|
+
try:
|
|
126
|
+
params = list(inspect.signature(processor_class.__init__).parameters)[2:]
|
|
127
|
+
res[processor_class.__init__] = params
|
|
128
|
+
except Exception:
|
|
129
|
+
res[processor_class.__init__] = None
|
|
130
|
+
return res
|
|
131
|
+
|
|
132
|
+
_cache_inspect = _init_cache_inspect()
|
|
133
|
+
|
|
134
|
+
def patched_parse_processor_args(
|
|
135
|
+
processor_class: Optional[type["CacheProcessor"]], kwargs: dict # noqa: F821
|
|
136
|
+
) -> tuple[dict, dict]:
|
|
137
|
+
"""[patch:transformers.cache_utils.parse_processor_args]"""
|
|
138
|
+
# If not patched...
|
|
139
|
+
# Fails with transformers>=4.54 because function ``parse_processor_args``
|
|
140
|
+
# relies in inspect and the exporter is not very fond of that.
|
|
141
|
+
# torch._dynamo.exc.Unsupported: id() with unsupported args
|
|
142
|
+
# Explanation: Dynamo doesn't know how to trace id()
|
|
143
|
+
# call with args
|
|
144
|
+
# (GetAttrVariable(ConstantVariable(NoneType: None), __init__),)
|
|
145
|
+
# Hint: Supported args are Tensors, and functions/nn.Modules/user-defined
|
|
146
|
+
# objects from outside the compiled region.
|
|
147
|
+
# Hint: It may be possible to write Dynamo tracing rules for this code.
|
|
148
|
+
#
|
|
149
|
+
# The patch is caching the signature to avoid any call to inspect.
|
|
150
|
+
if processor_class is None:
|
|
151
|
+
return {}, kwargs
|
|
152
|
+
params = _cache_inspect[processor_class.__init__]
|
|
153
|
+
if params is None:
|
|
154
|
+
return {}, kwargs
|
|
155
|
+
processor_kwargs = {k: kwargs[k] for k in params if k in kwargs}
|
|
156
|
+
remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs}
|
|
157
|
+
return processor_kwargs, remaining_kwargs
|
|
158
|
+
|
|
159
|
+
|
|
113
160
|
def _patch_make_causal_mask(
|
|
114
161
|
input_ids_shape: torch.Size,
|
|
115
162
|
dtype: torch.dtype,
|
|
@@ -192,134 +239,140 @@ class patched_AttentionMaskConverter:
|
|
|
192
239
|
return _patch_make_causal_mask(**kwargs)
|
|
193
240
|
|
|
194
241
|
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
`transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
|
|
199
|
-
"""
|
|
200
|
-
|
|
201
|
-
_PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
|
|
202
|
-
_PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
|
|
203
|
-
|
|
204
|
-
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
|
205
|
-
"""Returns the sequence length of the cached states.
|
|
206
|
-
A layer index can be optionally passed."""
|
|
207
|
-
# TODO: deprecate this function in favor of `cache_position`
|
|
208
|
-
is_empty_layer = (
|
|
209
|
-
len(self.key_cache) == 0 # no cache in any layer
|
|
210
|
-
or len(self.key_cache)
|
|
211
|
-
<= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
|
212
|
-
or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
|
|
213
|
-
)
|
|
214
|
-
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
|
|
215
|
-
return layer_seq_length
|
|
216
|
-
|
|
217
|
-
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
218
|
-
"""Reorders the cache for beam search, given the selected beam indices."""
|
|
219
|
-
for layer_idx in range(len(self.key_cache)):
|
|
220
|
-
if self.key_cache[layer_idx].numel():
|
|
221
|
-
device = self.key_cache[layer_idx].device
|
|
222
|
-
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
|
|
223
|
-
0, beam_idx.to(device)
|
|
224
|
-
)
|
|
225
|
-
if self.value_cache[layer_idx].numel():
|
|
226
|
-
device = self.value_cache[layer_idx].device
|
|
227
|
-
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
|
|
228
|
-
0, beam_idx.to(device)
|
|
229
|
-
)
|
|
242
|
+
if pv.Version(transformers.__version__) < pv.Version("4.51"):
|
|
243
|
+
from typing import Any, Dict
|
|
244
|
+
from transformers.cache_utils import DynamicCache
|
|
230
245
|
|
|
231
|
-
|
|
232
|
-
self,
|
|
233
|
-
key_states: torch.Tensor,
|
|
234
|
-
value_states: torch.Tensor,
|
|
235
|
-
layer_idx: int,
|
|
236
|
-
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
237
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
246
|
+
class patched_DynamicCache:
|
|
238
247
|
"""
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
Parameters:
|
|
243
|
-
key_states (`torch.Tensor`):
|
|
244
|
-
The new key states to cache.
|
|
245
|
-
value_states (`torch.Tensor`):
|
|
246
|
-
The new value states to cache.
|
|
247
|
-
layer_idx (`int`):
|
|
248
|
-
The index of the layer to cache the states for.
|
|
249
|
-
cache_kwargs (`Dict[str, Any]`, `optional`):
|
|
250
|
-
Additional arguments for the cache subclass.
|
|
251
|
-
No additional arguments are used in `DynamicCache`.
|
|
252
|
-
|
|
253
|
-
Return:
|
|
254
|
-
A tuple containing the updated key and value states.
|
|
248
|
+
Applies modifications implemented in PR
|
|
249
|
+
`transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
|
|
255
250
|
"""
|
|
256
|
-
# Update the number of seen tokens
|
|
257
|
-
if layer_idx == 0:
|
|
258
|
-
self._seen_tokens += key_states.shape[-2]
|
|
259
|
-
|
|
260
|
-
# Update the cache
|
|
261
|
-
if key_states is not None:
|
|
262
|
-
if len(self.key_cache) <= layer_idx:
|
|
263
|
-
# There may be skipped layers, fill them with empty lists
|
|
264
|
-
for _ in range(len(self.key_cache), layer_idx):
|
|
265
|
-
self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
266
|
-
self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
267
|
-
self.key_cache.append(key_states)
|
|
268
|
-
self.value_cache.append(value_states)
|
|
269
|
-
elif not self.key_cache[
|
|
270
|
-
layer_idx
|
|
271
|
-
].numel(): # prefers not t.numel() to len(t) == 0 to export the model
|
|
272
|
-
# fills previously skipped layers; checking for tensor causes errors
|
|
273
|
-
self.key_cache[layer_idx] = key_states
|
|
274
|
-
self.value_cache[layer_idx] = value_states
|
|
275
|
-
else:
|
|
276
|
-
self.key_cache[layer_idx] = torch.cat(
|
|
277
|
-
[self.key_cache[layer_idx], key_states], dim=-2
|
|
278
|
-
)
|
|
279
|
-
self.value_cache[layer_idx] = torch.cat(
|
|
280
|
-
[self.value_cache[layer_idx], value_states], dim=-2
|
|
281
|
-
)
|
|
282
|
-
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
283
251
|
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
252
|
+
_PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
|
|
253
|
+
_PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
|
|
254
|
+
|
|
255
|
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
|
256
|
+
"""Returns the sequence length of the cached states.
|
|
257
|
+
A layer index can be optionally passed."""
|
|
258
|
+
# TODO: deprecate this function in favor of `cache_position`
|
|
259
|
+
is_empty_layer = (
|
|
260
|
+
len(self.key_cache) == 0 # no cache in any layer
|
|
261
|
+
or len(self.key_cache)
|
|
262
|
+
<= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
|
263
|
+
or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
|
|
264
|
+
)
|
|
265
|
+
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
|
|
266
|
+
return layer_seq_length
|
|
267
|
+
|
|
268
|
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
269
|
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
|
270
|
+
for layer_idx in range(len(self.key_cache)):
|
|
271
|
+
if self.key_cache[layer_idx].numel():
|
|
272
|
+
device = self.key_cache[layer_idx].device
|
|
273
|
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
|
|
274
|
+
0, beam_idx.to(device)
|
|
275
|
+
)
|
|
276
|
+
if self.value_cache[layer_idx].numel():
|
|
277
|
+
device = self.value_cache[layer_idx].device
|
|
278
|
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
|
|
279
|
+
0, beam_idx.to(device)
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
def update(
|
|
283
|
+
self,
|
|
284
|
+
key_states: torch.Tensor,
|
|
285
|
+
value_states: torch.Tensor,
|
|
286
|
+
layer_idx: int,
|
|
287
|
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
288
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
289
|
+
"""
|
|
290
|
+
Updates the cache with the new `key_states`
|
|
291
|
+
and `value_states` for the layer `layer_idx`.
|
|
292
|
+
Parameters:
|
|
293
|
+
key_states (`torch.Tensor`):
|
|
294
|
+
The new key states to cache.
|
|
295
|
+
value_states (`torch.Tensor`):
|
|
296
|
+
The new value states to cache.
|
|
297
|
+
layer_idx (`int`):
|
|
298
|
+
The index of the layer to cache the states for.
|
|
299
|
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
|
300
|
+
Additional arguments for the cache subclass.
|
|
301
|
+
No additional arguments are used in `DynamicCache`.
|
|
302
|
+
Return:
|
|
303
|
+
A tuple containing the updated key and value states.
|
|
304
|
+
"""
|
|
305
|
+
# Update the number of seen tokens
|
|
306
|
+
if layer_idx == 0:
|
|
307
|
+
if hasattr(self, "_seen_tokens"):
|
|
308
|
+
self._seen_tokens += key_states.shape[-2]
|
|
309
|
+
|
|
310
|
+
# Update the cache
|
|
311
|
+
if key_states is not None:
|
|
312
|
+
if len(self.key_cache) <= layer_idx:
|
|
313
|
+
# There may be skipped layers, fill them with empty lists
|
|
314
|
+
for _ in range(len(self.key_cache), layer_idx):
|
|
315
|
+
self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
316
|
+
self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
317
|
+
self.key_cache.append(key_states)
|
|
318
|
+
self.value_cache.append(value_states)
|
|
319
|
+
elif not self.key_cache[
|
|
320
|
+
layer_idx
|
|
321
|
+
].numel(): # prefers not t.numel() to len(t) == 0 to export the model
|
|
322
|
+
# fills previously skipped layers; checking for tensor causes errors
|
|
323
|
+
self.key_cache[layer_idx] = key_states
|
|
324
|
+
self.value_cache[layer_idx] = value_states
|
|
325
|
+
else:
|
|
326
|
+
self.key_cache[layer_idx] = torch.cat(
|
|
327
|
+
[self.key_cache[layer_idx], key_states], dim=-2
|
|
328
|
+
)
|
|
329
|
+
self.value_cache[layer_idx] = torch.cat(
|
|
330
|
+
[self.value_cache[layer_idx], value_states], dim=-2
|
|
331
|
+
)
|
|
332
|
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
333
|
+
|
|
334
|
+
def crop(self, max_length: int):
|
|
335
|
+
"""Crop the past key values up to a new `max_length`
|
|
336
|
+
in terms of tokens. `max_length` can also be
|
|
337
|
+
negative to remove `max_length` tokens.
|
|
338
|
+
This is used in assisted decoding and contrastive search.
|
|
339
|
+
"""
|
|
340
|
+
# In case it is negative
|
|
341
|
+
if max_length < 0:
|
|
342
|
+
max_length = self.get_seq_length() - abs(max_length)
|
|
343
|
+
|
|
344
|
+
if self.get_seq_length() <= max_length:
|
|
345
|
+
return
|
|
346
|
+
|
|
347
|
+
if hasattr(self, "_seen_tokens"):
|
|
348
|
+
self._seen_tokens = max_length
|
|
349
|
+
for idx in range(len(self.key_cache)):
|
|
350
|
+
if self.key_cache[idx].numel():
|
|
351
|
+
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
|
|
352
|
+
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
|
|
353
|
+
|
|
354
|
+
@classmethod
|
|
355
|
+
def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache:
|
|
356
|
+
"""This is the opposite of the above `batch_split()` method.
|
|
357
|
+
This will be used by `stack_model_outputs` in
|
|
358
|
+
`generation.utils`"""
|
|
359
|
+
cache = cls()
|
|
360
|
+
for idx in range(len(splits[0])):
|
|
361
|
+
key_cache = [
|
|
362
|
+
current.key_cache[idx]
|
|
363
|
+
for current in splits
|
|
364
|
+
if current.key_cache[idx].numel()
|
|
365
|
+
]
|
|
366
|
+
value_cache = [
|
|
367
|
+
current.value_cache[idx]
|
|
368
|
+
for current in splits
|
|
369
|
+
if current.value_cache[idx].numel()
|
|
370
|
+
]
|
|
371
|
+
if key_cache != []:
|
|
372
|
+
layer_keys = torch.cat(key_cache, dim=0)
|
|
373
|
+
layer_values = torch.cat(value_cache, dim=0)
|
|
374
|
+
cache.update(layer_keys, layer_values, idx)
|
|
375
|
+
return cache
|
|
323
376
|
|
|
324
377
|
|
|
325
378
|
class patched_GenerationMixin:
|
|
@@ -862,6 +915,91 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
862
915
|
return wrapper
|
|
863
916
|
|
|
864
917
|
|
|
918
|
+
def common_eager_attention_forward(
|
|
919
|
+
module: torch.nn.Module,
|
|
920
|
+
query: torch.Tensor,
|
|
921
|
+
key: torch.Tensor,
|
|
922
|
+
value: torch.Tensor,
|
|
923
|
+
attention_mask: Optional[torch.Tensor],
|
|
924
|
+
scaling: Optional[float] = None,
|
|
925
|
+
dropout: float = 0.0,
|
|
926
|
+
head_mask: Optional[torch.Tensor] = None,
|
|
927
|
+
**kwargs,
|
|
928
|
+
):
|
|
929
|
+
if scaling is None:
|
|
930
|
+
scaling = query.size(-1) ** -0.5
|
|
931
|
+
|
|
932
|
+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
|
933
|
+
if attention_mask is not None:
|
|
934
|
+
# The two following lines were added.
|
|
935
|
+
if attention_mask is not None and attention_mask.ndim == 4:
|
|
936
|
+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
|
937
|
+
attn_weights = attn_weights + attention_mask
|
|
938
|
+
|
|
939
|
+
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
|
|
940
|
+
|
|
941
|
+
if head_mask is not None:
|
|
942
|
+
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
|
943
|
+
|
|
944
|
+
attn_weights = torch.nn.functional.dropout(
|
|
945
|
+
attn_weights, p=dropout, training=module.training
|
|
946
|
+
)
|
|
947
|
+
attn_output = torch.matmul(attn_weights, value)
|
|
948
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
949
|
+
|
|
950
|
+
return attn_output, attn_weights
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
def patched_model_bart_eager_attention_forward(
|
|
954
|
+
module: torch.nn.Module,
|
|
955
|
+
query: torch.Tensor,
|
|
956
|
+
key: torch.Tensor,
|
|
957
|
+
value: torch.Tensor,
|
|
958
|
+
attention_mask: Optional[torch.Tensor],
|
|
959
|
+
scaling: Optional[float] = None,
|
|
960
|
+
dropout: float = 0.0,
|
|
961
|
+
head_mask: Optional[torch.Tensor] = None,
|
|
962
|
+
**kwargs,
|
|
963
|
+
):
|
|
964
|
+
"""[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
|
|
965
|
+
return common_eager_attention_forward(
|
|
966
|
+
module,
|
|
967
|
+
query,
|
|
968
|
+
key,
|
|
969
|
+
value,
|
|
970
|
+
attention_mask=attention_mask,
|
|
971
|
+
scaling=scaling,
|
|
972
|
+
dropout=dropout,
|
|
973
|
+
head_mask=head_mask,
|
|
974
|
+
**kwargs,
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
|
|
978
|
+
def patched_modeling_marian_eager_attention_forward(
|
|
979
|
+
module: torch.nn.Module,
|
|
980
|
+
query: torch.Tensor,
|
|
981
|
+
key: torch.Tensor,
|
|
982
|
+
value: torch.Tensor,
|
|
983
|
+
attention_mask: Optional[torch.Tensor],
|
|
984
|
+
scaling: Optional[float] = None,
|
|
985
|
+
dropout: float = 0.0,
|
|
986
|
+
head_mask: Optional[torch.Tensor] = None,
|
|
987
|
+
**kwargs,
|
|
988
|
+
):
|
|
989
|
+
"""[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
|
|
990
|
+
return common_eager_attention_forward(
|
|
991
|
+
module,
|
|
992
|
+
query,
|
|
993
|
+
key,
|
|
994
|
+
value,
|
|
995
|
+
attention_mask=attention_mask,
|
|
996
|
+
scaling=scaling,
|
|
997
|
+
dropout=dropout,
|
|
998
|
+
head_mask=head_mask,
|
|
999
|
+
**kwargs,
|
|
1000
|
+
)
|
|
1001
|
+
|
|
1002
|
+
|
|
865
1003
|
class common_RotaryEmbedding(torch.nn.Module):
|
|
866
1004
|
@torch.no_grad()
|
|
867
1005
|
@patched_dynamic_rope_update
|
|
@@ -1093,4 +1231,135 @@ class patched_IdeficsAttention(torch.nn.Module):
|
|
|
1093
1231
|
if output_attentions:
|
|
1094
1232
|
attn_weights = None
|
|
1095
1233
|
|
|
1096
|
-
|
|
1234
|
+
if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
|
|
1235
|
+
return attn_output, attn_weights, past_key_value
|
|
1236
|
+
return attn_output, attn_weights
|
|
1237
|
+
|
|
1238
|
+
|
|
1239
|
+
class patched_SamMaskDecoder(torch.nn.Module):
|
|
1240
|
+
_PATCHES_ = ["forward"]
|
|
1241
|
+
_PATCHED_CLASS_ = transformers.models.sam.modeling_sam.SamMaskDecoder
|
|
1242
|
+
|
|
1243
|
+
def forward(
|
|
1244
|
+
self,
|
|
1245
|
+
image_embeddings: torch.Tensor,
|
|
1246
|
+
image_positional_embeddings: torch.Tensor,
|
|
1247
|
+
sparse_prompt_embeddings: torch.Tensor,
|
|
1248
|
+
dense_prompt_embeddings: torch.Tensor,
|
|
1249
|
+
multimask_output: bool,
|
|
1250
|
+
output_attentions: Optional[bool] = None,
|
|
1251
|
+
attention_similarity: Optional[torch.Tensor] = None,
|
|
1252
|
+
target_embedding: Optional[torch.Tensor] = None,
|
|
1253
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1254
|
+
"""
|
|
1255
|
+
Predict masks given image and prompt embeddings.
|
|
1256
|
+
|
|
1257
|
+
Args:
|
|
1258
|
+
image_embeddings (`torch.Tensor`):
|
|
1259
|
+
the embeddings from the image encoder
|
|
1260
|
+
image_positional_embedding (`torch.Tensor`):
|
|
1261
|
+
positional encoding with the shape of image_embeddings
|
|
1262
|
+
sparse_prompt_embeddings (`torch.Tensor`):
|
|
1263
|
+
The embeddings of the points and boxes
|
|
1264
|
+
dense_prompt_embeddings (`torch.Tensor`):
|
|
1265
|
+
the embeddings of the mask inputs
|
|
1266
|
+
multimask_output (bool):
|
|
1267
|
+
Whether to return multiple masks or a single mask.
|
|
1268
|
+
output_attentions (bool, *optional*):
|
|
1269
|
+
Whether or not to return the attentions tensors of all attention layers.
|
|
1270
|
+
"""
|
|
1271
|
+
batch_size, num_channels, height, width = image_embeddings.shape
|
|
1272
|
+
point_batch_size = sparse_prompt_embeddings.shape[1]
|
|
1273
|
+
# Concatenate output tokens
|
|
1274
|
+
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
|
1275
|
+
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
|
|
1276
|
+
|
|
1277
|
+
# torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
|
|
1278
|
+
# torch.any is needed to avoid data-dependent control flow
|
|
1279
|
+
# with sparse_prompt_embeddings.sum().item() != 0
|
|
1280
|
+
def sparse_prompt_embeddings_is_not_empty(output_tokens, sparse_prompt_embeddings):
|
|
1281
|
+
return torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
|
|
1282
|
+
|
|
1283
|
+
def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
|
|
1284
|
+
return output_tokens.clone()
|
|
1285
|
+
|
|
1286
|
+
tokens = torch.cond(
|
|
1287
|
+
torch.any(sparse_prompt_embeddings != 0),
|
|
1288
|
+
sparse_prompt_embeddings_is_not_empty,
|
|
1289
|
+
sparse_prompt_embeddings_is_empty,
|
|
1290
|
+
[output_tokens, sparse_prompt_embeddings],
|
|
1291
|
+
)
|
|
1292
|
+
|
|
1293
|
+
point_embeddings = tokens.to(self.iou_token.weight.dtype)
|
|
1294
|
+
|
|
1295
|
+
# Expand per-image data in batch direction to be per-point
|
|
1296
|
+
image_embeddings = image_embeddings + dense_prompt_embeddings
|
|
1297
|
+
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
|
|
1298
|
+
image_positional_embeddings = image_positional_embeddings.repeat_interleave(
|
|
1299
|
+
point_batch_size, 0
|
|
1300
|
+
)
|
|
1301
|
+
|
|
1302
|
+
# Run the transformer, image_positional_embedding are consumed
|
|
1303
|
+
torch._check(point_embeddings.shape[0] != 0)
|
|
1304
|
+
torch._check(point_embeddings.shape[1] != 0)
|
|
1305
|
+
torch._check(point_embeddings.shape[2] != 0)
|
|
1306
|
+
torch._check(point_embeddings.shape[3] != 0)
|
|
1307
|
+
embeddings_attentions = self.transformer(
|
|
1308
|
+
point_embeddings=point_embeddings,
|
|
1309
|
+
image_embeddings=image_embeddings,
|
|
1310
|
+
image_positional_embeddings=image_positional_embeddings,
|
|
1311
|
+
attention_similarity=attention_similarity,
|
|
1312
|
+
target_embedding=target_embedding,
|
|
1313
|
+
output_attentions=output_attentions,
|
|
1314
|
+
)
|
|
1315
|
+
point_embedding, image_embeddings = embeddings_attentions[:2]
|
|
1316
|
+
iou_token_out = torch.select(point_embedding, dim=2, index=0)
|
|
1317
|
+
mask_tokens_out = torch.narrow(
|
|
1318
|
+
point_embedding, dim=2, start=1, length=self.num_mask_tokens
|
|
1319
|
+
)
|
|
1320
|
+
|
|
1321
|
+
# Upscale mask embeddings and predict masks using the mask tokens
|
|
1322
|
+
image_embeddings = image_embeddings.transpose(2, 3).reshape(
|
|
1323
|
+
batch_size * point_batch_size, num_channels, height, width
|
|
1324
|
+
)
|
|
1325
|
+
|
|
1326
|
+
upscaled_embedding = self.upscale_conv1(image_embeddings)
|
|
1327
|
+
upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
|
|
1328
|
+
upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
|
|
1329
|
+
|
|
1330
|
+
hyper_in_list = []
|
|
1331
|
+
for i in range(self.num_mask_tokens):
|
|
1332
|
+
current_mlp = self.output_hypernetworks_mlps[i]
|
|
1333
|
+
hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
|
|
1334
|
+
hyper_in = torch.stack(hyper_in_list, dim=2)
|
|
1335
|
+
|
|
1336
|
+
_, num_channels, height, width = upscaled_embedding.shape
|
|
1337
|
+
upscaled_embedding = upscaled_embedding.reshape(
|
|
1338
|
+
batch_size, point_batch_size, num_channels, height * width
|
|
1339
|
+
)
|
|
1340
|
+
masks = (hyper_in @ upscaled_embedding).reshape(
|
|
1341
|
+
batch_size, point_batch_size, -1, height, width
|
|
1342
|
+
)
|
|
1343
|
+
|
|
1344
|
+
# Generate mask quality predictions
|
|
1345
|
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
|
1346
|
+
|
|
1347
|
+
# Select the correct mask or masks for output
|
|
1348
|
+
if multimask_output:
|
|
1349
|
+
mask_slice = slice(1, None)
|
|
1350
|
+
else:
|
|
1351
|
+
mask_slice = slice(0, 1)
|
|
1352
|
+
masks = masks[:, :, mask_slice, :, :]
|
|
1353
|
+
iou_pred = iou_pred[:, :, mask_slice]
|
|
1354
|
+
|
|
1355
|
+
outputs = (masks, iou_pred)
|
|
1356
|
+
|
|
1357
|
+
if len(embeddings_attentions) == 2:
|
|
1358
|
+
# transformers==4.54
|
|
1359
|
+
return outputs
|
|
1360
|
+
|
|
1361
|
+
if output_attentions and len(embeddings_attentions) > 2:
|
|
1362
|
+
outputs = outputs + (embeddings_attentions[2],) # noqa: RUF005
|
|
1363
|
+
else:
|
|
1364
|
+
outputs = outputs + (None,) # noqa: RUF005
|
|
1365
|
+
return outputs
|