onnx-diagnostic 0.7.5__py3-none-any.whl → 0.7.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +56 -3
  3. onnx_diagnostic/export/dynamic_shapes.py +24 -10
  4. onnx_diagnostic/export/shape_helper.py +6 -2
  5. onnx_diagnostic/ext_test_case.py +2 -0
  6. onnx_diagnostic/helpers/_log_helper.py +6 -6
  7. onnx_diagnostic/helpers/cache_helper.py +326 -18
  8. onnx_diagnostic/helpers/config_helper.py +10 -0
  9. onnx_diagnostic/helpers/helper.py +152 -11
  10. onnx_diagnostic/helpers/mini_onnx_builder.py +7 -2
  11. onnx_diagnostic/helpers/onnx_helper.py +13 -7
  12. onnx_diagnostic/helpers/torch_helper.py +33 -11
  13. onnx_diagnostic/reference/ops/op_cast_like.py +15 -11
  14. onnx_diagnostic/reference/torch_ops/__init__.py +1 -0
  15. onnx_diagnostic/reference/torch_ops/unary_ops.py +7 -0
  16. onnx_diagnostic/tasks/__init__.py +2 -0
  17. onnx_diagnostic/tasks/automatic_speech_recognition.py +6 -2
  18. onnx_diagnostic/tasks/feature_extraction.py +7 -3
  19. onnx_diagnostic/tasks/fill_mask.py +6 -2
  20. onnx_diagnostic/tasks/image_classification.py +6 -2
  21. onnx_diagnostic/tasks/image_text_to_text.py +289 -62
  22. onnx_diagnostic/tasks/mask_generation.py +143 -0
  23. onnx_diagnostic/tasks/mixture_of_expert.py +2 -2
  24. onnx_diagnostic/tasks/object_detection.py +6 -2
  25. onnx_diagnostic/tasks/sentence_similarity.py +6 -2
  26. onnx_diagnostic/tasks/summarization.py +7 -2
  27. onnx_diagnostic/tasks/text2text_generation.py +7 -2
  28. onnx_diagnostic/tasks/text_classification.py +6 -2
  29. onnx_diagnostic/tasks/text_generation.py +14 -16
  30. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +3 -3
  31. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +17 -1
  32. onnx_diagnostic/torch_export_patches/patch_inputs.py +5 -2
  33. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -4
  34. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +428 -129
  35. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +60 -41
  36. onnx_diagnostic/torch_models/hghub/hub_data.py +5 -0
  37. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +288 -0
  38. onnx_diagnostic/torch_models/validate.py +1 -0
  39. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/METADATA +2 -2
  40. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/RECORD +43 -42
  41. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/WHEEL +0 -0
  42. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/licenses/LICENSE.txt +0 -0
  43. {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.7.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,20 @@
1
1
  import inspect
2
+ import math
2
3
  from dataclasses import dataclass
3
4
  from functools import wraps
4
- from typing import Any, Callable, Dict, List, Optional, Tuple
5
+ from typing import Callable, List, Optional, Tuple
5
6
  import packaging.version as pv
6
7
  import torch
7
8
  import transformers
8
9
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
9
- from transformers.cache_utils import StaticCache, Cache, DynamicCache
10
+ from transformers.cache_utils import StaticCache, Cache
11
+
12
+ try:
13
+ from transformers.cache_utils import parse_processor_args # noqa: F401
14
+
15
+ patch_parse_processor_args = True
16
+ except ImportError:
17
+ patch_parse_processor_args = False
10
18
 
11
19
  try:
12
20
  import transformers.masking_utils
@@ -15,10 +23,18 @@ try:
15
23
  except ImportError:
16
24
  patch_masking_utils = False
17
25
 
26
+
27
+ try:
28
+ # transformers>= 4.55.1
29
+ from transformers.cache_utils import DynamicLayer
30
+
31
+ patch_DynamicLayer = hasattr(DynamicLayer, "lazy_initialization")
32
+ except ImportError:
33
+ patch_DynamicLayer = False
34
+
18
35
  from ...ext_test_case import has_transformers
19
36
  from ...helpers.torch_helper import is_torchdynamo_exporting
20
37
 
21
-
22
38
  if patch_masking_utils:
23
39
  # Introduced in 4.52
24
40
  from transformers.masking_utils import causal_mask_function, sdpa_mask
@@ -110,6 +126,60 @@ if patch_masking_utils:
110
126
  return mask
111
127
 
112
128
 
129
+ if patch_parse_processor_args:
130
+
131
+ def _init_cache_inspect():
132
+ res = {}
133
+ for processor_class in transformers.cache_utils.PROCESSOR_CLASS_MAP.values():
134
+ try:
135
+ params = list(inspect.signature(processor_class.__init__).parameters)[2:]
136
+ res[processor_class.__init__] = params
137
+ except Exception:
138
+ res[processor_class.__init__] = None
139
+ return res
140
+
141
+ _cache_inspect = _init_cache_inspect()
142
+
143
+ def patched_parse_processor_args(
144
+ processor_class: Optional[type["CacheProcessor"]], kwargs: dict # noqa: F821
145
+ ) -> tuple[dict, dict]:
146
+ """[patch:transformers.cache_utils.parse_processor_args]"""
147
+ # If not patched...
148
+ # Fails with transformers>=4.54 because function ``parse_processor_args``
149
+ # relies in inspect and the exporter is not very fond of that.
150
+ # torch._dynamo.exc.Unsupported: id() with unsupported args
151
+ # Explanation: Dynamo doesn't know how to trace id()
152
+ # call with args
153
+ # (GetAttrVariable(ConstantVariable(NoneType: None), __init__),)
154
+ # Hint: Supported args are Tensors, and functions/nn.Modules/user-defined
155
+ # objects from outside the compiled region.
156
+ # Hint: It may be possible to write Dynamo tracing rules for this code.
157
+ #
158
+ # The patch is caching the signature to avoid any call to inspect.
159
+ if processor_class is None:
160
+ return {}, kwargs
161
+ params = _cache_inspect[processor_class.__init__]
162
+ if params is None:
163
+ return {}, kwargs
164
+ processor_kwargs = {k: kwargs[k] for k in params if k in kwargs}
165
+ remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs}
166
+ return processor_kwargs, remaining_kwargs
167
+
168
+
169
+ if patch_DynamicLayer:
170
+
171
+ class patched_DynamicLayer:
172
+ _PATCHES_ = ["lazy_initialization"]
173
+ _PATCHED_CLASS_ = DynamicLayer
174
+
175
+ def lazy_initialization(self, key_states: torch.Tensor):
176
+ self.dtype, self.device = key_states.dtype, key_states.device
177
+ new_shape = list(key_states.shape)
178
+ new_shape[-2] = 0
179
+ self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
180
+ self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
181
+
182
+
113
183
  def _patch_make_causal_mask(
114
184
  input_ids_shape: torch.Size,
115
185
  dtype: torch.dtype,
@@ -192,136 +262,148 @@ class patched_AttentionMaskConverter:
192
262
  return _patch_make_causal_mask(**kwargs)
193
263
 
194
264
 
195
- class patched_DynamicCache:
196
- """
197
- Applies modifications implemented in PR
198
- `transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
199
- """
200
-
201
- _PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
202
- _PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
203
-
204
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
205
- """Returns the sequence length of the cached states.
206
- A layer index can be optionally passed."""
207
- # TODO: deprecate this function in favor of `cache_position`
208
- is_empty_layer = (
209
- len(self.key_cache) == 0 # no cache in any layer
210
- or len(self.key_cache)
211
- <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
212
- or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
213
- )
214
- layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
215
- return layer_seq_length
216
-
217
- def reorder_cache(self, beam_idx: torch.LongTensor):
218
- """Reorders the cache for beam search, given the selected beam indices."""
219
- for layer_idx in range(len(self.key_cache)):
220
- if self.key_cache[layer_idx].numel():
221
- device = self.key_cache[layer_idx].device
222
- self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
223
- 0, beam_idx.to(device)
224
- )
225
- if self.value_cache[layer_idx].numel():
226
- device = self.value_cache[layer_idx].device
227
- self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
228
- 0, beam_idx.to(device)
229
- )
265
+ if pv.Version(transformers.__version__) < pv.Version("4.51"):
266
+ from typing import Any, Dict
267
+ from transformers.cache_utils import DynamicCache
230
268
 
