onnx-diagnostic 0.7.5__py3-none-any.whl → 0.7.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,12 +6,17 @@ import torch
6
6
  import transformers
7
7
  from transformers.cache_utils import (
8
8
  DynamicCache,
9
- MambaCache,
10
9
  EncoderDecoderCache,
10
+ HybridCache,
11
11
  SlidingWindowCache,
12
12
  StaticCache,
13
13
  )
14
14
 
15
+ try:
16
+ from transformers.models.mamba.modeling_mamba import MambaCache
17
+ except ImportError:
18
+ from transformers.cache_utils import MambaCache
19
+
15
20
  from ..helpers import string_type
16
21
  from .serialization import _lower_name_with_
17
22
 
@@ -161,6 +166,9 @@ def serialization_functions(
161
166
  flatten_dynamic_cache,
162
167
  unflatten_dynamic_cache,
163
168
  flatten_with_keys_dynamic_cache,
169
+ flatten_hybrid_cache,
170
+ unflatten_hybrid_cache,
171
+ flatten_with_keys_hybrid_cache,
164
172
  flatten_mamba_cache,
165
173
  unflatten_mamba_cache,
166
174
  flatten_with_keys_mamba_cache,
@@ -187,6 +195,14 @@ def serialization_functions(
187
195
  # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
188
196
  verbose=verbose,
189
197
  ),
198
+ HybridCache: lambda verbose=verbose: register_class_serialization(
199
+ HybridCache,
200
+ flatten_hybrid_cache,
201
+ unflatten_hybrid_cache,
202
+ flatten_with_keys_hybrid_cache,
203
+ # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
204
+ verbose=verbose,
205
+ ),
190
206
  MambaCache: lambda verbose=verbose: register_class_serialization(
191
207
  MambaCache,
192
208
  flatten_mamba_cache,
@@ -70,6 +70,8 @@ def convert_dynamic_axes_into_dynamic_shapes(
70
70
  :param verbose: verbosity
71
71
  :return: (args, kwargs, dynamic shapes)
72
72
  """
73
+ from ..helpers.cache_helper import CacheKeyValue
74
+
73
75
  new_kwargs = {}
74
76
  if args:
75
77
  assert hasattr(model, "forward"), f"Missing method 'forward' for {model!r}"
@@ -121,7 +123,8 @@ def convert_dynamic_axes_into_dynamic_shapes(
121
123
  changes[k] = type(updated_kwargs[k])
122
124
  continue
123
125
  if isinstance(v, transformers.cache_utils.DynamicCache):
124
- updated_kwargs[k] = [v.key_cache, v.value_cache]
126
+ ca = CacheKeyValue(v)
127
+ updated_kwargs[k] = [ca.key_cache, ca.value_cache]
125
128
  changes[k] = type(v)
126
129
  continue
127
130
  raise NotImplementedError(
@@ -1,12 +1,19 @@
1
1
  import inspect
2
2
  from dataclasses import dataclass
3
3
  from functools import wraps
4
- from typing import Any, Callable, Dict, List, Optional, Tuple
4
+ from typing import Callable, List, Optional, Tuple
5
5
  import packaging.version as pv
6
6
  import torch
7
7
  import transformers
8
8
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
9
- from transformers.cache_utils import StaticCache, Cache, DynamicCache
9
+ from transformers.cache_utils import StaticCache, Cache
10
+
11
+ try:
12
+ from transformers.cache_utils import parse_processor_args # noqa: F401
13
+
14
+ patch_parse_processor_args = True
15
+ except ImportError:
16
+ patch_parse_processor_args = False
10
17
 
11
18
  try:
12
19
  import transformers.masking_utils
@@ -15,10 +22,10 @@ try:
15
22
  except ImportError:
16
23
  patch_masking_utils = False
17
24
 
25
+
18
26
  from ...ext_test_case import has_transformers
19
27
  from ...helpers.torch_helper import is_torchdynamo_exporting
20
28
 
21
-
22
29
  if patch_masking_utils:
23
30
  # Introduced in 4.52
24
31
  from transformers.masking_utils import causal_mask_function, sdpa_mask
@@ -110,6 +117,46 @@ if patch_masking_utils:
110
117
  return mask
111
118
 
112
119
 
120
+ if patch_parse_processor_args:
121
+
122
+ def _init_cache_inspect():
123
+ res = {}
124
+ for processor_class in transformers.cache_utils.PROCESSOR_CLASS_MAP.values():
125
+ try:
126
+ params = list(inspect.signature(processor_class.__init__).parameters)[2:]
127
+ res[processor_class.__init__] = params
128
+ except Exception:
129
+ res[processor_class.__init__] = None
130
+ return res
131
+
132
+ _cache_inspect = _init_cache_inspect()
133
+
134
+ def patched_parse_processor_args(
135
+ processor_class: Optional[type["CacheProcessor"]], kwargs: dict # noqa: F821
136
+ ) -> tuple[dict, dict]:
137
+ """[patch:transformers.cache_utils.parse_processor_args]"""
138
+ # If not patched...
139
+ # Fails with transformers>=4.54 because function ``parse_processor_args``
140
+ # relies in inspect and the exporter is not very fond of that.
141
+ # torch._dynamo.exc.Unsupported: id() with unsupported args
142
+ # Explanation: Dynamo doesn't know how to trace id()
143
+ # call with args
144
+ # (GetAttrVariable(ConstantVariable(NoneType: None), __init__),)
145
+ # Hint: Supported args are Tensors, and functions/nn.Modules/user-defined
146
+ # objects from outside the compiled region.
147
+ # Hint: It may be possible to write Dynamo tracing rules for this code.
148
+ #
149
+ # The patch is caching the signature to avoid any call to inspect.
150
+ if processor_class is None:
151
+ return {}, kwargs
152
+ params = _cache_inspect[processor_class.__init__]
153
+ if params is None:
154
+ return {}, kwargs
155
+ processor_kwargs = {k: kwargs[k] for k in params if k in kwargs}
156
+ remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs}
157
+ return processor_kwargs, remaining_kwargs
158
+
159
+
113
160
  def _patch_make_causal_mask(
114
161
  input_ids_shape: torch.Size,
115
162
  dtype: torch.dtype,
@@ -192,136 +239,140 @@ class patched_AttentionMaskConverter:
192
239
  return _patch_make_causal_mask(**kwargs)
193
240
 
194
241
 
195
- class patched_DynamicCache:
196
- """
197
- Applies modifications implemented in PR
198
- `transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
199
- """
200
-
201
- _PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
202
- _PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
203
-
204
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
205
- """Returns the sequence length of the cached states.
206
- A layer index can be optionally passed."""
207
- # TODO: deprecate this function in favor of `cache_position`
208
- is_empty_layer = (
209
- len(self.key_cache) == 0 # no cache in any layer
210
- or len(self.key_cache)
211
- <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
212
- or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
213
- )
214
- layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
215
- return layer_seq_length
216
-
217
- def reorder_cache(self, beam_idx: torch.LongTensor):
218
- """Reorders the cache for beam search, given the selected beam indices."""
219
- for layer_idx in range(len(self.key_cache)):
220
- if self.key_cache[layer_idx].numel():
221
- device = self.key_cache[layer_idx].device
222
- self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
223
- 0, beam_idx.to(device)
224
- )
225
- if self.value_cache[layer_idx].numel():
226
- device = self.value_cache[layer_idx].device
227
- self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
228
- 0, beam_idx.to(device)
229
- )
242
+ if pv.Version(transformers.__version__) < pv.Version("4.51"):
243
+ from typing import Any, Dict
244
+ from transformers.cache_utils import DynamicCache
230
245
 
231
- def update(
232
- self,
233
- key_states: torch.Tensor,
234
- value_states: torch.Tensor,
235
- layer_idx: int,
236
- cache_kwargs: Optional[Dict[str, Any]] = None,
237
- ) -> Tuple[torch.Tensor, torch.Tensor]:
246
+ class patched_DynamicCache:
238
247
  """
239
- Updates the cache with the new `key_states`
240
- and `value_states` for the layer `layer_idx`.
241
-
242
- Parameters:
243
- key_states (`torch.Tensor`):
244
- The new key states to cache.
245
- value_states (`torch.Tensor`):
246
- The new value states to cache.
247
- layer_idx (`int`):
248
- The index of the layer to cache the states for.
249
- cache_kwargs (`Dict[str, Any]`, `optional`):
250
- Additional arguments for the cache subclass.
251
- No additional arguments are used in `DynamicCache`.
252
-
253
- Return:
254
- A tuple containing the updated key and value states.
248
+ Applies modifications implemented in PR
249
+ `transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
255
250
  """
256
- # Update the number of seen tokens
257
- if layer_idx == 0:
258
- if hasattr(self, "_seen_tokens"):
259
- self._seen_tokens += key_states.shape[-2]
260
-
261
- # Update the cache
262
- if key_states is not None:
263
- if len(self.key_cache) <= layer_idx:
264
- # There may be skipped layers, fill them with empty lists
265
- for _ in range(len(self.key_cache), layer_idx):
266
- self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
267
- self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
268
- self.key_cache.append(key_states)
269
- self.value_cache.append(value_states)
270
- elif not self.key_cache[
271
- layer_idx
272
- ].numel(): # prefers not t.numel() to len(t) == 0 to export the model
273
- # fills previously skipped layers; checking for tensor causes errors
274
- self.key_cache[layer_idx] = key_states
275
- self.value_cache[layer_idx] = value_states
276
- else:
277
- self.key_cache[layer_idx] = torch.cat(
278
- [self.key_cache[layer_idx], key_states], dim=-2
279
- )
280
- self.value_cache[layer_idx] = torch.cat(
281
- [self.value_cache[layer_idx], value_states], dim=-2
282
- )
283
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
284
251
 
285
- def crop(self, max_length: int):
286
- """Crop the past key values up to a new `max_length`
287
- in terms of tokens. `max_length` can also be
288
- negative to remove `max_length` tokens.
289
- This is used in assisted decoding and contrastive search.
290
- """
291
- # In case it is negative
292
- if max_length < 0:
293
- max_length = self.get_seq_length() - abs(max_length)
294
-
295
- if self.get_seq_length() <= max_length:
296
- return
297
-
298
- if hasattr(self, "_seen_tokens"):
299
- self._seen_tokens = max_length
300
- for idx in range(len(self.key_cache)):
301
- if self.key_cache[idx].numel():
302
- self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
303
- self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
304
-
305
- @classmethod
306
- def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache:
307
- """This is the opposite of the above `batch_split()` method.
308
- This will be used by `stack_model_outputs` in
309
- `generation.utils`"""
310
- cache = cls()
311
- for idx in range(len(splits[0])):
312
- key_cache = [
313
- current.key_cache[idx] for current in splits if current.key_cache[idx].numel()
314
- ]
315
- value_cache = [
316
- current.value_cache[idx]
317
- for current in splits
318
- if current.value_cache[idx].numel()
319
- ]
320
- if key_cache != []:
321
- layer_keys = torch.cat(key_cache, dim=0)
322
- layer_values = torch.cat(value_cache, dim=0)
323
- cache.update(layer_keys, layer_values, idx)
324
- return cache
252
+ _PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
253
+ _PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
254
+
255
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
256
+ """Returns the sequence length of the cached states.
257
+ A layer index can be optionally passed."""
258
+ # TODO: deprecate this function in favor of `cache_position`
259
+ is_empty_layer = (
260
+ len(self.key_cache) == 0 # no cache in any layer
261
+ or len(self.key_cache)
262
+ <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
263
+ or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
264
+ )
265
+ layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
266
+ return layer_seq_length
267
+
268
+ def reorder_cache(self, beam_idx: torch.LongTensor):
269
+ """Reorders the cache for beam search, given the selected beam indices."""
270
+ for layer_idx in range(len(self.key_cache)):
271
+ if self.key_cache[layer_idx].numel():
272
+ device = self.key_cache[layer_idx].device
273
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
274
+ 0, beam_idx.to(device)
275
+ )
276
+ if self.value_cache[layer_idx].numel():
277
+ device = self.value_cache[layer_idx].device
278
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
279
+ 0, beam_idx.to(device)
280
+ )
281
+
282
+ def update(
283
+ self,
284
+ key_states: torch.Tensor,
285
+ value_states: torch.Tensor,
286
+ layer_idx: int,
287
+ cache_kwargs: Optional[Dict[str, Any]] = None,
288
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
289
+ """
290
+ Updates the cache with the new `key_states`
291
+ and `value_states` for the layer `layer_idx`.
292
+ Parameters:
293
+ key_states (`torch.Tensor`):
294
+ The new key states to cache.
295
+ value_states (`torch.Tensor`):
296
+ The new value states to cache.
297
+ layer_idx (`int`):
298
+ The index of the layer to cache the states for.
299
+ cache_kwargs (`Dict[str, Any]`, `optional`):
300
+ Additional arguments for the cache subclass.
301
+ No additional arguments are used in `DynamicCache`.
302
+ Return:
303
+ A tuple containing the updated key and value states.
304
+ """
305
+ # Update the number of seen tokens
306
+ if layer_idx == 0:
307
+ if hasattr(self, "_seen_tokens"):
308
+ self._seen_tokens += key_states.shape[-2]
309
+
310
+ # Update the cache
311
+ if key_states is not None:
312
+ if len(self.key_cache) <= layer_idx:
313
+ # There may be skipped layers, fill them with empty lists
314
+ for _ in range(len(self.key_cache), layer_idx):
315
+ self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
316
+ self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
317
+ self.key_cache.append(key_states)
318
+ self.value_cache.append(value_states)
319
+ elif not self.key_cache[
320
+ layer_idx
321
+ ].numel(): # prefers not t.numel() to len(t) == 0 to export the model
322
+ # fills previously skipped layers; checking for tensor causes errors
323
+ self.key_cache[layer_idx] = key_states
324
+ self.value_cache[layer_idx] = value_states
325
+ else:
326
+ self.key_cache[layer_idx] = torch.cat(
327
+ [self.key_cache[layer_idx], key_states], dim=-2
328
+ )
329
+ self.value_cache[layer_idx] = torch.cat(
330
+ [self.value_cache[layer_idx], value_states], dim=-2
331
+ )
332
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
333
+
334
+ def crop(self, max_length: int):
335
+ """Crop the past key values up to a new `max_length`
336
+ in terms of tokens. `max_length` can also be
337
+ negative to remove `max_length` tokens.
338
+ This is used in assisted decoding and contrastive search.
339
+ """
340
+ # In case it is negative
341
+ if max_length < 0:
342
+ max_length = self.get_seq_length() - abs(max_length)
343
+
344
+ if self.get_seq_length() <= max_length:
345
+ return
346
+
347
+ if hasattr(self, "_seen_tokens"):
348
+ self._seen_tokens = max_length
349
+ for idx in range(len(self.key_cache)):
350
+ if self.key_cache[idx].numel():
351
+ self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
352
+ self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
353
+
354
+ @classmethod
355
+ def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache:
356
+ """This is the opposite of the above `batch_split()` method.
357
+ This will be used by `stack_model_outputs` in
358
+ `generation.utils`"""
359
+ cache = cls()
360
+ for idx in range(len(splits[0])):
361
+ key_cache = [
362
+ current.key_cache[idx]
363
+ for current in splits
364
+ if current.key_cache[idx].numel()
365
+ ]
366
+ value_cache = [
367
+ current.value_cache[idx]
368
+ for current in splits
369
+ if current.value_cache[idx].numel()
370
+ ]
371
+ if key_cache != []:
372
+ layer_keys = torch.cat(key_cache, dim=0)
373
+ layer_values = torch.cat(value_cache, dim=0)
374
+ cache.update(layer_keys, layer_values, idx)
375
+ return cache
325
376
 
326
377
 
327
378
  class patched_GenerationMixin:
@@ -1183,3 +1234,132 @@ class patched_IdeficsAttention(torch.nn.Module):
1183
1234
  if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
1184
1235
  return attn_output, attn_weights, past_key_value
1185
1236
  return attn_output, attn_weights
1237
+
1238
+
1239
+ class patched_SamMaskDecoder(torch.nn.Module):
1240
+ _PATCHES_ = ["forward"]
1241
+ _PATCHED_CLASS_ = transformers.models.sam.modeling_sam.SamMaskDecoder
1242
+
1243
+ def forward(
1244
+ self,
1245
+ image_embeddings: torch.Tensor,
1246
+ image_positional_embeddings: torch.Tensor,
1247
+ sparse_prompt_embeddings: torch.Tensor,
1248
+ dense_prompt_embeddings: torch.Tensor,
1249
+ multimask_output: bool,
1250
+ output_attentions: Optional[bool] = None,
1251
+ attention_similarity: Optional[torch.Tensor] = None,
1252
+ target_embedding: Optional[torch.Tensor] = None,
1253
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1254
+ """
1255
+ Predict masks given image and prompt embeddings.
1256
+
1257
+ Args:
1258
+ image_embeddings (`torch.Tensor`):
1259
+ the embeddings from the image encoder
1260
+ image_positional_embedding (`torch.Tensor`):
1261
+ positional encoding with the shape of image_embeddings
1262
+ sparse_prompt_embeddings (`torch.Tensor`):
1263
+ The embeddings of the points and boxes
1264
+ dense_prompt_embeddings (`torch.Tensor`):
1265
+ the embeddings of the mask inputs
1266
+ multimask_output (bool):
1267
+ Whether to return multiple masks or a single mask.
1268
+ output_attentions (bool, *optional*):
1269
+ Whether or not to return the attentions tensors of all attention layers.
1270
+ """
1271
+ batch_size, num_channels, height, width = image_embeddings.shape
1272
+ point_batch_size = sparse_prompt_embeddings.shape[1]
1273
+ # Concatenate output tokens
1274
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
1275
+ output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
1276
+
1277
+ # torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
1278
+ # torch.any is needed to avoid data-dependent control flow
1279
+ # with sparse_prompt_embeddings.sum().item() != 0
1280
+ def sparse_prompt_embeddings_is_not_empty(output_tokens, sparse_prompt_embeddings):
1281
+ return torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
1282
+
1283
+ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
1284
+ return output_tokens.clone()
1285
+
1286
+ tokens = torch.cond(
1287
+ torch.any(sparse_prompt_embeddings != 0),
1288
+ sparse_prompt_embeddings_is_not_empty,
1289
+ sparse_prompt_embeddings_is_empty,
1290
+ [output_tokens, sparse_prompt_embeddings],
1291
+ )
1292
+
1293
+ point_embeddings = tokens.to(self.iou_token.weight.dtype)
1294
+
1295
+ # Expand per-image data in batch direction to be per-point
1296
+ image_embeddings = image_embeddings + dense_prompt_embeddings
1297
+ image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
1298
+ image_positional_embeddings = image_positional_embeddings.repeat_interleave(
1299
+ point_batch_size, 0
1300
+ )
1301
+
1302
+ # Run the transformer, image_positional_embedding are consumed
1303
+ torch._check(point_embeddings.shape[0] != 0)
1304
+ torch._check(point_embeddings.shape[1] != 0)
1305
+ torch._check(point_embeddings.shape[2] != 0)
1306
+ torch._check(point_embeddings.shape[3] != 0)
1307
+ embeddings_attentions = self.transformer(
1308
+ point_embeddings=point_embeddings,
1309
+ image_embeddings=image_embeddings,
1310
+ image_positional_embeddings=image_positional_embeddings,
1311
+ attention_similarity=attention_similarity,
1312
+ target_embedding=target_embedding,
1313
+ output_attentions=output_attentions,
1314
+ )
1315
+ point_embedding, image_embeddings = embeddings_attentions[:2]
1316
+ iou_token_out = torch.select(point_embedding, dim=2, index=0)
1317
+ mask_tokens_out = torch.narrow(
1318
+ point_embedding, dim=2, start=1, length=self.num_mask_tokens
1319
+ )
1320
+
1321
+ # Upscale mask embeddings and predict masks using the mask tokens
1322
+ image_embeddings = image_embeddings.transpose(2, 3).reshape(
1323
+ batch_size * point_batch_size, num_channels, height, width
1324
+ )
1325
+
1326
+ upscaled_embedding = self.upscale_conv1(image_embeddings)
1327
+ upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
1328
+ upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
1329
+
1330
+ hyper_in_list = []
1331
+ for i in range(self.num_mask_tokens):
1332
+ current_mlp = self.output_hypernetworks_mlps[i]
1333
+ hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
1334
+ hyper_in = torch.stack(hyper_in_list, dim=2)
1335
+
1336
+ _, num_channels, height, width = upscaled_embedding.shape
1337
+ upscaled_embedding = upscaled_embedding.reshape(
1338
+ batch_size, point_batch_size, num_channels, height * width
1339
+ )
1340
+ masks = (hyper_in @ upscaled_embedding).reshape(
1341
+ batch_size, point_batch_size, -1, height, width
1342
+ )
1343
+
1344
+ # Generate mask quality predictions
1345
+ iou_pred = self.iou_prediction_head(iou_token_out)
1346
+
1347
+ # Select the correct mask or masks for output
1348
+ if multimask_output:
1349
+ mask_slice = slice(1, None)
1350
+ else:
1351
+ mask_slice = slice(0, 1)
1352
+ masks = masks[:, :, mask_slice, :, :]
1353
+ iou_pred = iou_pred[:, :, mask_slice]
1354
+
1355
+ outputs = (masks, iou_pred)
1356
+
1357
+ if len(embeddings_attentions) == 2:
1358
+ # transformers==4.54
1359
+ return outputs
1360
+
1361
+ if output_attentions and len(embeddings_attentions) > 2:
1362
+ outputs = outputs + (embeddings_attentions[2],) # noqa: RUF005
1363
+ else:
1364
+ outputs = outputs + (None,) # noqa: RUF005
1365
+ return outputs