optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. optimum/rbln/__init__.py +116 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +171 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +33 -18
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +50 -24
  52. optimum/rbln/modeling_base.py +116 -35
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +100 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +93 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  158. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  159. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  160. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  161. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  162. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  163. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  164. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
  165. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  166. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  167. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  168. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  169. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  170. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  171. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  172. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  173. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  174. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  175. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  176. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
  177. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  178. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  179. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  180. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  181. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  182. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  183. optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
  184. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  185. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  186. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  187. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  188. optimum/rbln/utils/deprecation.py +213 -0
  189. optimum/rbln/utils/hub.py +22 -50
  190. optimum/rbln/utils/runtime_utils.py +85 -17
  191. optimum/rbln/utils/submodule.py +31 -9
  192. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  193. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  194. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  195. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  196. optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
  197. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,599 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from functools import wraps
16
+ from typing import TYPE_CHECKING, List, Optional, Tuple
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import Tensor
21
+ from transformers.models.grounding_dino.modeling_grounding_dino import (
22
+ GroundingDinoDecoder,
23
+ GroundingDinoEncoder,
24
+ )
25
+
26
+
27
+ if TYPE_CHECKING:
28
+ from .configuration_grounding_dino import RBLNGroundingDinoDecoderConfig, RBLNGroundingDinoEncoderConfig
29
+
30
+
31
+ def monkey_patch():
32
+ from transformers.models.grounding_dino.modeling_grounding_dino import (
33
+ GroundingDinoBiMultiHeadAttention,
34
+ GroundingDinoEncoderLayer,
35
+ GroundingDinoMultiscaleDeformableAttention,
36
+ MultiScaleDeformableAttention,
37
+ )
38
+
39
+ original_forward = GroundingDinoMultiscaleDeformableAttention.forward
40
+ original_bi_multihead_attention_forward = GroundingDinoBiMultiHeadAttention.forward
41
+ original_encoder_layer_forward = GroundingDinoEncoderLayer.forward
42
+ original_multiscale_deform_attn = MultiScaleDeformableAttention.forward
43
+
44
+ # Patch the methods with the custom implementations
45
+ GroundingDinoMultiscaleDeformableAttention.forward = _GroundingDinoMultiscaleDeformableAttention.forward
46
+ GroundingDinoBiMultiHeadAttention.forward = _GroundingDinoBiMultiHeadAttention.forward
47
+ GroundingDinoEncoderLayer.forward = _GroundingDinoEncoderLayer.forward
48
+ MultiScaleDeformableAttention.forward = _MultiScaleDeformableAttention.forward
49
+
50
+ return (
51
+ original_forward,
52
+ original_bi_multihead_attention_forward,
53
+ original_encoder_layer_forward,
54
+ original_multiscale_deform_attn,
55
+ )
56
+
57
+
58
+ def restore_monkey_patch(
59
+ original_forward,
60
+ original_bi_multihead_attention_forward,
61
+ original_encoder_layer_forward,
62
+ original_multiscale_deform_attn,
63
+ ):
64
+ from transformers.models.grounding_dino.modeling_grounding_dino import (
65
+ GroundingDinoBiMultiHeadAttention,
66
+ GroundingDinoEncoderLayer,
67
+ GroundingDinoMultiscaleDeformableAttention,
68
+ MultiScaleDeformableAttention,
69
+ )
70
+
71
+ # Restore the original methods
72
+ GroundingDinoMultiscaleDeformableAttention.forward = original_forward
73
+ GroundingDinoBiMultiHeadAttention.forward = original_bi_multihead_attention_forward
74
+ GroundingDinoEncoderLayer.forward = original_encoder_layer_forward
75
+ MultiScaleDeformableAttention.forward = original_multiscale_deform_attn
76
+
77
+
78
+ def monkey_patch_decorator(func):
79
+ @wraps(func)
80
+ def wrapper(*args, **kwargs):
81
+ # Apply monkey patch and capture original methods
82
+ original_functions = monkey_patch()
83
+ try:
84
+ # Call the original function
85
+ result = func(*args, **kwargs)
86
+ finally:
87
+ # Restore original methods
88
+ restore_monkey_patch(*original_functions)
89
+ return result
90
+
91
+ return wrapper
92
+
93
+
94
+ def get_sine_pos_embed(
95
+ pos_tensor: torch.Tensor, num_pos_feats: int = 128, temperature: int = 10000, exchange_xy: bool = True
96
+ ) -> Tensor:
97
+ scale = 2 * math.pi
98
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
99
+ dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
100
+
101
+ scaled_pos = pos_tensor.unsqueeze(-1) * scale / dim_t
102
+ reshaped_pos = scaled_pos.view(*scaled_pos.shape[:-1], -1, 2)
103
+ sin_chunk, cos_chunk = torch.split(reshaped_pos, 1, dim=-1)
104
+ sin_embed = sin_chunk.squeeze(-1).sin()
105
+ cos_embed = cos_chunk.squeeze(-1).cos()
106
+
107
+ pos_embed = torch.stack((sin_embed, cos_embed), dim=-1).flatten(-2)
108
+
109
+ if exchange_xy and pos_tensor.shape[-1] >= 2:
110
+ swapped_embeds = torch.cat([pos_embed[..., 1:2, :], pos_embed[..., 0:1, :], pos_embed[..., 2:, :]], dim=-2)
111
+ pos_embed = swapped_embeds
112
+
113
+ position_embeddings = pos_embed.flatten(start_dim=-2)
114
+
115
+ return position_embeddings
116
+
117
+
118
+ class _GroundingDinoEncoder(torch.nn.Module):
119
+ def __init__(self, model: "GroundingDinoEncoder", rbln_config: "RBLNGroundingDinoEncoderConfig"):
120
+ super().__init__()
121
+ self.layers = model.layers
122
+ self.config = model.config
123
+ self.rbln_config = rbln_config
124
+ self.spatial_shapes = self.rbln_config.spatial_shapes
125
+ self.spatial_shapes_list = self.rbln_config.spatial_shapes_list
126
+ self.text_position_embedding = model.layers[0].get_text_position_embeddings(
127
+ torch.zeros(1, model.config.max_text_len, model.config.d_model),
128
+ None,
129
+ torch.arange(model.config.max_text_len, dtype=torch.int32).unsqueeze(0),
130
+ )
131
+
132
+ @monkey_patch_decorator
133
+ def forward(
134
+ self,
135
+ vision_features: torch.Tensor,
136
+ vision_attention_mask: torch.Tensor,
137
+ vision_position_embedding: torch.Tensor,
138
+ text_features: Optional[torch.Tensor] = None,
139
+ text_attention_mask: Optional[torch.Tensor] = None,
140
+ text_self_attention_masks: Optional[torch.Tensor] = None,
141
+ reference_points: Optional[torch.Tensor] = None,
142
+ ):
143
+ output_attentions = self.rbln_config.output_attentions
144
+ output_hidden_states = self.rbln_config.output_hidden_states
145
+
146
+ encoder_vision_states = () if output_hidden_states else None
147
+ encoder_text_states = () if output_hidden_states else None
148
+ all_attns = () if output_attentions else None
149
+ all_attn_fused_text = () if output_attentions else None
150
+ all_attn_fused_vision = () if output_attentions else None
151
+ all_attn_enhanced_text = () if output_attentions else None
152
+ all_attn_deformable = () if output_attentions else None
153
+ for i, encoder_layer in enumerate(self.layers):
154
+ if output_hidden_states:
155
+ encoder_vision_states += (vision_features,)
156
+ encoder_text_states += (text_features,)
157
+
158
+ (vision_features, text_features), attentions = encoder_layer(
159
+ vision_features=vision_features,
160
+ vision_position_embedding=vision_position_embedding,
161
+ spatial_shapes=self.spatial_shapes,
162
+ spatial_shapes_list=self.spatial_shapes_list,
163
+ level_start_index=None,
164
+ key_padding_mask=vision_attention_mask,
165
+ reference_points=reference_points,
166
+ text_features=text_features,
167
+ text_attention_mask=text_attention_mask,
168
+ text_position_embedding=self.text_position_embedding,
169
+ text_self_attention_masks=text_self_attention_masks,
170
+ )
171
+ if output_attentions:
172
+ all_attn_fused_vision += (attentions[0],)
173
+ all_attn_fused_text += (attentions[1],)
174
+ all_attn_enhanced_text += (attentions[2],)
175
+ all_attn_deformable += (attentions[3],)
176
+
177
+ if output_hidden_states:
178
+ encoder_vision_states += (vision_features,)
179
+ encoder_text_states += (text_features,)
180
+
181
+ if output_attentions:
182
+ all_attns = (all_attn_fused_vision, all_attn_fused_text, all_attn_enhanced_text, all_attn_deformable)
183
+
184
+ enc_outputs = [vision_features, text_features, encoder_vision_states, encoder_text_states, all_attns]
185
+
186
+ return tuple(v for v in enc_outputs if v is not None)
187
+
188
+
189
+ class _GroundingDinoDecoder(torch.nn.Module):
190
+ def __init__(self, model: "GroundingDinoDecoder", rbln_config: "RBLNGroundingDinoDecoderConfig"):
191
+ super().__init__()
192
+ self.layers = model.layers
193
+ self.config = model.config
194
+ self.spatial_shapes = rbln_config.spatial_shapes
195
+ self.spatial_shapes_list = rbln_config.spatial_shapes_list
196
+ self.rbln_config = rbln_config
197
+ self.reference_points_head = model.reference_points_head
198
+ self.bbox_embed = model.bbox_embed
199
+ self.layer_norm = model.layer_norm
200
+
201
+ @monkey_patch_decorator
202
+ def forward(
203
+ self,
204
+ inputs_embeds,
205
+ vision_encoder_hidden_states,
206
+ vision_encoder_attention_mask=None,
207
+ text_encoder_hidden_states=None,
208
+ text_encoder_attention_mask=None,
209
+ reference_points=None,
210
+ valid_ratios=None,
211
+ ):
212
+ output_attentions = self.rbln_config.output_attentions
213
+ output_hidden_states = self.rbln_config.output_hidden_states
214
+
215
+ if inputs_embeds is not None:
216
+ hidden_states = inputs_embeds
217
+
218
+ # decoder layers
219
+ all_hidden_states = () if output_hidden_states else None
220
+ all_self_attns = () if output_attentions else None
221
+ all_attns = () if output_attentions else None
222
+ all_cross_attns_vision = () if (output_attentions and vision_encoder_hidden_states is not None) else None
223
+ all_cross_attns_text = () if (output_attentions and text_encoder_hidden_states is not None) else None
224
+ intermediate = ()
225
+ intermediate_reference_points = ()
226
+
227
+ if text_encoder_attention_mask is not None:
228
+ text_encoder_attention_mask = text_encoder_attention_mask[:, None, None, :]
229
+ text_encoder_attention_mask = text_encoder_attention_mask.repeat(
230
+ 1, self.config.decoder_attention_heads, self.config.num_queries, 1
231
+ )
232
+ text_encoder_attention_mask = text_encoder_attention_mask
233
+ text_encoder_attention_mask = text_encoder_attention_mask * torch.finfo(torch.float16).min
234
+
235
+ for idx, decoder_layer in enumerate(self.layers):
236
+ num_coordinates = reference_points.shape[-1]
237
+ if num_coordinates == 4:
238
+ reference_points_input = (
239
+ reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
240
+ )
241
+ elif num_coordinates == 2:
242
+ reference_points_input = reference_points[:, :, None] * valid_ratios[:, None]
243
+ else:
244
+ raise ValueError("Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
245
+ _query_pos = get_sine_pos_embed(reference_points_input[:, :, 0, :], num_pos_feats=self.config.d_model // 2)
246
+ query_pos = self.reference_points_head(_query_pos)
247
+
248
+ # In original implementation they apply layer norm before outputting intermediate hidden states
249
+ # Though that's not through between layers so the layers use as input the output of the previous layer
250
+ # withtout layer norm
251
+ if output_hidden_states:
252
+ all_hidden_states += (self.layer_norm(hidden_states),)
253
+
254
+ layer_outputs = decoder_layer(
255
+ hidden_states=hidden_states,
256
+ position_embeddings=query_pos,
257
+ reference_points=reference_points_input,
258
+ spatial_shapes=self.spatial_shapes,
259
+ spatial_shapes_list=self.spatial_shapes_list,
260
+ level_start_index=None,
261
+ vision_encoder_hidden_states=vision_encoder_hidden_states,
262
+ vision_encoder_attention_mask=vision_encoder_attention_mask,
263
+ text_encoder_hidden_states=text_encoder_hidden_states,
264
+ text_encoder_attention_mask=text_encoder_attention_mask,
265
+ self_attn_mask=None,
266
+ output_attentions=output_attentions,
267
+ )
268
+
269
+ hidden_states = layer_outputs[0]
270
+
271
+ # hack implementation for iterative bounding box refinement
272
+ if self.bbox_embed is not None:
273
+ tmp = self.bbox_embed[idx](hidden_states)
274
+ num_coordinates = reference_points.shape[-1]
275
+ if num_coordinates == 4:
276
+ new_reference_points = tmp + torch.special.logit(reference_points, eps=1e-5)
277
+ new_reference_points = new_reference_points.sigmoid()
278
+ elif num_coordinates == 2:
279
+ new_reference_points = tmp
280
+ new_reference_points[..., :2] = tmp[..., :2] + torch.special.logit(reference_points, eps=1e-5)
281
+ new_reference_points = new_reference_points.sigmoid()
282
+ else:
283
+ raise ValueError(
284
+ f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}"
285
+ )
286
+ reference_points = new_reference_points.detach()
287
+
288
+ intermediate += (self.layer_norm(hidden_states),)
289
+ intermediate_reference_points += (reference_points,)
290
+
291
+ if output_attentions:
292
+ all_self_attns += (layer_outputs[1],)
293
+
294
+ if text_encoder_hidden_states is not None:
295
+ all_cross_attns_text += (layer_outputs[2],)
296
+
297
+ if vision_encoder_hidden_states is not None:
298
+ all_cross_attns_vision += (layer_outputs[3],)
299
+
300
+ # Keep batch_size as first dimension
301
+ intermediate = torch.stack(intermediate, dim=1)
302
+ intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
303
+ hidden_states = self.layer_norm(hidden_states)
304
+
305
+ # add hidden states from the last decoder layer
306
+ if output_hidden_states:
307
+ all_hidden_states += (hidden_states,)
308
+
309
+ if output_attentions:
310
+ all_attns += (all_self_attns, all_cross_attns_text, all_cross_attns_vision)
311
+
312
+ return tuple(
313
+ v
314
+ for v in [
315
+ hidden_states,
316
+ intermediate,
317
+ intermediate_reference_points,
318
+ all_hidden_states,
319
+ all_attns,
320
+ ]
321
+ if v is not None
322
+ )
323
+
324
+
325
+ class _GroundingDinoEncoderLayer(torch.nn.Module):
326
+ def forward(
327
+ self,
328
+ vision_features: Tensor,
329
+ vision_position_embedding: Tensor,
330
+ spatial_shapes: Tensor,
331
+ spatial_shapes_list: List[Tuple[int, int]],
332
+ level_start_index: Tensor,
333
+ key_padding_mask: Tensor,
334
+ reference_points: Tensor,
335
+ text_features: Optional[Tensor] = None,
336
+ text_attention_mask: Optional[Tensor] = None,
337
+ text_position_embedding: Optional[Tensor] = None,
338
+ text_self_attention_masks: Optional[Tensor] = None,
339
+ text_position_ids: Optional[Tensor] = None,
340
+ ):
341
+ text_position_embedding = self.get_text_position_embeddings(
342
+ text_features, text_position_embedding, text_position_ids
343
+ )
344
+
345
+ (vision_features, vision_fused_attn), (text_features, text_fused_attn) = self.fusion_layer(
346
+ vision_features=vision_features,
347
+ text_features=text_features,
348
+ attention_mask_vision=key_padding_mask,
349
+ attention_mask_text=text_attention_mask,
350
+ )
351
+
352
+ (text_features, text_enhanced_attn) = self.text_enhancer_layer(
353
+ hidden_states=text_features,
354
+ attention_masks=(1.0 - text_self_attention_masks), # RBLN FIX, change from ~ to 1.0 -
355
+ position_embeddings=(text_position_embedding if text_position_embedding is not None else None),
356
+ )
357
+
358
+ (vision_features, vision_deformable_attn) = self.deformable_layer(
359
+ hidden_states=vision_features,
360
+ attention_mask=(1.0 - key_padding_mask), # RBLN FIX, change from ~ to 1.0 -
361
+ position_embeddings=vision_position_embedding,
362
+ reference_points=reference_points,
363
+ spatial_shapes=spatial_shapes,
364
+ spatial_shapes_list=spatial_shapes_list,
365
+ level_start_index=level_start_index,
366
+ )
367
+
368
+ return (
369
+ (vision_features, text_features),
370
+ (vision_fused_attn, text_fused_attn, text_enhanced_attn, vision_deformable_attn),
371
+ )
372
+
373
+
374
+ class _GroundingDinoMultiscaleDeformableAttention(torch.nn.Module):
375
+ """
376
+ Multiscale deformable attention as proposed in Deformable DETR.
377
+ """
378
+
379
+ def forward(
380
+ self,
381
+ hidden_states: torch.Tensor,
382
+ attention_mask: Optional[torch.Tensor] = None,
383
+ encoder_hidden_states=None,
384
+ encoder_attention_mask=None,
385
+ position_embeddings: Optional[torch.Tensor] = None,
386
+ reference_points=None,
387
+ spatial_shapes=None,
388
+ spatial_shapes_list=None,
389
+ level_start_index=None,
390
+ output_attentions: bool = False,
391
+ ):
392
+ # add position embeddings to the hidden states before projecting to queries and keys
393
+ if position_embeddings is not None:
394
+ hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
395
+
396
+ batch_size, num_queries, _ = hidden_states.shape
397
+ batch_size, sequence_length, _ = encoder_hidden_states.shape
398
+ # Ignore copy
399
+ if torch.compiler.is_exporting():
400
+ torch._check(
401
+ (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum().item() == sequence_length,
402
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states",
403
+ )
404
+ else:
405
+ if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
406
+ raise ValueError(
407
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
408
+ )
409
+
410
+ value = self.value_proj(encoder_hidden_states)
411
+ if attention_mask is not None:
412
+ # RBLN FIX: bool tensor to float tensor
413
+ value = attention_mask * value
414
+
415
+ value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
416
+ sampling_offsets = self.sampling_offsets(hidden_states).view(
417
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
418
+ )
419
+ attention_weights = self.attention_weights(hidden_states).view(
420
+ batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
421
+ )
422
+ attention_weights = F.softmax(attention_weights, -1).view(
423
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
424
+ )
425
+ # batch_size, num_queries, n_heads, n_levels, n_points, 2
426
+ num_coordinates = reference_points.shape[-1]
427
+ if num_coordinates == 2:
428
+ offset_normalizer = 0.5 * torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
429
+ sampling_grids = (
430
+ 2 * reference_points[:, :, None, :, None, :]
431
+ - 1
432
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
433
+ )
434
+ elif num_coordinates == 4:
435
+ ref_points_xy, ref_points_wh = torch.split(reference_points, 2, dim=-1)
436
+ ref_points_xy = ref_points_xy[:, :, None, :, None, :]
437
+ ref_points_wh = ref_points_wh[:, :, None, :, None, :]
438
+ ref_points_grids = 2 * ref_points_xy - 1
439
+ offset_grids = sampling_offsets / self.n_points * ref_points_wh
440
+ sampling_grids = ref_points_grids + offset_grids
441
+
442
+ else:
443
+ raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
444
+
445
+ output = self.attn(
446
+ value,
447
+ spatial_shapes,
448
+ spatial_shapes_list,
449
+ level_start_index,
450
+ sampling_grids,
451
+ attention_weights,
452
+ self.im2col_step,
453
+ )
454
+
455
+ output = self.output_proj(output)
456
+
457
+ return output, attention_weights
458
+
459
+
460
+ class _GroundingDinoBiMultiHeadAttention(torch.nn.Module):
461
+ def forward(
462
+ self,
463
+ vision_features: torch.FloatTensor,
464
+ text_features: torch.FloatTensor,
465
+ vision_attention_mask: Optional[torch.BoolTensor] = None,
466
+ text_attention_mask: Optional[torch.BoolTensor] = None,
467
+ ) -> Tuple[Tuple[torch.FloatTensor, torch.FloatTensor], Tuple[torch.FloatTensor, torch.FloatTensor]]:
468
+ batch_size, tgt_len, _ = vision_features.size()
469
+
470
+ vision_query_states = self.vision_proj(vision_features) * self.scale
471
+ vision_query_states = self._reshape(vision_query_states, tgt_len, batch_size)
472
+
473
+ text_key_states = self.text_proj(text_features)
474
+ text_key_states = self._reshape(text_key_states, -1, batch_size)
475
+
476
+ vision_value_states = self.values_vision_proj(vision_features)
477
+ vision_value_states = self._reshape(vision_value_states, -1, batch_size)
478
+
479
+ text_value_states = self.values_text_proj(text_features)
480
+ text_value_states = self._reshape(text_value_states, -1, batch_size)
481
+
482
+ proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
483
+
484
+ vision_query_states = vision_query_states.view(*proj_shape)
485
+ text_key_states = text_key_states.view(*proj_shape)
486
+ vision_value_states = vision_value_states.view(*proj_shape)
487
+ text_value_states = text_value_states.view(*proj_shape)
488
+
489
+ src_len = text_key_states.size(1)
490
+ attn_weights = torch.bmm(vision_query_states, text_key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt
491
+
492
+ if attn_weights.size() != (batch_size * self.num_heads, tgt_len, src_len):
493
+ raise ValueError(
494
+ f"Attention weights should be of size {(batch_size * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
495
+ )
496
+
497
+ # RBLN FIX: max_values from scalar to vector
498
+ attn_weights = attn_weights - torch.max(attn_weights).reshape(1).repeat(src_len)
499
+ # # Do not increase -50000/50000, data type half has quite limited range
500
+ attn_weights = torch.clamp(attn_weights, min=-50000, max=50000)
501
+
502
+ # RBLN FIX: max_values from scalar to vector
503
+ text_attn_weights = attn_weights - torch.max(attn_weights, dim=1, keepdim=True)[0].repeat(1, tgt_len, 1)
504
+
505
+ # # Do not increase -50000/50000, data type half has quite limited range
506
+ text_attn_weights = torch.clamp(text_attn_weights, min=-50000, max=50000)
507
+
508
+ text_attn_weights = text_attn_weights.transpose(1, 2)
509
+
510
+ # mask vision for language
511
+ if vision_attention_mask is not None:
512
+ # RBLN FIX: bool tensor to float tensor
513
+ mask = vision_attention_mask * torch.finfo(torch.float16).min
514
+ text_attn_weights = text_attn_weights.transpose(1, 2) + mask
515
+ text_attn_weights = text_attn_weights.transpose(1, 2)
516
+
517
+ text_attn_weights = text_attn_weights.softmax(dim=-1)
518
+
519
+ # mask language for vision
520
+ if text_attention_mask is not None:
521
+ text_attention_mask = text_attention_mask[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
522
+ # RBLN FIX: bool tensor to float tensor
523
+ mask = text_attention_mask * torch.finfo(torch.float16).min
524
+ attn_weights = attn_weights + mask
525
+
526
+ vision_attn_weights = attn_weights.softmax(dim=-1)
527
+
528
+ vision_attn_probs = F.dropout(vision_attn_weights, p=self.dropout, training=self.training)
529
+ text_attn_probs = F.dropout(text_attn_weights, p=self.dropout, training=self.training)
530
+
531
+ vision_attn_output = torch.bmm(vision_attn_probs, text_value_states)
532
+ text_attn_output = torch.bmm(text_attn_probs, vision_value_states)
533
+
534
+ if vision_attn_output.size() != (batch_size * self.num_heads, tgt_len, self.head_dim):
535
+ raise ValueError(
536
+ f"`vision_attn_output` should be of size {(batch_size, self.num_heads, tgt_len, self.head_dim)}, but is {vision_attn_output.size()}"
537
+ )
538
+
539
+ if text_attn_output.size() != (batch_size * self.num_heads, src_len, self.head_dim):
540
+ raise ValueError(
541
+ f"`text_attn_output` should be of size {(batch_size, self.num_heads, src_len, self.head_dim)}, but is {text_attn_output.size()}"
542
+ )
543
+
544
+ vision_attn_output = vision_attn_output.view(batch_size, self.num_heads, tgt_len, self.head_dim)
545
+ vision_attn_output = vision_attn_output.transpose(1, 2)
546
+ vision_attn_output = vision_attn_output.reshape(batch_size, tgt_len, self.embed_dim)
547
+
548
+ text_attn_output = text_attn_output.view(batch_size, self.num_heads, src_len, self.head_dim)
549
+ text_attn_output = text_attn_output.transpose(1, 2)
550
+ text_attn_output = text_attn_output.reshape(batch_size, src_len, self.embed_dim)
551
+
552
+ vision_attn_output = self.out_vision_proj(vision_attn_output)
553
+ text_attn_output = self.out_text_proj(text_attn_output)
554
+
555
+ return (vision_attn_output, vision_attn_weights), (text_attn_output, text_attn_weights)
556
+
557
+
558
+ class _MultiScaleDeformableAttention(torch.nn.Module):
559
+ def forward(
560
+ self,
561
+ value: Tensor,
562
+ value_spatial_shapes: Tensor,
563
+ value_spatial_shapes_list: List[Tuple],
564
+ level_start_index: Tensor,
565
+ sampling_grids: Tensor,
566
+ attention_weights: Tensor,
567
+ im2col_step: int,
568
+ ):
569
+ batch_size, _, num_heads, hidden_dim = value.shape
570
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_grids.shape
571
+ value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
572
+ sampling_value_list = []
573
+ sampling_grids_list = [t.squeeze(3) for t in torch.split(sampling_grids, 1, dim=3)]
574
+ for level_id, (height, width) in enumerate(value_spatial_shapes_list):
575
+ value_l_ = (
576
+ value_list[level_id].permute(0, 2, 3, 1).reshape(batch_size * num_heads, hidden_dim, height, width)
577
+ )
578
+ sampling_grid_l_ = sampling_grids_list[level_id].transpose(1, 2).flatten(0, 1)
579
+ sampling_value_l_ = torch.nn.functional.grid_sample(
580
+ value_l_,
581
+ sampling_grid_l_,
582
+ mode="bilinear",
583
+ padding_mode="zeros",
584
+ align_corners=False,
585
+ )
586
+ sampling_value_list.append(sampling_value_l_)
587
+
588
+ sampling_values = torch.cat(sampling_value_list, dim=-1)
589
+ attention_weights_prep = attention_weights.transpose(1, 2)
590
+ values_permuted = sampling_values.permute(0, 2, 3, 1)
591
+
592
+ weights_for_matmul = attention_weights_prep.reshape(
593
+ batch_size * num_heads, num_queries, 1, num_levels * num_points
594
+ )
595
+ output_before_permute = torch.matmul(weights_for_matmul, values_permuted)
596
+ output_before_view = output_before_permute.squeeze(2).permute(0, 2, 1)
597
+ output = output_before_view.reshape(batch_size, num_heads * hidden_dim, num_queries)
598
+
599
+ return output.transpose(1, 2).contiguous()