onnx-diagnostic 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +66 -8
  3. onnx_diagnostic/ext_test_case.py +2 -0
  4. onnx_diagnostic/helpers/_log_helper.py +461 -0
  5. onnx_diagnostic/helpers/cache_helper.py +250 -15
  6. onnx_diagnostic/helpers/helper.py +146 -10
  7. onnx_diagnostic/helpers/log_helper.py +404 -315
  8. onnx_diagnostic/helpers/mini_onnx_builder.py +7 -2
  9. onnx_diagnostic/helpers/onnx_helper.py +13 -7
  10. onnx_diagnostic/helpers/torch_helper.py +33 -11
  11. onnx_diagnostic/tasks/__init__.py +2 -0
  12. onnx_diagnostic/tasks/feature_extraction.py +86 -5
  13. onnx_diagnostic/tasks/image_text_to_text.py +260 -56
  14. onnx_diagnostic/tasks/mask_generation.py +139 -0
  15. onnx_diagnostic/tasks/text2text_generation.py +2 -2
  16. onnx_diagnostic/tasks/text_generation.py +6 -2
  17. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +7 -1
  18. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +17 -1
  19. onnx_diagnostic/torch_export_patches/patch_inputs.py +4 -1
  20. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +397 -128
  21. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +57 -40
  22. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +288 -0
  23. onnx_diagnostic/torch_models/hghub/model_inputs.py +5 -0
  24. onnx_diagnostic/torch_models/validate.py +26 -3
  25. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/METADATA +1 -1
  26. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/RECORD +29 -27
  27. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/WHEEL +0 -0
  28. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/licenses/LICENSE.txt +0 -0
  29. {onnx_diagnostic-0.7.4.dist-info → onnx_diagnostic-0.7.6.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,19 @@
1
1
  import inspect
2
2
  from dataclasses import dataclass
3
3
  from functools import wraps
4
- from typing import Any, Callable, Dict, List, Optional, Tuple
4
+ from typing import Callable, List, Optional, Tuple
5
5
  import packaging.version as pv
6
6
  import torch
7
7
  import transformers
8
8
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
9
- from transformers.cache_utils import StaticCache, Cache, DynamicCache
9
+ from transformers.cache_utils import StaticCache, Cache
10
+
11
+ try:
12
+ from transformers.cache_utils import parse_processor_args # noqa: F401
13
+
14
+ patch_parse_processor_args = True
15
+ except ImportError:
16
+ patch_parse_processor_args = False
10
17
 
11
18
  try:
12
19
  import transformers.masking_utils
@@ -15,10 +22,10 @@ try:
15
22
  except ImportError:
16
23
  patch_masking_utils = False
17
24
 
25
+
18
26
  from ...ext_test_case import has_transformers
19
27
  from ...helpers.torch_helper import is_torchdynamo_exporting
20
28
 
21
-
22
29
  if patch_masking_utils:
23
30
  # Introduced in 4.52
24
31
  from transformers.masking_utils import causal_mask_function, sdpa_mask
@@ -110,6 +117,46 @@ if patch_masking_utils:
110
117
  return mask
111
118
 
112
119
 
120
+ if patch_parse_processor_args:
121
+
122
+ def _init_cache_inspect():
123
+ res = {}
124
+ for processor_class in transformers.cache_utils.PROCESSOR_CLASS_MAP.values():
125
+ try:
126
+ params = list(inspect.signature(processor_class.__init__).parameters)[2:]
127
+ res[processor_class.__init__] = params
128
+ except Exception:
129
+ res[processor_class.__init__] = None
130
+ return res
131
+
132
+ _cache_inspect = _init_cache_inspect()
133
+
134
+ def patched_parse_processor_args(
135
+ processor_class: Optional[type["CacheProcessor"]], kwargs: dict # noqa: F821
136
+ ) -> tuple[dict, dict]:
137
+ """[patch:transformers.cache_utils.parse_processor_args]"""
138
+ # If not patched...
139
+ # Fails with transformers>=4.54 because function ``parse_processor_args``
140
+ # relies in inspect and the exporter is not very fond of that.
141
+ # torch._dynamo.exc.Unsupported: id() with unsupported args
142
+ # Explanation: Dynamo doesn't know how to trace id()
143
+ # call with args
144
+ # (GetAttrVariable(ConstantVariable(NoneType: None), __init__),)
145
+ # Hint: Supported args are Tensors, and functions/nn.Modules/user-defined
146
+ # objects from outside the compiled region.
147
+ # Hint: It may be possible to write Dynamo tracing rules for this code.
148
+ #
149
+ # The patch is caching the signature to avoid any call to inspect.
150
+ if processor_class is None:
151
+ return {}, kwargs
152
+ params = _cache_inspect[processor_class.__init__]
153
+ if params is None:
154
+ return {}, kwargs
155
+ processor_kwargs = {k: kwargs[k] for k in params if k in kwargs}
156
+ remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs}
157
+ return processor_kwargs, remaining_kwargs
158
+
159
+
113
160
  def _patch_make_causal_mask(
114
161
  input_ids_shape: torch.Size,
115
162
  dtype: torch.dtype,
@@ -192,134 +239,140 @@ class patched_AttentionMaskConverter:
192
239
  return _patch_make_causal_mask(**kwargs)
193
240
 
194
241
 
195
- class patched_DynamicCache:
196
- """
197
- Applies modifications implemented in PR
198
- `transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
199
- """
200
-
201
- _PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
202
- _PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
203
-
204
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
205
- """Returns the sequence length of the cached states.
206
- A layer index can be optionally passed."""
207
- # TODO: deprecate this function in favor of `cache_position`
208
- is_empty_layer = (
209
- len(self.key_cache) == 0 # no cache in any layer
210
- or len(self.key_cache)
211
- <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
212
- or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
213
- )
214
- layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
215
- return layer_seq_length
216
-
217
- def reorder_cache(self, beam_idx: torch.LongTensor):
218
- """Reorders the cache for beam search, given the selected beam indices."""
219
- for layer_idx in range(len(self.key_cache)):
220
- if self.key_cache[layer_idx].numel():
221
- device = self.key_cache[layer_idx].device
222
- self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
223
- 0, beam_idx.to(device)
224
- )
225
- if self.value_cache[layer_idx].numel():
226
- device = self.value_cache[layer_idx].device
227
- self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
228
- 0, beam_idx.to(device)
229
- )
242
+ if pv.Version(transformers.__version__) < pv.Version("4.51"):
243
+ from typing import Any, Dict
244
+ from transformers.cache_utils import DynamicCache
230
245
 