231
- def update(
232
- self,
233
- key_states: torch.Tensor,
234
- value_states: torch.Tensor,
235
- layer_idx: int,
236
- cache_kwargs: Optional[Dict[str, Any]] = None,
237
- ) -> Tuple[torch.Tensor, torch.Tensor]:
269
+ class patched_DynamicCache:
238
270
  """
239
- Updates the cache with the new `key_states`
240
- and `value_states` for the layer `layer_idx`.
241
-
242
- Parameters:
243
- key_states (`torch.Tensor`):
244
- The new key states to cache.
245
- value_states (`torch.Tensor`):
246
- The new value states to cache.
247
- layer_idx (`int`):
248
- The index of the layer to cache the states for.
249
- cache_kwargs (`Dict[str, Any]`, `optional`):
250
- Additional arguments for the cache subclass.
251
- No additional arguments are used in `DynamicCache`.
252
-
253
- Return:
254
- A tuple containing the updated key and value states.
271
+ Applies modifications implemented in PR
272
+ `transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
255
273
  """
256
- # Update the number of seen tokens
257
- if layer_idx == 0:
258
- if hasattr(self, "_seen_tokens"):
259
- self._seen_tokens += key_states.shape[-2]
260
-
261
- # Update the cache
262
- if key_states is not None:
263
- if len(self.key_cache) <= layer_idx:
264
- # There may be skipped layers, fill them with empty lists
265
- for _ in range(len(self.key_cache), layer_idx):
266
- self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
267
- self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
268
- self.key_cache.append(key_states)
269
- self.value_cache.append(value_states)
270
- elif not self.key_cache[
271
- layer_idx
272
- ].numel(): # prefers not t.numel() to len(t) == 0 to export the model
273
- # fills previously skipped layers; checking for tensor causes errors
274
- self.key_cache[layer_idx] = key_states
275
- self.value_cache[layer_idx] = value_states
276
- else:
277
- self.key_cache[layer_idx] = torch.cat(
278
- [self.key_cache[layer_idx], key_states], dim=-2
279
- )
280
- self.value_cache[layer_idx] = torch.cat(
281
- [self.value_cache[layer_idx], value_states], dim=-2
282
- )
283
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
284
274
 
