optimum-rbln 0.8.3a4__py3-none-any.whl → 0.8.4a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (31) hide show
  1. optimum/rbln/__init__.py +14 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +15 -0
  4. optimum/rbln/modeling.py +2 -4
  5. optimum/rbln/modeling_base.py +44 -13
  6. optimum/rbln/transformers/__init__.py +14 -0
  7. optimum/rbln/transformers/configuration_generic.py +2 -0
  8. optimum/rbln/transformers/modeling_generic.py +12 -4
  9. optimum/rbln/transformers/models/__init__.py +18 -0
  10. optimum/rbln/transformers/models/auto/__init__.py +1 -0
  11. optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
  12. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  13. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  14. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +6 -1
  15. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +6 -3
  16. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +7 -1
  17. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +12 -31
  18. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +1 -1
  19. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +7 -1
  20. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  21. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +86 -0
  22. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +507 -0
  23. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  24. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +2 -0
  25. optimum/rbln/transformers/models/swin/modeling_swin.py +32 -7
  26. optimum/rbln/transformers/utils/rbln_quantization.py +47 -31
  27. optimum/rbln/utils/submodule.py +10 -4
  28. {optimum_rbln-0.8.3a4.dist-info → optimum_rbln-0.8.4a0.dist-info}/METADATA +1 -1
  29. {optimum_rbln-0.8.3a4.dist-info → optimum_rbln-0.8.4a0.dist-info}/RECORD +31 -26
  30. {optimum_rbln-0.8.3a4.dist-info → optimum_rbln-0.8.4a0.dist-info}/WHEEL +0 -0
  31. {optimum_rbln-0.8.3a4.dist-info → optimum_rbln-0.8.4a0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,86 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at:
4
+
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ from typing import Any, List, Optional, Tuple, Union
14
+
15
+ import torch
16
+
17
+ from ...configuration_generic import RBLNImageModelConfig, RBLNModelConfig
18
+
19
+
20
+ class RBLNGroundingDinoForObjectDetectionConfig(RBLNImageModelConfig):
21
+ submodules = [
22
+ "text_backbone",
23
+ "backbone",
24
+ "encoder",
25
+ "decoder",
26
+ ]
27
+
28
+ def __init__(
29
+ self,
30
+ batch_size: Optional[int] = None,
31
+ encoder: Optional["RBLNGroundingDinoEncoderConfig"] = None,
32
+ decoder: Optional["RBLNGroundingDinoDecoderConfig"] = None,
33
+ text_backbone: Optional["RBLNModelConfig"] = None,
34
+ backbone: Optional["RBLNModelConfig"] = None,
35
+ output_attentions: Optional[bool] = False,
36
+ output_hidden_states: Optional[bool] = False,
37
+ **kwargs: Any,
38
+ ):
39
+ """
40
+ Args:
41
+ batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
42
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
43
+
44
+ Raises:
45
+ ValueError: If batch_size is not a positive integer.
46
+ """
47
+ super().__init__(**kwargs)
48
+ self.encoder = encoder
49
+ self.decoder = decoder
50
+ self.text_backbone = text_backbone
51
+ self.backbone = backbone
52
+ self.output_attentions = output_attentions
53
+ self.output_hidden_states = output_hidden_states
54
+
55
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
56
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
57
+
58
+
59
+ class RBLNGroundingDinoComponentConfig(RBLNImageModelConfig):
60
+ def __init__(
61
+ self,
62
+ image_size: Optional[Union[int, Tuple[int, int]]] = None,
63
+ batch_size: Optional[int] = None,
64
+ spatial_shapes_list: Optional[List[Tuple[int, int]]] = None,
65
+ output_attentions: Optional[bool] = False,
66
+ output_hidden_states: Optional[bool] = False,
67
+ **kwargs: Any,
68
+ ):
69
+ super().__init__(image_size=image_size, batch_size=batch_size, **kwargs)
70
+ self.spatial_shapes_list = spatial_shapes_list
71
+ self.output_attentions = output_attentions
72
+ self.output_hidden_states = output_hidden_states
73
+
74
+ @property
75
+ def spatial_shapes(self):
76
+ if self.spatial_shapes_list is None:
77
+ raise ValueError("Spatial shapes are not defined. Please set them before accessing.")
78
+ return torch.tensor(self.spatial_shapes_list)
79
+
80
+
81
+ class RBLNGroundingDinoEncoderConfig(RBLNGroundingDinoComponentConfig):
82
+ pass
83
+
84
+
85
+ class RBLNGroundingDinoDecoderConfig(RBLNGroundingDinoComponentConfig):
86
+ pass
@@ -0,0 +1,507 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import wraps
15
+ from typing import TYPE_CHECKING, List, Optional, Tuple
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import Tensor
20
+ from transformers.models.grounding_dino.modeling_grounding_dino import (
21
+ GroundingDinoDecoder,
22
+ GroundingDinoEncoder,
23
+ get_sine_pos_embed,
24
+ )
25
+
26
+
27
+ if TYPE_CHECKING:
28
+ from .configuration_grounding_dino import RBLNGroundingDinoDecoderConfig, RBLNGroundingDinoEncoderConfig
29
+
30
+
31
+ def monkey_patch():
32
+ from transformers.models.grounding_dino.modeling_grounding_dino import (
33
+ GroundingDinoBiMultiHeadAttention,
34
+ GroundingDinoEncoderLayer,
35
+ GroundingDinoMultiscaleDeformableAttention,
36
+ )
37
+
38
+ original_forward = GroundingDinoMultiscaleDeformableAttention.forward
39
+ original_bi_multihead_attention_forward = GroundingDinoBiMultiHeadAttention.forward
40
+ original_encoder_layer_forward = GroundingDinoEncoderLayer.forward
41
+
42
+ # Patch the methods with the custom implementations
43
+ GroundingDinoMultiscaleDeformableAttention.forward = _GroundingDinoMultiscaleDeformableAttention.forward
44
+ GroundingDinoBiMultiHeadAttention.forward = _GroundingDinoBiMultiHeadAttention.forward
45
+ GroundingDinoEncoderLayer.forward = _GroundingDinoEncoderLayer.forward
46
+
47
+ return (original_forward, original_bi_multihead_attention_forward, original_encoder_layer_forward)
48
+
49
+
50
+ def restore_monkey_patch(original_forward, original_bi_multihead_attention_forward, original_encoder_layer_forward):
51
+ from transformers.models.grounding_dino.modeling_grounding_dino import (
52
+ GroundingDinoBiMultiHeadAttention,
53
+ GroundingDinoEncoderLayer,
54
+ GroundingDinoMultiscaleDeformableAttention,
55
+ )
56
+
57
+ # Restore the original methods
58
+ GroundingDinoMultiscaleDeformableAttention.forward = original_forward
59
+ GroundingDinoBiMultiHeadAttention.forward = original_bi_multihead_attention_forward
60
+ GroundingDinoEncoderLayer.forward = original_encoder_layer_forward
61
+
62
+
63
+ def monkey_patch_decorator(func):
64
+ @wraps(func)
65
+ def wrapper(*args, **kwargs):
66
+ # Apply monkey patch and capture original methods
67
+ original_functions = monkey_patch()
68
+ try:
69
+ # Call the original function
70
+ result = func(*args, **kwargs)
71
+ finally:
72
+ # Restore original methods
73
+ restore_monkey_patch(*original_functions)
74
+ return result
75
+
76
+ return wrapper
77
+
78
+
79
+ class _GroundingDinoEncoder(torch.nn.Module):
80
+ def __init__(self, model: "GroundingDinoEncoder", rbln_config: "RBLNGroundingDinoEncoderConfig"):
81
+ super().__init__()
82
+ self.layers = model.layers
83
+ self.config = model.config
84
+ self.rbln_config = rbln_config
85
+ self.spatial_shapes = self.rbln_config.spatial_shapes
86
+ self.spatial_shapes_list = self.rbln_config.spatial_shapes_list
87
+ self.text_position_embedding = model.layers[0].get_text_position_embeddings(
88
+ torch.zeros(1, model.config.max_text_len, model.config.d_model),
89
+ None,
90
+ torch.arange(model.config.max_text_len, dtype=torch.int32).unsqueeze(0),
91
+ )
92
+
93
+ @monkey_patch_decorator
94
+ def forward(
95
+ self,
96
+ vision_features: torch.Tensor,
97
+ vision_attention_mask: torch.Tensor,
98
+ vision_position_embedding: torch.Tensor,
99
+ text_features: Optional[torch.Tensor] = None,
100
+ text_attention_mask: Optional[torch.Tensor] = None,
101
+ text_self_attention_masks: Optional[torch.Tensor] = None,
102
+ reference_points: Optional[torch.Tensor] = None,
103
+ ):
104
+ output_attentions = self.rbln_config.output_attentions
105
+ output_hidden_states = self.rbln_config.output_hidden_states
106
+
107
+ encoder_vision_states = () if output_hidden_states else None
108
+ encoder_text_states = () if output_hidden_states else None
109
+ all_attns = () if output_attentions else None
110
+ all_attn_fused_text = () if output_attentions else None
111
+ all_attn_fused_vision = () if output_attentions else None
112
+ all_attn_enhanced_text = () if output_attentions else None
113
+ all_attn_deformable = () if output_attentions else None
114
+ for i, encoder_layer in enumerate(self.layers):
115
+ if output_hidden_states:
116
+ encoder_vision_states += (vision_features,)
117
+ encoder_text_states += (text_features,)
118
+
119
+ (vision_features, text_features), attentions = encoder_layer(
120
+ vision_features=vision_features,
121
+ vision_position_embedding=vision_position_embedding,
122
+ spatial_shapes=self.spatial_shapes,
123
+ spatial_shapes_list=self.spatial_shapes_list,
124
+ level_start_index=None,
125
+ key_padding_mask=vision_attention_mask,
126
+ reference_points=reference_points,
127
+ text_features=text_features,
128
+ text_attention_mask=text_attention_mask,
129
+ text_position_embedding=self.text_position_embedding,
130
+ text_self_attention_masks=text_self_attention_masks,
131
+ )
132
+ if output_attentions:
133
+ all_attn_fused_vision += (attentions[0],)
134
+ all_attn_fused_text += (attentions[1],)
135
+ all_attn_enhanced_text += (attentions[2],)
136
+ all_attn_deformable += (attentions[3],)
137
+
138
+ if output_hidden_states:
139
+ encoder_vision_states += (vision_features,)
140
+ encoder_text_states += (text_features,)
141
+
142
+ if output_attentions:
143
+ all_attns = (all_attn_fused_vision, all_attn_fused_text, all_attn_enhanced_text, all_attn_deformable)
144
+
145
+ enc_outputs = [vision_features, text_features, encoder_vision_states, encoder_text_states, all_attns]
146
+
147
+ return tuple(v for v in enc_outputs if v is not None)
148
+
149
+
150
+ class _GroundingDinoDecoder(torch.nn.Module):
151
+ def __init__(self, model: "GroundingDinoDecoder", rbln_config: "RBLNGroundingDinoDecoderConfig"):
152
+ super().__init__()
153
+ self.layers = model.layers
154
+ self.config = model.config
155
+ self.spatial_shapes = rbln_config.spatial_shapes
156
+ self.spatial_shapes_list = rbln_config.spatial_shapes_list
157
+ self.rbln_config = rbln_config
158
+ self.reference_points_head = model.reference_points_head
159
+ self.bbox_embed = model.bbox_embed
160
+ self.layer_norm = model.layer_norm
161
+
162
+ @monkey_patch_decorator
163
+ def forward(
164
+ self,
165
+ inputs_embeds,
166
+ vision_encoder_hidden_states,
167
+ vision_encoder_attention_mask=None,
168
+ text_encoder_hidden_states=None,
169
+ text_encoder_attention_mask=None,
170
+ reference_points=None,
171
+ valid_ratios=None,
172
+ ):
173
+ output_attentions = self.rbln_config.output_attentions
174
+ output_hidden_states = self.rbln_config.output_hidden_states
175
+
176
+ if inputs_embeds is not None:
177
+ hidden_states = inputs_embeds
178
+
179
+ # decoder layers
180
+ all_hidden_states = () if output_hidden_states else None
181
+ all_self_attns = () if output_attentions else None
182
+ all_attns = () if output_attentions else None
183
+ all_cross_attns_vision = () if (output_attentions and vision_encoder_hidden_states is not None) else None
184
+ all_cross_attns_text = () if (output_attentions and text_encoder_hidden_states is not None) else None
185
+ intermediate = ()
186
+ intermediate_reference_points = ()
187
+
188
+ if text_encoder_attention_mask is not None:
189
+ text_encoder_attention_mask = text_encoder_attention_mask[:, None, None, :]
190
+ text_encoder_attention_mask = text_encoder_attention_mask.repeat(
191
+ 1, self.config.decoder_attention_heads, self.config.num_queries, 1
192
+ )
193
+ text_encoder_attention_mask = text_encoder_attention_mask
194
+ text_encoder_attention_mask = text_encoder_attention_mask * torch.finfo(torch.float16).min
195
+
196
+ for idx, decoder_layer in enumerate(self.layers):
197
+ num_coordinates = reference_points.shape[-1]
198
+ if num_coordinates == 4:
199
+ reference_points_input = (
200
+ reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
201
+ )
202
+ elif num_coordinates == 2:
203
+ reference_points_input = reference_points[:, :, None] * valid_ratios[:, None]
204
+ else:
205
+ raise ValueError("Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
206
+ _query_pos = get_sine_pos_embed(reference_points_input[:, :, 0, :], num_pos_feats=self.config.d_model // 2)
207
+ query_pos = self.reference_points_head(_query_pos)
208
+
209
+ # In original implementation they apply layer norm before outputting intermediate hidden states
210
+ # Though that's not through between layers so the layers use as input the output of the previous layer
211
+ # withtout layer norm
212
+ if output_hidden_states:
213
+ all_hidden_states += (self.layer_norm(hidden_states),)
214
+
215
+ layer_outputs = decoder_layer(
216
+ hidden_states=hidden_states,
217
+ position_embeddings=query_pos,
218
+ reference_points=reference_points_input,
219
+ spatial_shapes=self.spatial_shapes,
220
+ spatial_shapes_list=self.spatial_shapes_list,
221
+ level_start_index=None,
222
+ vision_encoder_hidden_states=vision_encoder_hidden_states,
223
+ vision_encoder_attention_mask=vision_encoder_attention_mask,
224
+ text_encoder_hidden_states=text_encoder_hidden_states,
225
+ text_encoder_attention_mask=text_encoder_attention_mask,
226
+ self_attn_mask=None,
227
+ output_attentions=output_attentions,
228
+ )
229
+
230
+ hidden_states = layer_outputs[0]
231
+
232
+ # hack implementation for iterative bounding box refinement
233
+ if self.bbox_embed is not None:
234
+ tmp = self.bbox_embed[idx](hidden_states)
235
+ num_coordinates = reference_points.shape[-1]
236
+ if num_coordinates == 4:
237
+ new_reference_points = tmp + torch.special.logit(reference_points, eps=1e-5)
238
+ new_reference_points = new_reference_points.sigmoid()
239
+ elif num_coordinates == 2:
240
+ new_reference_points = tmp
241
+ new_reference_points[..., :2] = tmp[..., :2] + torch.special.logit(reference_points, eps=1e-5)
242
+ new_reference_points = new_reference_points.sigmoid()
243
+ else:
244
+ raise ValueError(
245
+ f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}"
246
+ )
247
+ reference_points = new_reference_points.detach()
248
+
249
+ intermediate += (self.layer_norm(hidden_states),)
250
+ intermediate_reference_points += (reference_points,)
251
+
252
+ if output_attentions:
253
+ all_self_attns += (layer_outputs[1],)
254
+
255
+ if text_encoder_hidden_states is not None:
256
+ all_cross_attns_text += (layer_outputs[2],)
257
+
258
+ if vision_encoder_hidden_states is not None:
259
+ all_cross_attns_vision += (layer_outputs[3],)
260
+
261
+ # Keep batch_size as first dimension
262
+ intermediate = torch.stack(intermediate, dim=1)
263
+ intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
264
+ hidden_states = self.layer_norm(hidden_states)
265
+
266
+ # add hidden states from the last decoder layer
267
+ if output_hidden_states:
268
+ all_hidden_states += (hidden_states,)
269
+
270
+ if output_attentions:
271
+ all_attns += (all_self_attns, all_cross_attns_text, all_cross_attns_vision)
272
+
273
+ return tuple(
274
+ v
275
+ for v in [
276
+ hidden_states,
277
+ intermediate,
278
+ intermediate_reference_points,
279
+ all_hidden_states,
280
+ all_attns,
281
+ ]
282
+ if v is not None
283
+ )
284
+
285
+
286
+ class _GroundingDinoEncoderLayer(torch.nn.Module):
287
+ def forward(
288
+ self,
289
+ vision_features: Tensor,
290
+ vision_position_embedding: Tensor,
291
+ spatial_shapes: Tensor,
292
+ spatial_shapes_list: List[Tuple[int, int]],
293
+ level_start_index: Tensor,
294
+ key_padding_mask: Tensor,
295
+ reference_points: Tensor,
296
+ text_features: Optional[Tensor] = None,
297
+ text_attention_mask: Optional[Tensor] = None,
298
+ text_position_embedding: Optional[Tensor] = None,
299
+ text_self_attention_masks: Optional[Tensor] = None,
300
+ text_position_ids: Optional[Tensor] = None,
301
+ ):
302
+ text_position_embedding = self.get_text_position_embeddings(
303
+ text_features, text_position_embedding, text_position_ids
304
+ )
305
+
306
+ (vision_features, vision_fused_attn), (text_features, text_fused_attn) = self.fusion_layer(
307
+ vision_features=vision_features,
308
+ text_features=text_features,
309
+ attention_mask_vision=key_padding_mask,
310
+ attention_mask_text=text_attention_mask,
311
+ )
312
+
313
+ (text_features, text_enhanced_attn) = self.text_enhancer_layer(
314
+ hidden_states=text_features,
315
+ attention_masks=(1.0 - text_self_attention_masks), # RBLN FIX, change from ~ to 1.0 -
316
+ position_embeddings=(text_position_embedding if text_position_embedding is not None else None),
317
+ )
318
+
319
+ (vision_features, vision_deformable_attn) = self.deformable_layer(
320
+ hidden_states=vision_features,
321
+ attention_mask=(1.0 - key_padding_mask), # RBLN FIX, change from ~ to 1.0 -
322
+ position_embeddings=vision_position_embedding,
323
+ reference_points=reference_points,
324
+ spatial_shapes=spatial_shapes,
325
+ spatial_shapes_list=spatial_shapes_list,
326
+ level_start_index=level_start_index,
327
+ )
328
+
329
+ return (
330
+ (vision_features, text_features),
331
+ (vision_fused_attn, text_fused_attn, text_enhanced_attn, vision_deformable_attn),
332
+ )
333
+
334
+
335
+ class _GroundingDinoMultiscaleDeformableAttention(torch.nn.Module):
336
+ """
337
+ Multiscale deformable attention as proposed in Deformable DETR.
338
+ """
339
+
340
+ def forward(
341
+ self,
342
+ hidden_states: torch.Tensor,
343
+ attention_mask: Optional[torch.Tensor] = None,
344
+ encoder_hidden_states=None,
345
+ encoder_attention_mask=None,
346
+ position_embeddings: Optional[torch.Tensor] = None,
347
+ reference_points=None,
348
+ spatial_shapes=None,
349
+ spatial_shapes_list=None,
350
+ level_start_index=None,
351
+ output_attentions: bool = False,
352
+ ):
353
+ # add position embeddings to the hidden states before projecting to queries and keys
354
+ if position_embeddings is not None:
355
+ hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
356
+
357
+ batch_size, num_queries, _ = hidden_states.shape
358
+ batch_size, sequence_length, _ = encoder_hidden_states.shape
359
+ # Ignore copy
360
+ if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
361
+ raise ValueError(
362
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
363
+ )
364
+
365
+ value = self.value_proj(encoder_hidden_states)
366
+ if attention_mask is not None:
367
+ # RBLN FIX: bool tensor to float tensor
368
+ value = attention_mask * value
369
+
370
+ value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
371
+ sampling_offsets = self.sampling_offsets(hidden_states).view(
372
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
373
+ )
374
+ attention_weights = self.attention_weights(hidden_states).view(
375
+ batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
376
+ )
377
+ attention_weights = F.softmax(attention_weights, -1).view(
378
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
379
+ )
380
+ # batch_size, num_queries, n_heads, n_levels, n_points, 2
381
+ num_coordinates = reference_points.shape[-1]
382
+ if num_coordinates == 2:
383
+ offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
384
+ sampling_locations = (
385
+ reference_points[:, :, None, :, None, :]
386
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
387
+ )
388
+ elif num_coordinates == 4:
389
+ sampling_locations = (
390
+ reference_points[:, :, None, :, None, :2]
391
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
392
+ )
393
+ else:
394
+ raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
395
+
396
+ output = self.attn(
397
+ value,
398
+ spatial_shapes,
399
+ spatial_shapes_list,
400
+ level_start_index,
401
+ sampling_locations,
402
+ attention_weights,
403
+ self.im2col_step,
404
+ )
405
+
406
+ output = self.output_proj(output)
407
+
408
+ return output, attention_weights
409
+
410
+
411
+ class _GroundingDinoBiMultiHeadAttention(torch.nn.Module):
412
+ def forward(
413
+ self,
414
+ vision_features: torch.FloatTensor,
415
+ text_features: torch.FloatTensor,
416
+ vision_attention_mask: Optional[torch.BoolTensor] = None,
417
+ text_attention_mask: Optional[torch.BoolTensor] = None,
418
+ ) -> Tuple[Tuple[torch.FloatTensor, torch.FloatTensor], Tuple[torch.FloatTensor, torch.FloatTensor]]:
419
+ batch_size, tgt_len, _ = vision_features.size()
420
+
421
+ vision_query_states = self.vision_proj(vision_features) * self.scale
422
+ vision_query_states = self._reshape(vision_query_states, tgt_len, batch_size)
423
+
424
+ text_key_states = self.text_proj(text_features)
425
+ text_key_states = self._reshape(text_key_states, -1, batch_size)
426
+
427
+ vision_value_states = self.values_vision_proj(vision_features)
428
+ vision_value_states = self._reshape(vision_value_states, -1, batch_size)
429
+
430
+ text_value_states = self.values_text_proj(text_features)
431
+ text_value_states = self._reshape(text_value_states, -1, batch_size)
432
+
433
+ proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
434
+
435
+ vision_query_states = vision_query_states.view(*proj_shape)
436
+ text_key_states = text_key_states.view(*proj_shape)
437
+ vision_value_states = vision_value_states.view(*proj_shape)
438
+ text_value_states = text_value_states.view(*proj_shape)
439
+
440
+ src_len = text_key_states.size(1)
441
+ attn_weights = torch.bmm(vision_query_states, text_key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt
442
+
443
+ if attn_weights.size() != (batch_size * self.num_heads, tgt_len, src_len):
444
+ raise ValueError(
445
+ f"Attention weights should be of size {(batch_size * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
446
+ )
447
+
448
+ # RBLN FIX: max_values from scalar to vector
449
+ attn_weights = attn_weights - torch.max(attn_weights).reshape(1).repeat(src_len)
450
+ # # Do not increase -50000/50000, data type half has quite limited range
451
+ attn_weights = torch.clamp(attn_weights, min=-50000, max=50000)
452
+
453
+ attn_weights_transposed = attn_weights.transpose(1, 2)
454
+ # RBLN FIX: max_values from scalar to vector
455
+ text_attn_weights = attn_weights_transposed - torch.max(attn_weights_transposed, dim=-1, keepdim=True)[
456
+ 0
457
+ ].repeat(1, 1, tgt_len)
458
+
459
+ # # Do not increase -50000/50000, data type half has quite limited range
460
+ text_attn_weights = torch.clamp(text_attn_weights, min=-50000, max=50000)
461
+
462
+ # mask vision for language
463
+ if vision_attention_mask is not None:
464
+ # RBLN FIX: bool tensor to float tensor
465
+ mask = vision_attention_mask * torch.finfo(torch.float16).min
466
+ text_attn_weights = text_attn_weights.transpose(1, 2) + mask
467
+ text_attn_weights = text_attn_weights.transpose(1, 2)
468
+
469
+ text_attn_weights = text_attn_weights.softmax(dim=-1)
470
+
471
+ # mask language for vision
472
+ if text_attention_mask is not None:
473
+ text_attention_mask = text_attention_mask[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
474
+ # RBLN FIX: bool tensor to float tensor
475
+ mask = text_attention_mask * torch.finfo(torch.float16).min
476
+ attn_weights = attn_weights + mask
477
+
478
+ vision_attn_weights = attn_weights.softmax(dim=-1)
479
+
480
+ vision_attn_probs = F.dropout(vision_attn_weights, p=self.dropout, training=self.training)
481
+ text_attn_probs = F.dropout(text_attn_weights, p=self.dropout, training=self.training)
482
+
483
+ vision_attn_output = torch.bmm(vision_attn_probs, text_value_states)
484
+ text_attn_output = torch.bmm(text_attn_probs, vision_value_states)
485
+
486
+ if vision_attn_output.size() != (batch_size * self.num_heads, tgt_len, self.head_dim):
487
+ raise ValueError(
488
+ f"`vision_attn_output` should be of size {(batch_size, self.num_heads, tgt_len, self.head_dim)}, but is {vision_attn_output.size()}"
489
+ )
490
+
491
+ if text_attn_output.size() != (batch_size * self.num_heads, src_len, self.head_dim):
492
+ raise ValueError(
493
+ f"`text_attn_output` should be of size {(batch_size, self.num_heads, src_len, self.head_dim)}, but is {text_attn_output.size()}"
494
+ )
495
+
496
+ vision_attn_output = vision_attn_output.view(batch_size, self.num_heads, tgt_len, self.head_dim)
497
+ vision_attn_output = vision_attn_output.transpose(1, 2)
498
+ vision_attn_output = vision_attn_output.reshape(batch_size, tgt_len, self.embed_dim)
499
+
500
+ text_attn_output = text_attn_output.view(batch_size, self.num_heads, src_len, self.head_dim)
501
+ text_attn_output = text_attn_output.transpose(1, 2)
502
+ text_attn_output = text_attn_output.reshape(batch_size, src_len, self.embed_dim)
503
+
504
+ vision_attn_output = self.out_vision_proj(vision_attn_output)
505
+ text_attn_output = self.out_text_proj(text_attn_output)
506
+
507
+ return (vision_attn_output, vision_attn_weights), (text_attn_output, text_attn_weights)