231
- def update(
232
- self,
233
- key_states: torch.Tensor,
234
- value_states: torch.Tensor,
235
- layer_idx: int,
236
- cache_kwargs: Optional[Dict[str, Any]] = None,
237
- ) -> Tuple[torch.Tensor, torch.Tensor]:
246
+ class patched_DynamicCache:
238
247
  """
239
- Updates the cache with the new `key_states`
240
- and `value_states` for the layer `layer_idx`.
241
-
242
- Parameters:
243
- key_states (`torch.Tensor`):
244
- The new key states to cache.
245
- value_states (`torch.Tensor`):
246
- The new value states to cache.
247
- layer_idx (`int`):
248
- The index of the layer to cache the states for.
249
- cache_kwargs (`Dict[str, Any]`, `optional`):
250
- Additional arguments for the cache subclass.
251
- No additional arguments are used in `DynamicCache`.
252
-
253
- Return:
254
- A tuple containing the updated key and value states.
248
+ Applies modifications implemented in PR
249
+ `transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
255
250
  """
256
- # Update the number of seen tokens
257
- if layer_idx == 0:
258
- self._seen_tokens += key_states.shape[-2]
259
-
260
- # Update the cache
261
- if key_states is not None:
262
- if len(self.key_cache) <= layer_idx:
263
- # There may be skipped layers, fill them with empty lists
264
- for _ in range(len(self.key_cache), layer_idx):
265
- self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
266
- self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
267
- self.key_cache.append(key_states)
268
- self.value_cache.append(value_states)
269
- elif not self.key_cache[
270
- layer_idx
271
- ].numel(): # prefers not t.numel() to len(t) == 0 to export the model
272
- # fills previously skipped layers; checking for tensor causes errors
273
- self.key_cache[layer_idx] = key_states
274
- self.value_cache[layer_idx] = value_states
275
- else:
276
- self.key_cache[layer_idx] = torch.cat(
277
- [self.key_cache[layer_idx], key_states], dim=-2
278
- )
279
- self.value_cache[layer_idx] = torch.cat(
280
- [self.value_cache[layer_idx], value_states], dim=-2
281
- )
282
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
283
251
 