285
- def crop(self, max_length: int):
286
- """Crop the past key values up to a new `max_length`
287
- in terms of tokens. `max_length` can also be
288
- negative to remove `max_length` tokens.
289
- This is used in assisted decoding and contrastive search.
290
- """
291
- # In case it is negative
292
- if max_length < 0:
293
- max_length = self.get_seq_length() - abs(max_length)
294
-
295
- if self.get_seq_length() <= max_length:
296
- return
297
-
298
- if hasattr(self, "_seen_tokens"):
299
- self._seen_tokens = max_length
300
- for idx in range(len(self.key_cache)):
301
- if self.key_cache[idx].numel():
302
- self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
303
- self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
304
-
305
- @classmethod
306
- def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache:
307
- """This is the opposite of the above `batch_split()` method.
308
- This will be used by `stack_model_outputs` in
309
- `generation.utils`"""
310
- cache = cls()
311
- for idx in range(len(splits[0])):
312
- key_cache = [
313
- current.key_cache[idx] for current in splits if current.key_cache[idx].numel()
314
- ]
315
- value_cache = [
316
- current.value_cache[idx]
317
- for current in splits
318
- if current.value_cache[idx].numel()
319
- ]
320
- if key_cache != []:
321
- layer_keys = torch.cat(key_cache, dim=0)
322
- layer_values = torch.cat(value_cache, dim=0)
323
- cache.update(layer_keys, layer_values, idx)
324
- return cache
275
+ _PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
276
+ _PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
277
+
278
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
279
+ """Returns the sequence length of the cached states.
280
+ A layer index can be optionally passed."""
281
+ # TODO: deprecate this function in favor of `cache_position`
282
+ is_empty_layer = (
283
+ len(self.key_cache) == 0 # no cache in any layer
284
+ or len(self.key_cache)
285
+ <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
286
+ or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
287
+ )
288
+ layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
289
+ return layer_seq_length
290
+
291
+ def reorder_cache(self, beam_idx: torch.LongTensor):
292
+ """Reorders the cache for beam search, given the selected beam indices."""
293
+ for layer_idx in range(len(self.key_cache)):
294
+ if self.key_cache[layer_idx].numel():
295
+ device = self.key_cache[layer_idx].device
296
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
297
+ 0, beam_idx.to(device)
298
+ )
299
+ if self.value_cache[layer_idx].numel():
300
+ device = self.value_cache[layer_idx].device
301
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
302
+ 0, beam_idx.to(device)
303
+ )
304
+
305
+ def update(
306
+ self,
307
+ key_states: torch.Tensor,
308
+ value_states: torch.Tensor,
309
+ layer_idx: int,
310
+ cache_kwargs: Optional[Dict[str, Any]] = None,
311
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
312
+ """
313
+ Updates the cache with the new `key_states`
314
+ and `value_states` for the layer `layer_idx`.
315
+ Parameters:
316
+ key_states (`torch.Tensor`):
317
+ The new key states to cache.
318
+ value_states (`torch.Tensor`):
319
+ The new value states to cache.
320
+ layer_idx (`int`):
321
+ The index of the layer to cache the states for.
322
+ cache_kwargs (`Dict[str, Any]`, `optional`):
323
+ Additional arguments for the cache subclass.
324
+ No additional arguments are used in `DynamicCache`.
325
+ Return:
326
+ A tuple containing the updated key and value states.
327
+ """
328
+ # Update the number of seen tokens
329
+ if layer_idx == 0:
330
+ if hasattr(self, "_seen_tokens"):
331
+ self._seen_tokens += key_states.shape[-2]
332
+
333
+ # Update the cache
334
+ if key_states is not None:
335
+ if len(self.key_cache) <= layer_idx:
336
+ # There may be skipped layers, fill them with empty lists
337
+ for _ in range(len(self.key_cache), layer_idx):
338
+ self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
339
+ self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
340
+ self.key_cache.append(key_states)
341
+ self.value_cache.append(value_states)
342
+ elif not self.key_cache[
343
+ layer_idx
344
+ ].numel(): # prefers not t.numel() to len(t) == 0 to export the model
345
+ # fills previously skipped layers; checking for tensor causes errors
346
+ self.key_cache[layer_idx] = key_states
347
+ self.value_cache[layer_idx] = value_states
348
+ else:
349
+ torch._check(
350
+ len(self.key_cache[layer_idx].shape) == len(key_states.shape),
351
+ lambda: (
352
+ f"Rank mismatch len(self.key_cache[layer_idx].shape)="
353
+ f"{len(self.key_cache[layer_idx].shape)}, "
354
+ f"len(key_states.shape)={len(key_states.shape)}"
355
+ ),
356
+ )
357
+ self.key_cache[layer_idx] = torch.cat(
358
+ [self.key_cache[layer_idx], key_states], dim=-2
359
+ )
360
+ self.value_cache[layer_idx] = torch.cat(
361
+ [self.value_cache[layer_idx], value_states], dim=-2
362
+ )
363
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
364
+
365
+ def crop(self, max_length: int):
366
+ """Crop the past key values up to a new `max_length`
367
+ in terms of tokens. `max_length` can also be
368
+ negative to remove `max_length` tokens.
369
+ This is used in assisted decoding and contrastive search.
370
+ """
371
+ # In case it is negative
372
+ if max_length < 0:
373
+ max_length = self.get_seq_length() - abs(max_length)
374
+
375
+ if self.get_seq_length() <= max_length:
376
+ return
377
+
378
+ if hasattr(self, "_seen_tokens"):
379
+ self._seen_tokens = max_length
380
+ for idx in range(len(self.key_cache)):
381
+ if self.key_cache[idx].numel():
382
+ self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
383
+ self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
384
+
385
+ @classmethod
386
+ def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache:
387
+ """This is the opposite of the above `batch_split()` method.
388
+ This will be used by `stack_model_outputs` in
389
+ `generation.utils`"""
390
+ cache = cls()
391
+ for idx in range(len(splits[0])):
392
+ key_cache = [
393
+ current.key_cache[idx]
394
+ for current in splits
395
+ if current.key_cache[idx].numel()
396
+ ]
397
+ value_cache = [
398
+ current.value_cache[idx]
399
+ for current in splits
400
+ if current.value_cache[idx].numel()
401
+ ]
402
+ if key_cache != []:
403
+ layer_keys = torch.cat(key_cache, dim=0)
404
+ layer_values = torch.cat(value_cache, dim=0)
405
+ cache.update(layer_keys, layer_values, idx)
406
+ return cache
325
407
 
326
408
 
327
409
  class patched_GenerationMixin:
@@ -1183,3 +1265,220 @@ class patched_IdeficsAttention(torch.nn.Module):
1183
1265
  if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
1184
1266
  return attn_output, attn_weights, past_key_value
1185
1267
  return attn_output, attn_weights
1268
+
1269
+
1270
+ class patched_SamMaskDecoder(torch.nn.Module):
1271
+ _PATCHES_ = ["forward"]
1272
+ _PATCHED_CLASS_ = transformers.models.sam.modeling_sam.SamMaskDecoder
1273
+
1274
+ def forward(
1275
+ self,
1276
+ image_embeddings: torch.Tensor,
1277
+ image_positional_embeddings: torch.Tensor,
1278
+ sparse_prompt_embeddings: torch.Tensor,
1279
+ dense_prompt_embeddings: torch.Tensor,
1280
+ multimask_output: bool,
1281
+ output_attentions: Optional[bool] = None,
1282
+ attention_similarity: Optional[torch.Tensor] = None,
1283
+ target_embedding: Optional[torch.Tensor] = None,
1284
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1285
+ """
1286
+ Predict masks given image and prompt embeddings.
1287
+
1288
+ Args:
1289
+ image_embeddings (`torch.Tensor`):
1290
+ the embeddings from the image encoder
1291
+ image_positional_embedding (`torch.Tensor`):
1292
+ positional encoding with the shape of image_embeddings
1293
+ sparse_prompt_embeddings (`torch.Tensor`):
1294
+ The embeddings of the points and boxes
1295
+ dense_prompt_embeddings (`torch.Tensor`):
1296
+ the embeddings of the mask inputs
1297
+ multimask_output (bool):
1298
+ Whether to return multiple masks or a single mask.
1299
+ output_attentions (bool, *optional*):
1300
+ Whether or not to return the attentions tensors of all attention layers.
1301
+ """
1302
+ batch_size, num_channels, height, width = image_embeddings.shape
1303
+ point_batch_size = sparse_prompt_embeddings.shape[1]
1304
+ # Concatenate output tokens
1305
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
1306
+ output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
1307
+
1308
+ # torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
1309
+ # torch.any is needed to avoid data-dependent control flow
1310
+ # with sparse_prompt_embeddings.sum().item() != 0
1311
+ def sparse_prompt_embeddings_is_not_empty(output_tokens, sparse_prompt_embeddings):
1312
+ return torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
1313
+
1314
+ def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
1315
+ return output_tokens.clone()
1316
+
1317
+ tokens = torch.cond(
1318
+ torch.any(sparse_prompt_embeddings != 0),
1319
+ sparse_prompt_embeddings_is_not_empty,
1320
+ sparse_prompt_embeddings_is_empty,
1321
+ [output_tokens, sparse_prompt_embeddings],
1322
+ )
1323
+
1324
+ point_embeddings = tokens.to(self.iou_token.weight.dtype)
1325
+
1326
+ # Expand per-image data in batch direction to be per-point
1327
+ image_embeddings = image_embeddings + dense_prompt_embeddings
1328
+ image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
1329
+ image_positional_embeddings = image_positional_embeddings.repeat_interleave(
1330
+ point_batch_size, 0
1331
+ )
1332
+
1333
+ # Run the transformer, image_positional_embedding are consumed
1334
+ torch._check(point_embeddings.shape[0] != 0)
1335
+ torch._check(point_embeddings.shape[1] != 0)
1336
+ torch._check(point_embeddings.shape[2] != 0)
1337
+ torch._check(point_embeddings.shape[3] != 0)
1338
+ embeddings_attentions = self.transformer(
1339
+ point_embeddings=point_embeddings,
1340
+ image_embeddings=image_embeddings,
1341
+ image_positional_embeddings=image_positional_embeddings,
1342
+ attention_similarity=attention_similarity,
1343
+ target_embedding=target_embedding,
1344
+ output_attentions=output_attentions,
1345
+ )
1346
+ point_embedding, image_embeddings = embeddings_attentions[:2]
1347
+ iou_token_out = torch.select(point_embedding, dim=2, index=0)
1348
+ mask_tokens_out = torch.narrow(
1349
+ point_embedding, dim=2, start=1, length=self.num_mask_tokens
1350
+ )
1351
+
1352
+ # Upscale mask embeddings and predict masks using the mask tokens
1353
+ image_embeddings = image_embeddings.transpose(2, 3).reshape(
1354
+ batch_size * point_batch_size, num_channels, height, width
1355
+ )
1356
+
1357
+ upscaled_embedding = self.upscale_conv1(image_embeddings)
1358
+ upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
1359
+ upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
1360
+
1361
+ hyper_in_list = []
1362
+ for i in range(self.num_mask_tokens):
1363
+ current_mlp = self.output_hypernetworks_mlps[i]
1364
+ hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
1365
+ hyper_in = torch.stack(hyper_in_list, dim=2)
1366
+
1367
+ _, num_channels, height, width = upscaled_embedding.shape
1368
+ upscaled_embedding = upscaled_embedding.reshape(
1369
+ batch_size, point_batch_size, num_channels, height * width
1370
+ )
1371
+ masks = (hyper_in @ upscaled_embedding).reshape(
1372
+ batch_size, point_batch_size, -1, height, width
1373
+ )
1374
+
1375
+ # Generate mask quality predictions
1376
+ iou_pred = self.iou_prediction_head(iou_token_out)
1377
+
1378
+ # Select the correct mask or masks for output
1379
+ if multimask_output:
1380
+ mask_slice = slice(1, None)
1381
+ else:
1382
+ mask_slice = slice(0, 1)
1383
+ masks = masks[:, :, mask_slice, :, :]
1384
+ iou_pred = iou_pred[:, :, mask_slice]
1385
+
1386
+ outputs = (masks, iou_pred)
1387
+
1388
+ if len(embeddings_attentions) == 2:
1389
+ # transformers==4.54
1390
+ return outputs
1391
+
1392
+ if output_attentions and len(embeddings_attentions) > 2:
1393
+ outputs = outputs + (embeddings_attentions[2],) # noqa: RUF005
1394
+ else:
1395
+ outputs = outputs + (None,) # noqa: RUF005
1396
+ return outputs
1397
+
1398
+
1399
+ def rewrite_loop_for_square_mask(mask: torch.Tensor, seq: torch.Tensor):
1400
+ """
1401
+ Rewrites the loop in:
1402
+
1403
+ .. code-block:: python
1404
+
1405
+ attention_mask = torch.full(
1406
+ [1, seq_length, seq_length], torch.finfo(q.dtype).min, dtype=q.dtype
1407
+ )
1408
+ for i in range(1, len(seq)):
1409
+ attention_mask[..., seq[i - 1] : seq[i], seq[i - 1] : seq[i]] = 0
1410
+ """
1411
+ r = torch.arange(0, mask.shape[-1], dtype=torch.int64)
1412
+ less0 = (r.reshape((-1, 1)) < seq.reshape((1, -1))).to(torch.int64)
1413
+ less = less0.sum(axis=-1, keepdim=True) + 1
1414
+ sq = less * less.T
1415
+ look = (
1416
+ torch.max(seq.min() == 0, less != less.max())
1417
+ * torch.max(seq.max() == mask.shape[-1], less != less.min())
1418
+ * less
1419
+ )
1420
+ filt = (sq != look**2).to(mask.dtype)
1421
+ return mask * filt
1422
+
1423
+
1424
+ class patched_VisionAttention(torch.nn.Module):
1425
+ _PATCHES_ = ["forward"]
1426
+ _PATCHED_CLASS_ = transformers.models.qwen2_vl.modeling_qwen2_vl.VisionAttention
1427
+
1428
+ def forward(
1429
+ self,
1430
+ hidden_states: torch.Tensor,
1431
+ cu_seqlens: torch.Tensor,
1432
+ rotary_pos_emb: Optional[torch.Tensor] = None,
1433
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1434
+ ) -> torch.Tensor:
1435
+ seq_length = hidden_states.shape[0]
1436
+ q, k, v = (
1437
+ self.qkv(hidden_states)
1438
+ .reshape(seq_length, 3, self.num_heads, -1)
1439
+ .permute(1, 0, 2, 3)
1440
+ .unbind(0)
1441
+ )
1442
+ if position_embeddings is None:
1443
+ transformers.models.qwen2_vl.modeling_qwen2_vl.logger.warning_once(
1444
+ "The attention layers in this model are transitioning from "
1445
+ " computing the RoPE embeddings internally "
1446
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), "
1447
+ "to using externally computed "
1448
+ "`position_embeddings` (Tuple of tensors, containing cos and sin)."
1449
+ " In v4.54 `rotary_pos_emb` will be "
1450
+ "removed and `position_embeddings` will be mandatory."
1451
+ )
1452
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
1453
+ cos = emb.cos()
1454
+ sin = emb.sin()
1455
+ else:
1456
+ cos, sin = position_embeddings
1457
+ q, k = transformers.models.qwen2_vl.modeling_qwen2_vl.apply_rotary_pos_emb_vision(
1458
+ q, k, cos, sin
1459
+ )
1460
+
1461
+ attention_mask = torch.full(
1462
+ [1, seq_length, seq_length],
1463
+ torch.finfo(q.dtype).min,
1464
+ device=q.device,
1465
+ dtype=q.dtype,
1466
+ )
1467
+ # for i in range(1, len(cu_seqlens)):
1468
+ # attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i],
1469
+ # cu_seqlens[i - 1] : cu_seqlens[i]] = 0
1470
+ attention_mask = rewrite_loop_for_square_mask(attention_mask, cu_seqlens)
1471
+
1472
+ q = q.transpose(0, 1)
1473
+ k = k.transpose(0, 1)
1474
+ v = v.transpose(0, 1)
1475
+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
1476
+ attn_weights = attn_weights + attention_mask
1477
+ attn_weights = torch.nn.functional.softmax(
1478
+ attn_weights, dim=-1, dtype=torch.float32
1479
+ ).to(q.dtype)
1480
+ attn_output = torch.matmul(attn_weights, v)
1481
+ attn_output = attn_output.transpose(0, 1)
1482
+ attn_output = attn_output.reshape(seq_length, -1)
1483
+ attn_output = self.proj(attn_output)
1484
+ return attn_output