onnx-diagnostic 0.7.5__py3-none-any.whl → 0.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/ext_test_case.py +2 -0
- onnx_diagnostic/helpers/_log_helper.py +6 -6
- onnx_diagnostic/helpers/cache_helper.py +250 -15
- onnx_diagnostic/helpers/helper.py +146 -10
- onnx_diagnostic/helpers/mini_onnx_builder.py +7 -2
- onnx_diagnostic/helpers/onnx_helper.py +13 -7
- onnx_diagnostic/helpers/torch_helper.py +33 -11
- onnx_diagnostic/tasks/__init__.py +2 -0
- onnx_diagnostic/tasks/image_text_to_text.py +260 -56
- onnx_diagnostic/tasks/mask_generation.py +139 -0
- onnx_diagnostic/tasks/text_generation.py +6 -2
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +17 -1
- onnx_diagnostic/torch_export_patches/patch_inputs.py +4 -1
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +309 -129
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +57 -40
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +288 -0
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.6.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.6.dist-info}/RECORD +22 -21
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.6.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.6.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.5.dist-info → onnx_diagnostic-0.7.6.dist-info}/top_level.txt +0 -0
|
@@ -6,12 +6,17 @@ import torch
|
|
|
6
6
|
import transformers
|
|
7
7
|
from transformers.cache_utils import (
|
|
8
8
|
DynamicCache,
|
|
9
|
-
MambaCache,
|
|
10
9
|
EncoderDecoderCache,
|
|
10
|
+
HybridCache,
|
|
11
11
|
SlidingWindowCache,
|
|
12
12
|
StaticCache,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
|
+
try:
|
|
16
|
+
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
17
|
+
except ImportError:
|
|
18
|
+
from transformers.cache_utils import MambaCache
|
|
19
|
+
|
|
15
20
|
from ..helpers import string_type
|
|
16
21
|
from .serialization import _lower_name_with_
|
|
17
22
|
|
|
@@ -161,6 +166,9 @@ def serialization_functions(
|
|
|
161
166
|
flatten_dynamic_cache,
|
|
162
167
|
unflatten_dynamic_cache,
|
|
163
168
|
flatten_with_keys_dynamic_cache,
|
|
169
|
+
flatten_hybrid_cache,
|
|
170
|
+
unflatten_hybrid_cache,
|
|
171
|
+
flatten_with_keys_hybrid_cache,
|
|
164
172
|
flatten_mamba_cache,
|
|
165
173
|
unflatten_mamba_cache,
|
|
166
174
|
flatten_with_keys_mamba_cache,
|
|
@@ -187,6 +195,14 @@ def serialization_functions(
|
|
|
187
195
|
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
188
196
|
verbose=verbose,
|
|
189
197
|
),
|
|
198
|
+
HybridCache: lambda verbose=verbose: register_class_serialization(
|
|
199
|
+
HybridCache,
|
|
200
|
+
flatten_hybrid_cache,
|
|
201
|
+
unflatten_hybrid_cache,
|
|
202
|
+
flatten_with_keys_hybrid_cache,
|
|
203
|
+
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
|
|
204
|
+
verbose=verbose,
|
|
205
|
+
),
|
|
190
206
|
MambaCache: lambda verbose=verbose: register_class_serialization(
|
|
191
207
|
MambaCache,
|
|
192
208
|
flatten_mamba_cache,
|
|
@@ -70,6 +70,8 @@ def convert_dynamic_axes_into_dynamic_shapes(
|
|
|
70
70
|
:param verbose: verbosity
|
|
71
71
|
:return: (args, kwargs, dynamic shapes)
|
|
72
72
|
"""
|
|
73
|
+
from ..helpers.cache_helper import CacheKeyValue
|
|
74
|
+
|
|
73
75
|
new_kwargs = {}
|
|
74
76
|
if args:
|
|
75
77
|
assert hasattr(model, "forward"), f"Missing method 'forward' for {model!r}"
|
|
@@ -121,7 +123,8 @@ def convert_dynamic_axes_into_dynamic_shapes(
|
|
|
121
123
|
changes[k] = type(updated_kwargs[k])
|
|
122
124
|
continue
|
|
123
125
|
if isinstance(v, transformers.cache_utils.DynamicCache):
|
|
124
|
-
|
|
126
|
+
ca = CacheKeyValue(v)
|
|
127
|
+
updated_kwargs[k] = [ca.key_cache, ca.value_cache]
|
|
125
128
|
changes[k] = type(v)
|
|
126
129
|
continue
|
|
127
130
|
raise NotImplementedError(
|
|
@@ -1,12 +1,19 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from functools import wraps
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Callable, List, Optional, Tuple
|
|
5
5
|
import packaging.version as pv
|
|
6
6
|
import torch
|
|
7
7
|
import transformers
|
|
8
8
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
9
|
-
from transformers.cache_utils import StaticCache, Cache
|
|
9
|
+
from transformers.cache_utils import StaticCache, Cache
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from transformers.cache_utils import parse_processor_args # noqa: F401
|
|
13
|
+
|
|
14
|
+
patch_parse_processor_args = True
|
|
15
|
+
except ImportError:
|
|
16
|
+
patch_parse_processor_args = False
|
|
10
17
|
|
|
11
18
|
try:
|
|
12
19
|
import transformers.masking_utils
|
|
@@ -15,10 +22,10 @@ try:
|
|
|
15
22
|
except ImportError:
|
|
16
23
|
patch_masking_utils = False
|
|
17
24
|
|
|
25
|
+
|
|
18
26
|
from ...ext_test_case import has_transformers
|
|
19
27
|
from ...helpers.torch_helper import is_torchdynamo_exporting
|
|
20
28
|
|
|
21
|
-
|
|
22
29
|
if patch_masking_utils:
|
|
23
30
|
# Introduced in 4.52
|
|
24
31
|
from transformers.masking_utils import causal_mask_function, sdpa_mask
|
|
@@ -110,6 +117,46 @@ if patch_masking_utils:
|
|
|
110
117
|
return mask
|
|
111
118
|
|
|
112
119
|
|
|
120
|
+
if patch_parse_processor_args:
|
|
121
|
+
|
|
122
|
+
def _init_cache_inspect():
|
|
123
|
+
res = {}
|
|
124
|
+
for processor_class in transformers.cache_utils.PROCESSOR_CLASS_MAP.values():
|
|
125
|
+
try:
|
|
126
|
+
params = list(inspect.signature(processor_class.__init__).parameters)[2:]
|
|
127
|
+
res[processor_class.__init__] = params
|
|
128
|
+
except Exception:
|
|
129
|
+
res[processor_class.__init__] = None
|
|
130
|
+
return res
|
|
131
|
+
|
|
132
|
+
_cache_inspect = _init_cache_inspect()
|
|
133
|
+
|
|
134
|
+
def patched_parse_processor_args(
|
|
135
|
+
processor_class: Optional[type["CacheProcessor"]], kwargs: dict # noqa: F821
|
|
136
|
+
) -> tuple[dict, dict]:
|
|
137
|
+
"""[patch:transformers.cache_utils.parse_processor_args]"""
|
|
138
|
+
# If not patched...
|
|
139
|
+
# Fails with transformers>=4.54 because function ``parse_processor_args``
|
|
140
|
+
# relies in inspect and the exporter is not very fond of that.
|
|
141
|
+
# torch._dynamo.exc.Unsupported: id() with unsupported args
|
|
142
|
+
# Explanation: Dynamo doesn't know how to trace id()
|
|
143
|
+
# call with args
|
|
144
|
+
# (GetAttrVariable(ConstantVariable(NoneType: None), __init__),)
|
|
145
|
+
# Hint: Supported args are Tensors, and functions/nn.Modules/user-defined
|
|
146
|
+
# objects from outside the compiled region.
|
|
147
|
+
# Hint: It may be possible to write Dynamo tracing rules for this code.
|
|
148
|
+
#
|
|
149
|
+
# The patch is caching the signature to avoid any call to inspect.
|
|
150
|
+
if processor_class is None:
|
|
151
|
+
return {}, kwargs
|
|
152
|
+
params = _cache_inspect[processor_class.__init__]
|
|
153
|
+
if params is None:
|
|
154
|
+
return {}, kwargs
|
|
155
|
+
processor_kwargs = {k: kwargs[k] for k in params if k in kwargs}
|
|
156
|
+
remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs}
|
|
157
|
+
return processor_kwargs, remaining_kwargs
|
|
158
|
+
|
|
159
|
+
|
|
113
160
|
def _patch_make_causal_mask(
|
|
114
161
|
input_ids_shape: torch.Size,
|
|
115
162
|
dtype: torch.dtype,
|
|
@@ -192,136 +239,140 @@ class patched_AttentionMaskConverter:
|
|
|
192
239
|
return _patch_make_causal_mask(**kwargs)
|
|
193
240
|
|
|
194
241
|
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
`transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
|
|
199
|
-
"""
|
|
200
|
-
|
|
201
|
-
_PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
|
|
202
|
-
_PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
|
|
203
|
-
|
|
204
|
-
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
|
205
|
-
"""Returns the sequence length of the cached states.
|
|
206
|
-
A layer index can be optionally passed."""
|
|
207
|
-
# TODO: deprecate this function in favor of `cache_position`
|
|
208
|
-
is_empty_layer = (
|
|
209
|
-
len(self.key_cache) == 0 # no cache in any layer
|
|
210
|
-
or len(self.key_cache)
|
|
211
|
-
<= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
|
212
|
-
or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
|
|
213
|
-
)
|
|
214
|
-
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
|
|
215
|
-
return layer_seq_length
|
|
216
|
-
|
|
217
|
-
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
218
|
-
"""Reorders the cache for beam search, given the selected beam indices."""
|
|
219
|
-
for layer_idx in range(len(self.key_cache)):
|
|
220
|
-
if self.key_cache[layer_idx].numel():
|
|
221
|
-
device = self.key_cache[layer_idx].device
|
|
222
|
-
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
|
|
223
|
-
0, beam_idx.to(device)
|
|
224
|
-
)
|
|
225
|
-
if self.value_cache[layer_idx].numel():
|
|
226
|
-
device = self.value_cache[layer_idx].device
|
|
227
|
-
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
|
|
228
|
-
0, beam_idx.to(device)
|
|
229
|
-
)
|
|
242
|
+
if pv.Version(transformers.__version__) < pv.Version("4.51"):
|
|
243
|
+
from typing import Any, Dict
|
|
244
|
+
from transformers.cache_utils import DynamicCache
|
|
230
245
|
|
|
231
|
-
|
|
232
|
-
self,
|
|
233
|
-
key_states: torch.Tensor,
|
|
234
|
-
value_states: torch.Tensor,
|
|
235
|
-
layer_idx: int,
|
|
236
|
-
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
237
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
246
|
+
class patched_DynamicCache:
|
|
238
247
|
"""
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
Parameters:
|
|
243
|
-
key_states (`torch.Tensor`):
|
|
244
|
-
The new key states to cache.
|
|
245
|
-
value_states (`torch.Tensor`):
|
|
246
|
-
The new value states to cache.
|
|
247
|
-
layer_idx (`int`):
|
|
248
|
-
The index of the layer to cache the states for.
|
|
249
|
-
cache_kwargs (`Dict[str, Any]`, `optional`):
|
|
250
|
-
Additional arguments for the cache subclass.
|
|
251
|
-
No additional arguments are used in `DynamicCache`.
|
|
252
|
-
|
|
253
|
-
Return:
|
|
254
|
-
A tuple containing the updated key and value states.
|
|
248
|
+
Applies modifications implemented in PR
|
|
249
|
+
`transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
|
|
255
250
|
"""
|
|
256
|
-
# Update the number of seen tokens
|
|
257
|
-
if layer_idx == 0:
|
|
258
|
-
if hasattr(self, "_seen_tokens"):
|
|
259
|
-
self._seen_tokens += key_states.shape[-2]
|
|
260
|
-
|
|
261
|
-
# Update the cache
|
|
262
|
-
if key_states is not None:
|
|
263
|
-
if len(self.key_cache) <= layer_idx:
|
|
264
|
-
# There may be skipped layers, fill them with empty lists
|
|
265
|
-
for _ in range(len(self.key_cache), layer_idx):
|
|
266
|
-
self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
267
|
-
self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
268
|
-
self.key_cache.append(key_states)
|
|
269
|
-
self.value_cache.append(value_states)
|
|
270
|
-
elif not self.key_cache[
|
|
271
|
-
layer_idx
|
|
272
|
-
].numel(): # prefers not t.numel() to len(t) == 0 to export the model
|
|
273
|
-
# fills previously skipped layers; checking for tensor causes errors
|
|
274
|
-
self.key_cache[layer_idx] = key_states
|
|
275
|
-
self.value_cache[layer_idx] = value_states
|
|
276
|
-
else:
|
|
277
|
-
self.key_cache[layer_idx] = torch.cat(
|
|
278
|
-
[self.key_cache[layer_idx], key_states], dim=-2
|
|
279
|
-
)
|
|
280
|
-
self.value_cache[layer_idx] = torch.cat(
|
|
281
|
-
[self.value_cache[layer_idx], value_states], dim=-2
|
|
282
|
-
)
|
|
283
|
-
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
284
251
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
252
|
+
_PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
|
|
253
|
+
_PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
|
|
254
|
+
|
|
255
|
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
|
256
|
+
"""Returns the sequence length of the cached states.
|
|
257
|
+
A layer index can be optionally passed."""
|
|
258
|
+
# TODO: deprecate this function in favor of `cache_position`
|
|
259
|
+
is_empty_layer = (
|
|
260
|
+
len(self.key_cache) == 0 # no cache in any layer
|
|
261
|
+
or len(self.key_cache)
|
|
262
|
+
<= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
|
263
|
+
or self.key_cache[layer_idx].numel() == 0 # the layer has no cache
|
|
264
|
+
)
|
|
265
|
+
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
|
|
266
|
+
return layer_seq_length
|
|
267
|
+
|
|
268
|
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
269
|
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
|
270
|
+
for layer_idx in range(len(self.key_cache)):
|
|
271
|
+
if self.key_cache[layer_idx].numel():
|
|
272
|
+
device = self.key_cache[layer_idx].device
|
|
273
|
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(
|
|
274
|
+
0, beam_idx.to(device)
|
|
275
|
+
)
|
|
276
|
+
if self.value_cache[layer_idx].numel():
|
|
277
|
+
device = self.value_cache[layer_idx].device
|
|
278
|
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(
|
|
279
|
+
0, beam_idx.to(device)
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
def update(
|
|
283
|
+
self,
|
|
284
|
+
key_states: torch.Tensor,
|
|
285
|
+
value_states: torch.Tensor,
|
|
286
|
+
layer_idx: int,
|
|
287
|
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
288
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
289
|
+
"""
|
|
290
|
+
Updates the cache with the new `key_states`
|
|
291
|
+
and `value_states` for the layer `layer_idx`.
|
|
292
|
+
Parameters:
|
|
293
|
+
key_states (`torch.Tensor`):
|
|
294
|
+
The new key states to cache.
|
|
295
|
+
value_states (`torch.Tensor`):
|
|
296
|
+
The new value states to cache.
|
|
297
|
+
layer_idx (`int`):
|
|
298
|
+
The index of the layer to cache the states for.
|
|
299
|
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
|
300
|
+
Additional arguments for the cache subclass.
|
|
301
|
+
No additional arguments are used in `DynamicCache`.
|
|
302
|
+
Return:
|
|
303
|
+
A tuple containing the updated key and value states.
|
|
304
|
+
"""
|
|
305
|
+
# Update the number of seen tokens
|
|
306
|
+
if layer_idx == 0:
|
|
307
|
+
if hasattr(self, "_seen_tokens"):
|
|
308
|
+
self._seen_tokens += key_states.shape[-2]
|
|
309
|
+
|
|
310
|
+
# Update the cache
|
|
311
|
+
if key_states is not None:
|
|
312
|
+
if len(self.key_cache) <= layer_idx:
|
|
313
|
+
# There may be skipped layers, fill them with empty lists
|
|
314
|
+
for _ in range(len(self.key_cache), layer_idx):
|
|
315
|
+
self.key_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
316
|
+
self.value_cache.append(torch.tensor([], dtype=key_states.dtype))
|
|
317
|
+
self.key_cache.append(key_states)
|
|
318
|
+
self.value_cache.append(value_states)
|
|
319
|
+
elif not self.key_cache[
|
|
320
|
+
layer_idx
|
|
321
|
+
].numel(): # prefers not t.numel() to len(t) == 0 to export the model
|
|
322
|
+
# fills previously skipped layers; checking for tensor causes errors
|
|
323
|
+
self.key_cache[layer_idx] = key_states
|
|
324
|
+
self.value_cache[layer_idx] = value_states
|
|
325
|
+
else:
|
|
326
|
+
self.key_cache[layer_idx] = torch.cat(
|
|
327
|
+
[self.key_cache[layer_idx], key_states], dim=-2
|
|
328
|
+
)
|
|
329
|
+
self.value_cache[layer_idx] = torch.cat(
|
|
330
|
+
[self.value_cache[layer_idx], value_states], dim=-2
|
|
331
|
+
)
|
|
332
|
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
333
|
+
|
|
334
|
+
def crop(self, max_length: int):
|
|
335
|
+
"""Crop the past key values up to a new `max_length`
|
|
336
|
+
in terms of tokens. `max_length` can also be
|
|
337
|
+
negative to remove `max_length` tokens.
|
|
338
|
+
This is used in assisted decoding and contrastive search.
|
|
339
|
+
"""
|
|
340
|
+
# In case it is negative
|
|
341
|
+
if max_length < 0:
|
|
342
|
+
max_length = self.get_seq_length() - abs(max_length)
|
|
343
|
+
|
|
344
|
+
if self.get_seq_length() <= max_length:
|
|
345
|
+
return
|
|
346
|
+
|
|
347
|
+
if hasattr(self, "_seen_tokens"):
|
|
348
|
+
self._seen_tokens = max_length
|
|
349
|
+
for idx in range(len(self.key_cache)):
|
|
350
|
+
if self.key_cache[idx].numel():
|
|
351
|
+
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
|
|
352
|
+
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
|
|
353
|
+
|
|
354
|
+
@classmethod
|
|
355
|
+
def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache:
|
|
356
|
+
"""This is the opposite of the above `batch_split()` method.
|
|
357
|
+
This will be used by `stack_model_outputs` in
|
|
358
|
+
`generation.utils`"""
|
|
359
|
+
cache = cls()
|
|
360
|
+
for idx in range(len(splits[0])):
|
|
361
|
+
key_cache = [
|
|
362
|
+
current.key_cache[idx]
|
|
363
|
+
for current in splits
|
|
364
|
+
if current.key_cache[idx].numel()
|
|
365
|
+
]
|
|
366
|
+
value_cache = [
|
|
367
|
+
current.value_cache[idx]
|
|
368
|
+
for current in splits
|
|
369
|
+
if current.value_cache[idx].numel()
|
|
370
|
+
]
|
|
371
|
+
if key_cache != []:
|
|
372
|
+
layer_keys = torch.cat(key_cache, dim=0)
|
|
373
|
+
layer_values = torch.cat(value_cache, dim=0)
|
|
374
|
+
cache.update(layer_keys, layer_values, idx)
|
|
375
|
+
return cache
|
|
325
376
|
|
|
326
377
|
|
|
327
378
|
class patched_GenerationMixin:
|
|
@@ -1183,3 +1234,132 @@ class patched_IdeficsAttention(torch.nn.Module):
|
|
|
1183
1234
|
if pv.Version(transformers.__version__) < pv.Version("4.53.99"):
|
|
1184
1235
|
return attn_output, attn_weights, past_key_value
|
|
1185
1236
|
return attn_output, attn_weights
|
|
1237
|
+
|
|
1238
|
+
|
|
1239
|
+
class patched_SamMaskDecoder(torch.nn.Module):
|
|
1240
|
+
_PATCHES_ = ["forward"]
|
|
1241
|
+
_PATCHED_CLASS_ = transformers.models.sam.modeling_sam.SamMaskDecoder
|
|
1242
|
+
|
|
1243
|
+
def forward(
|
|
1244
|
+
self,
|
|
1245
|
+
image_embeddings: torch.Tensor,
|
|
1246
|
+
image_positional_embeddings: torch.Tensor,
|
|
1247
|
+
sparse_prompt_embeddings: torch.Tensor,
|
|
1248
|
+
dense_prompt_embeddings: torch.Tensor,
|
|
1249
|
+
multimask_output: bool,
|
|
1250
|
+
output_attentions: Optional[bool] = None,
|
|
1251
|
+
attention_similarity: Optional[torch.Tensor] = None,
|
|
1252
|
+
target_embedding: Optional[torch.Tensor] = None,
|
|
1253
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
1254
|
+
"""
|
|
1255
|
+
Predict masks given image and prompt embeddings.
|
|
1256
|
+
|
|
1257
|
+
Args:
|
|
1258
|
+
image_embeddings (`torch.Tensor`):
|
|
1259
|
+
the embeddings from the image encoder
|
|
1260
|
+
image_positional_embedding (`torch.Tensor`):
|
|
1261
|
+
positional encoding with the shape of image_embeddings
|
|
1262
|
+
sparse_prompt_embeddings (`torch.Tensor`):
|
|
1263
|
+
The embeddings of the points and boxes
|
|
1264
|
+
dense_prompt_embeddings (`torch.Tensor`):
|
|
1265
|
+
the embeddings of the mask inputs
|
|
1266
|
+
multimask_output (bool):
|
|
1267
|
+
Whether to return multiple masks or a single mask.
|
|
1268
|
+
output_attentions (bool, *optional*):
|
|
1269
|
+
Whether or not to return the attentions tensors of all attention layers.
|
|
1270
|
+
"""
|
|
1271
|
+
batch_size, num_channels, height, width = image_embeddings.shape
|
|
1272
|
+
point_batch_size = sparse_prompt_embeddings.shape[1]
|
|
1273
|
+
# Concatenate output tokens
|
|
1274
|
+
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
|
1275
|
+
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
|
|
1276
|
+
|
|
1277
|
+
# torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
|
|
1278
|
+
# torch.any is needed to avoid data-dependent control flow
|
|
1279
|
+
# with sparse_prompt_embeddings.sum().item() != 0
|
|
1280
|
+
def sparse_prompt_embeddings_is_not_empty(output_tokens, sparse_prompt_embeddings):
|
|
1281
|
+
return torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
|
|
1282
|
+
|
|
1283
|
+
def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
|
|
1284
|
+
return output_tokens.clone()
|
|
1285
|
+
|
|
1286
|
+
tokens = torch.cond(
|
|
1287
|
+
torch.any(sparse_prompt_embeddings != 0),
|
|
1288
|
+
sparse_prompt_embeddings_is_not_empty,
|
|
1289
|
+
sparse_prompt_embeddings_is_empty,
|
|
1290
|
+
[output_tokens, sparse_prompt_embeddings],
|
|
1291
|
+
)
|
|
1292
|
+
|
|
1293
|
+
point_embeddings = tokens.to(self.iou_token.weight.dtype)
|
|
1294
|
+
|
|
1295
|
+
# Expand per-image data in batch direction to be per-point
|
|
1296
|
+
image_embeddings = image_embeddings + dense_prompt_embeddings
|
|
1297
|
+
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
|
|
1298
|
+
image_positional_embeddings = image_positional_embeddings.repeat_interleave(
|
|
1299
|
+
point_batch_size, 0
|
|
1300
|
+
)
|
|
1301
|
+
|
|
1302
|
+
# Run the transformer, image_positional_embedding are consumed
|
|
1303
|
+
torch._check(point_embeddings.shape[0] != 0)
|
|
1304
|
+
torch._check(point_embeddings.shape[1] != 0)
|
|
1305
|
+
torch._check(point_embeddings.shape[2] != 0)
|
|
1306
|
+
torch._check(point_embeddings.shape[3] != 0)
|
|
1307
|
+
embeddings_attentions = self.transformer(
|
|
1308
|
+
point_embeddings=point_embeddings,
|
|
1309
|
+
image_embeddings=image_embeddings,
|
|
1310
|
+
image_positional_embeddings=image_positional_embeddings,
|
|
1311
|
+
attention_similarity=attention_similarity,
|
|
1312
|
+
target_embedding=target_embedding,
|
|
1313
|
+
output_attentions=output_attentions,
|
|
1314
|
+
)
|
|
1315
|
+
point_embedding, image_embeddings = embeddings_attentions[:2]
|
|
1316
|
+
iou_token_out = torch.select(point_embedding, dim=2, index=0)
|
|
1317
|
+
mask_tokens_out = torch.narrow(
|
|
1318
|
+
point_embedding, dim=2, start=1, length=self.num_mask_tokens
|
|
1319
|
+
)
|
|
1320
|
+
|
|
1321
|
+
# Upscale mask embeddings and predict masks using the mask tokens
|
|
1322
|
+
image_embeddings = image_embeddings.transpose(2, 3).reshape(
|
|
1323
|
+
batch_size * point_batch_size, num_channels, height, width
|
|
1324
|
+
)
|
|
1325
|
+
|
|
1326
|
+
upscaled_embedding = self.upscale_conv1(image_embeddings)
|
|
1327
|
+
upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
|
|
1328
|
+
upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
|
|
1329
|
+
|
|
1330
|
+
hyper_in_list = []
|
|
1331
|
+
for i in range(self.num_mask_tokens):
|
|
1332
|
+
current_mlp = self.output_hypernetworks_mlps[i]
|
|
1333
|
+
hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
|
|
1334
|
+
hyper_in = torch.stack(hyper_in_list, dim=2)
|
|
1335
|
+
|
|
1336
|
+
_, num_channels, height, width = upscaled_embedding.shape
|
|
1337
|
+
upscaled_embedding = upscaled_embedding.reshape(
|
|
1338
|
+
batch_size, point_batch_size, num_channels, height * width
|
|
1339
|
+
)
|
|
1340
|
+
masks = (hyper_in @ upscaled_embedding).reshape(
|
|
1341
|
+
batch_size, point_batch_size, -1, height, width
|
|
1342
|
+
)
|
|
1343
|
+
|
|
1344
|
+
# Generate mask quality predictions
|
|
1345
|
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
|
1346
|
+
|
|
1347
|
+
# Select the correct mask or masks for output
|
|
1348
|
+
if multimask_output:
|
|
1349
|
+
mask_slice = slice(1, None)
|
|
1350
|
+
else:
|
|
1351
|
+
mask_slice = slice(0, 1)
|
|
1352
|
+
masks = masks[:, :, mask_slice, :, :]
|
|
1353
|
+
iou_pred = iou_pred[:, :, mask_slice]
|
|
1354
|
+
|
|
1355
|
+
outputs = (masks, iou_pred)
|
|
1356
|
+
|
|
1357
|
+
if len(embeddings_attentions) == 2:
|
|
1358
|
+
# transformers==4.54
|
|
1359
|
+
return outputs
|
|
1360
|
+
|
|
1361
|
+
if output_attentions and len(embeddings_attentions) > 2:
|
|
1362
|
+
outputs = outputs + (embeddings_attentions[2],) # noqa: RUF005
|
|
1363
|
+
else:
|
|
1364
|
+
outputs = outputs + (None,) # noqa: RUF005
|
|
1365
|
+
return outputs
|