284
- def crop(self, max_length: int):
285
- """Crop the past key values up to a new `max_length`
286
- in terms of tokens. `max_length` can also be
287
- negative to remove `max_length` tokens.
288
- This is used in assisted decoding and contrastive search.
289
- """
290
- # In case it is negative
291
- if max_length < 0:
292
- max_length = self.get_seq_length() - abs(max_length)
293
-
294
- if self.get_seq_length() <= max_length:
295
- return
296
-
297
- self._seen_tokens = max_length
298
- for idx in range(len(self.key_cache)):
299
- if self.key_cache[idx].numel():
300
- self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
301
- self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
302
-
303
- @classmethod
304
- def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache:
305
- """This is the opposite of the above `batch_split()` method.
306
- This will be used by `stack_model_outputs` in
307
- `generation.utils`"""
308
- cache = cls()
309
- for idx in range(len(splits[0])):
310
- key_cache = [
311
- current.key_cache[idx] for current in splits if current.key_cache[idx].numel()
312
- ]
313
- value_cache = [
314
- current.value_cache[idx]
315
- for current in splits
316
- if current.value_cache[idx].numel()
317
- ]
318
- if key_cache != []:
319
- layer_keys = torch.cat(key_cache, dim=0)
320
- layer_values = torch.cat(value_cache, dim=0)
321
- cache.update(layer_keys, layer_values, idx)
322
- return cache
252
+ _PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
253
+ _PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
254
+
255
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
256
+ """Returns the sequence length of the cached states.
257
+ A layer index can be optionally passed."""
258
+ # TODO: deprecate this function in favor of `cache_position`
259
+ is_empty_layer = (
260
+ len(self.key_cache) == 0 # no cache in any layer
261
+ or len(self.key_cache)
262
+ <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
263
+ or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
264
+ )
265
+ layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
266
+ return layer_seq_length
267
+
268
+ def reorder_cache(self, beam_idx: torch.LongTensor):
269
+ """Reorders the cache for beam search, given the selected beam indices."""
270
+ for layer_idx in range(len(self.key_cache)):
271
+ if self.key_cache[layer_idx].numel():
272
+ device = self.key_cache[layer_idx].device
273
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
274
+ 0, beam_idx.to(device)
275
+ )
276
+ if self.value_cache[layer_idx].numel():
277
+ device = self.value_cache[layer_idx].device
278
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
279
+ 0, beam_idx.to(device)
280
+ )
281
+
282
+ def update(
283
+ self,
284
+ key_states: torch.Tensor,
285
+ value_states: torch.Tensor,
286
+ layer_idx: int,
287
+ cache_kwargs: Optional[Dict[str, Any]] = None,
288
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
289
+ """
290
+ Updates the cache with the new `key_states`
291
+ and `value_states` for the layer `layer_idx`.
292
+ Parameters:
293
+ key_states (`torch.Tensor`):
294
+ The new key states to cache.
295
+ value_states (`torch.Tensor`):
296
+ The new value states to cache.
297
+ layer_idx (`int`):
298
+ The index of the layer to cache the states for.
299
+ cache_kwargs (`Dict[str, Any]`, `optional`):
300
+ Additional arguments for the cache subclass.
301
+ No additional arguments are used in `DynamicCache`.
302
+ Return:
303
+ A tuple containing the updated key and value states.
304
+ """
305
+ # Update the number of seen tokens
306
+ if layer_idx == 0:
307
+ if hasattr(self, "_seen_tokens"):
308
+ self._seen_tokens += key_states.shape[-2]
309
+
310
+ # Update the cache
311
+ if key_states is not None:
312
+ if len(self.key_cache) <= layer_idx:
313
+ # There may be skipped layers, fill them with empty lists
314
+ for _ in range(len(self.key_cache), layer_idx):
315
+ self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
316
+ self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
317
+ self.key_cache.append(key_states)
318
+ self.value_cache.append(value_states)
319
+ elif not self.key_cache[
320
+ layer_idx
321
+ ].numel(): # prefers not t.numel() to len(t) == 0 to export the model
322
+ # fills previously skipped layers; checking for tensor causes errors
323
+ self.key_cache[layer_idx] = key_states
324
+ self.value_cache[layer_idx] = value_states
325
+ else:
326
+ self.key_cache[layer_idx] = torch.cat(
327
+ [self.key_cache[layer_idx], key_states], dim=-2
328
+ )
329
+ self.value_cache[layer_idx] = torch.cat(
330
+ [self.value_cache[layer_idx], value_states], dim=-2
331
+ )
332
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
333
+
334
+ def crop(self, max_length: int):
335
+ """Crop the past key values up to a new `max_length`
336
+ in terms of tokens. `max_length` can also be
337
+ negative to remove `max_length` tokens.
338
+ This is used in assisted decoding and contrastive search.
339
+ """
340
+ # In case it is negative
341
+ if max_length < 0:
342
+ max_length = self.get_seq_length() - abs(max_length)
343
+
344
+ if self.get_seq_length() <= max_length:
345
+ return
346
+
347
+ if hasattr(self, "_seen_tokens"):
348
+ self._seen_tokens = max_length
349
+ for idx in range(len(self.key_cache)):
350
+ if self.key_cache[idx].numel():
351
+ self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
352
+ self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
353
+
354
+ @classmethod
355
+ def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache:
356
+ """This is the opposite of the above `batch_split()` method.
357
+ This will be used by `stack_model_outputs` in
358
+ `generation.utils`"""
359
+ cache = cls()
360
+ for idx in range(len(splits[0])):
361
+ key_cache = [
362
+ current.key_cache[idx]
363
+ for current in splits
364
+ if current.key_cache[idx].numel()
365
+ ]
366
+ value_cache = [
367
+ current.value_cache[idx]
368
+ for current in splits
369
+ if current.value_cache[idx].numel()
370
+ ]
371
+ if key_cache != []:
372
+ layer_keys = torch.cat(key_cache, dim=0)
373
+ layer_values = torch.cat(value_cache, dim=0)
374
+ cache.update(layer_keys, layer_values, idx)
375
+ return cache
323
376
 
324
377
 
325
378
  class patched_GenerationMixin:
@@ -862,6 +915,91 @@ def patched_dynamic_rope_update(rope_forward):
862
915
  return wrapper
863
916
 
864
917
 
918
+ def common_eager_attention_forward(
919
+ module: torch.nn.Module,
920
+ query: torch.Tensor,
921
+ key: torch.Tensor,
922
+ value: torch.Tensor,
923
+ attention_mask: Optional[torch.Tensor],
924
+ scaling: Optional[float] = None,
925
+ dropout: float = 0.0,
926
+ head_mask: Optional[torch.Tensor] = None,
927
+ **kwargs,
928
+ ):
929
+ if scaling is None:
930
+ scaling = query.size(-1) ** -0.5
931
+
932
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
933
+ if attention_mask is not None:
934
+ # The two following lines were added.
935
+ if attention_mask is not None and attention_mask.ndim == 4:
936
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
937
+ attn_weights = attn_weights + attention_mask
938
+
939
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
940
+
941
+ if head_mask is not None:
942
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
943
+
944
+ attn_weights = torch.nn.functional.dropout(
945
+ attn_weights, p=dropout, training=module.training
946
+ )
947
+ attn_output = torch.matmul(attn_weights, value)
948
+ attn_output = attn_output.transpose(1, 2).contiguous()
949
+
950
+ return attn_output, attn_weights
951
+
952
+
953
+ def patched_model_bart_eager_attention_forward(
954
+ module: torch.nn.Module,
955
+ query: torch.Tensor,
956
+ key: torch.Tensor,
957
+ value: torch.Tensor,
958
+ attention_mask: Optional[torch.Tensor],
959
+ scaling: Optional[float] = None,
960
+ dropout: float = 0.0,
961
+ head_mask: Optional[torch.Tensor] = None,
962
+ **kwargs,
963
+ ):
964
+ """[patch:transformers.models.bart.modeling_bart.eager_attention_forward]"""
965
+ return common_eager_attention_forward(
966
+ module,
967
+ query,
968
+ key,
969
+ value,
970
+ attention_mask=attention_mask,
971
+ scaling=scaling,
972
+ dropout=dropout,
973
+ head_mask=head_mask,
974
+ **kwargs,
975
+ )
976
+
977
+
978
+ def patched_modeling_marian_eager_attention_forward(
979
+ module: torch.nn.Module,
980
+ query: torch.Tensor,
981
+ key: torch.Tensor,
982
+ value: torch.Tensor,
983
+ attention_mask: Optional[torch.Tensor],
984
+ scaling: Optional[float] = None,
985
+ dropout: float = 0.0,
986
+ head_mask: Optional[torch.Tensor] = None,
987
+ **kwargs,
988
+ ):
989
+ """[patch:transformers.models.marian.modeling_marian.eager_attention_forward]"""
990
+ return common_eager_attention_forward(
991
+ module,
992
+ query,
993
+ key,
994
+ value,
995
+ attention_mask=attention_mask,
996
+ scaling=scaling,
997
+ dropout=dropout,
998
+ head_mask=head_mask,
999
+ **kwargs,
1000
+ )
1001
+
1002
+
865
1003
  class common_RotaryEmbedding(torch.nn.Module):
866
1004
  @torch.no_grad()
867
1005
  @patched_dynamic_rope_update
@@ -1093,4 +1231,135 @@ class patched_IdeficsAttention(torch.nn.Module):
1093
1231
  if output_attentions:
1094
1232
  attn_weights = None
1095
1233
 
1096
- return attn_output, attn_weights, past_key_value
1234
+ if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
1235
+ return attn_output, attn_weights, past_key_value
1236
+ return attn_output, attn_weights
1237
+
1238
+
1239
+ class patched_SamMaskDecoder(torch.nn.Module):
1240
+ _PATCHES_ = ["forward"]
1241
+ _PATCHED_CLASS_ = transformers.models.sam.modeling_sam.SamMaskDecoder
1242
+
1243
+ def forward(
1244
+ self,
1245
+ image_embeddings: torch.Tensor,
1246
+ image_positional_embeddings: torch.Tensor,
1247
+ sparse_prompt_embeddings: torch.Tensor,
1248
+ dense_prompt_embeddings: torch.Tensor,
1249
+ multimask_output: bool,
1250
+ output_attentions: Optional[bool] = None,
1251
+ attention_similarity: Optional[torch.Tensor] = None,
1252
+ target_embedding: Optional[torch.Tensor] = None,
1253
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1254
+ """
1255
+ Predict masks given image and prompt embeddings.
1256
+
1257
+ Args:
1258
+ image_embeddings (`torch.Tensor`):
1259
+ the embeddings from the image encoder
1260
+ image_positional_embedding (`torch.Tensor`):
1261
+ positional encoding with the shape of image_embeddings
1262
+ sparse_prompt_embeddings (`torch.Tensor`):
1263
+ The embeddings of the points and boxes
1264
+ dense_prompt_embeddings (`torch.Tensor`):
1265
+ the embeddings of the mask inputs
1266
+ multimask_output (bool):
1267
+ Whether to return multiple masks or a single mask.
1268
+ output_attentions (bool, *optional*):
1269
+ Whether or not to return the attentions tensors of all attention layers.
1270
+ """
1271
+ batch_size, num_channels, height, width = image_embeddings.shape
1272
+ point_batch_size = sparse_prompt_embeddings.shape[1]
1273
+ # Concatenate output tokens
1274
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
1275
+ output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
1276
+
1277
+ # torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
1278
+ # torch.any is needed to avoid data-dependent control flow
1279
+ # with sparse_prompt_embeddings.sum().item() != 0
1280
+ def sparse_prompt_embeddings_is_not_empty(output_tokens, sparse_prompt_embeddings):
1281
+ return torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
1282
+
1283
+ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
1284
+ return output_tokens.clone()
1285
+
1286
+ tokens = torch.cond(
1287
+ torch.any(sparse_prompt_embeddings != 0),
1288
+ sparse_prompt_embeddings_is_not_empty,
1289
+ sparse_prompt_embeddings_is_empty,
1290
+ [output_tokens, sparse_prompt_embeddings],
1291
+ )
1292
+
1293
+ point_embeddings = tokens.to(self.iou_token.weight.dtype)
1294
+
1295
+ # Expand per-image data in batch direction to be per-point
1296
+ image_embeddings = image_embeddings + dense_prompt_embeddings
1297
+ image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
1298
+ image_positional_embeddings = image_positional_embeddings.repeat_interleave(
1299
+ point_batch_size, 0
1300
+ )
1301
+
1302
+ # Run the transformer, image_positional_embedding are consumed
1303
+ torch._check(point_embeddings.shape[0] != 0)
1304
+ torch._check(point_embeddings.shape[1] != 0)
1305
+ torch._check(point_embeddings.shape[2] != 0)
1306
+ torch._check(point_embeddings.shape[3] != 0)
1307
+ embeddings_attentions = self.transformer(
1308
+ point_embeddings=point_embeddings,
1309
+ image_embeddings=image_embeddings,
1310
+ image_positional_embeddings=image_positional_embeddings,
1311
+ attention_similarity=attention_similarity,
1312
+ target_embedding=target_embedding,
1313
+ output_attentions=output_attentions,
1314
+ )
1315
+ point_embedding, image_embeddings = embeddings_attentions[:2]
1316
+ iou_token_out = torch.select(point_embedding, dim=2, index=0)
1317
+ mask_tokens_out = torch.narrow(
1318
+ point_embedding, dim=2, start=1, length=self.num_mask_tokens
1319
+ )
1320
+
1321
+ # Upscale mask embeddings and predict masks using the mask tokens
1322
+ image_embeddings = image_embeddings.transpose(2, 3).reshape(
1323
+ batch_size * point_batch_size, num_channels, height, width
1324
+ )
1325
+
1326
+ upscaled_embedding = self.upscale_conv1(image_embeddings)
1327
+ upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
1328
+ upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
1329
+
1330
+ hyper_in_list = []
1331
+ for i in range(self.num_mask_tokens):
1332
+ current_mlp = self.output_hypernetworks_mlps[i]
1333
+ hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
1334
+ hyper_in = torch.stack(hyper_in_list, dim=2)
1335
+
1336
+ _, num_channels, height, width = upscaled_embedding.shape
1337
+ upscaled_embedding = upscaled_embedding.reshape(
1338
+ batch_size, point_batch_size, num_channels, height * width
1339
+ )
1340
+ masks = (hyper_in @ upscaled_embedding).reshape(
1341
+ batch_size, point_batch_size, -1, height, width
1342
+ )
1343
+
1344
+ # Generate mask quality predictions
1345
+ iou_pred = self.iou_prediction_head(iou_token_out)
1346
+
1347
+ # Select the correct mask or masks for output
1348
+ if multimask_output:
1349
+ mask_slice = slice(1, None)
1350
+ else:
1351
+ mask_slice = slice(0, 1)
1352
+ masks = masks[:, :, mask_slice, :, :]
1353
+ iou_pred = iou_pred[:, :, mask_slice]
1354
+
1355
+ outputs = (masks, iou_pred)
1356
+
1357
+ if len(embeddings_attentions) == 2:
1358
+ # transformers==4.54
1359
+ return outputs
1360
+
1361
+ if output_attentions and len(embeddings_attentions) > 2:
1362
+ outputs = outputs + (embeddings_attentions[2],) # noqa: RUF005
1363
+ else:
1364
+ outputs = outputs + (None,) # noqa: RUF005
1365
+ return outputs