optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. optimum/rbln/__init__.py +96 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +153 -42
  5. optimum/rbln/diffusers/__init__.py +7 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  12. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  13. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  20. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  21. optimum/rbln/diffusers/models/__init__.py +3 -13
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  23. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
  24. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  25. optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
  26. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
  27. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
  28. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
  29. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  30. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  31. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  32. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  34. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  36. optimum/rbln/modeling.py +71 -19
  37. optimum/rbln/modeling_base.py +99 -21
  38. optimum/rbln/ops/attn.py +158 -0
  39. optimum/rbln/ops/flash_attn.py +166 -0
  40. optimum/rbln/ops/kv_cache_update.py +5 -0
  41. optimum/rbln/ops/linear.py +7 -0
  42. optimum/rbln/transformers/__init__.py +92 -0
  43. optimum/rbln/transformers/configuration_generic.py +9 -7
  44. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  45. optimum/rbln/transformers/modeling_generic.py +51 -9
  46. optimum/rbln/transformers/modeling_outputs.py +37 -0
  47. optimum/rbln/transformers/models/__init__.py +91 -30
  48. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  49. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  50. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  51. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  52. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  53. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  54. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  55. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
  57. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  58. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
  59. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  60. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  61. optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
  62. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  63. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  64. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  65. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  66. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  67. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
  68. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  69. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
  71. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
  72. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  73. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
  74. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  75. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  76. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  77. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  78. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  79. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  80. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  81. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  82. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  83. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  84. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  85. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
  86. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  87. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  88. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  89. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  90. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  91. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  92. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  93. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  94. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  95. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
  96. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  97. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  98. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  99. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  100. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  101. optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
  102. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  103. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
  104. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  105. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  106. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  107. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  108. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  109. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  110. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  111. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  112. optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
  113. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  114. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  115. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  116. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  117. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  118. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  119. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  120. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  121. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  122. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  123. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  124. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
  125. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  126. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  127. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  128. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  129. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  130. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
  131. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  132. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  133. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  134. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
  135. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  136. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  137. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  138. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  139. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
  140. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
  141. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  142. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  143. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  144. optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
  145. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  146. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  147. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  148. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  149. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  150. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  151. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  152. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  153. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
  154. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  155. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  156. optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
  157. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  158. optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
  159. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  160. optimum/rbln/utils/depreacate_utils.py +16 -0
  161. optimum/rbln/utils/runtime_utils.py +28 -18
  162. optimum/rbln/utils/submodule.py +31 -9
  163. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
  164. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
  165. optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
  166. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
  167. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1032 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from torch import Tensor, nn
20
+ from transformers.modeling_utils import no_init_weights
21
+ from transformers.models.grounding_dino.modeling_grounding_dino import (
22
+ GroundingDinoContrastiveEmbedding,
23
+ GroundingDinoConvEncoder,
24
+ GroundingDinoDecoderOutput,
25
+ GroundingDinoEncoderOutput,
26
+ GroundingDinoMLPPredictionHead,
27
+ GroundingDinoModel,
28
+ GroundingDinoModelOutput,
29
+ GroundingDinoObjectDetectionOutput,
30
+ build_position_encoding,
31
+ generate_masks_with_special_tokens_and_transfer_map,
32
+ )
33
+ from transformers.pytorch_utils import meshgrid
34
+
35
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
36
+ from ....modeling import RBLNModel
37
+ from ....utils.runtime_utils import RBLNPytorchRuntime
38
+ from .configuration_grounding_dino import (
39
+ RBLNGroundingDinoDecoderConfig,
40
+ RBLNGroundingDinoEncoderConfig,
41
+ RBLNGroundingDinoForObjectDetectionConfig,
42
+ )
43
+ from .grounding_dino_architecture import (
44
+ _GroundingDinoDecoder,
45
+ _GroundingDinoEncoder,
46
+ )
47
+
48
+
49
+ if TYPE_CHECKING:
50
+ from transformers import (
51
+ AutoFeatureExtractor,
52
+ AutoProcessor,
53
+ AutoTokenizer,
54
+ PreTrainedModel,
55
+ )
56
+
57
+
58
+ class RBLNGroundingDinoForObjectDetection(RBLNModel):
59
+ _rbln_submodules = [
60
+ {"name": "text_backbone"},
61
+ {"name": "backbone"},
62
+ {"name": "encoder"},
63
+ {"name": "decoder"},
64
+ ]
65
+ """
66
+ RBLN optimized Grounding DINO model for object detection.
67
+ This class provides hardware-accelerated inference for Grounding DINO models
68
+ on RBLN devices, supporting multimodal object detection tasks that combine
69
+ vision and language understanding.
70
+
71
+ Grounding DINO is a transformer-based architecture consisting of:
72
+ - A backbone for feature extraction from images
73
+ - An encoder-decoder transformer for processing visual and textual features
74
+ - Object detection heads for predicting bounding boxes and class labels
75
+ """
76
+
77
+ def __post_init__(self, **kwargs):
78
+ self._setup_cpu_instances()
79
+ self.text_projection = RBLNPytorchRuntime(self.model[0])
80
+ self.text_backbone = self.rbln_submodules[0]
81
+ self.backbone = self.rbln_submodules[1]
82
+ self.encoder = self.rbln_submodules[2]
83
+ self.decoder = self.rbln_submodules[3]
84
+
85
+ def _setup_cpu_instances(self):
86
+ stacte_dict = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
87
+ with no_init_weights():
88
+ config = self.config
89
+ _class_embed = GroundingDinoContrastiveEmbedding(config)
90
+ if config.decoder_bbox_embed_share: # True
91
+ _bbox_embed = GroundingDinoMLPPredictionHead(
92
+ input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
93
+ )
94
+ self.bbox_embed = nn.ModuleList([_bbox_embed for _ in range(config.decoder_layers)])
95
+ else:
96
+ for _ in range(config.decoder_layers):
97
+ _bbox_embed = GroundingDinoMLPPredictionHead(
98
+ input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
99
+ )
100
+ self.bbox_embed = nn.ModuleList([_bbox_embed for _ in range(config.decoder_layers)])
101
+ self.class_embed = nn.ModuleList([_class_embed for _ in range(config.decoder_layers)])
102
+
103
+ backbone = GroundingDinoConvEncoder(config)
104
+ self.backbone_position_embedding = build_position_encoding(self.config)
105
+ # Create input projection layers
106
+ if config.num_feature_levels > 1:
107
+ num_backbone_outs = len(backbone.intermediate_channel_sizes)
108
+ input_proj_list = []
109
+ for i in range(num_backbone_outs):
110
+ in_channels = backbone.intermediate_channel_sizes[i]
111
+ input_proj_list.append(
112
+ nn.Sequential(
113
+ nn.Conv2d(in_channels, config.d_model, kernel_size=1),
114
+ nn.GroupNorm(32, config.d_model),
115
+ )
116
+ )
117
+ for _ in range(config.num_feature_levels - num_backbone_outs):
118
+ input_proj_list.append(
119
+ nn.Sequential(
120
+ nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1),
121
+ nn.GroupNorm(32, config.d_model),
122
+ )
123
+ )
124
+ in_channels = config.d_model
125
+ self.input_proj_vision = nn.ModuleList(input_proj_list)
126
+ else:
127
+ self.input_proj_vision = nn.ModuleList(
128
+ [
129
+ nn.Sequential(
130
+ nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1),
131
+ nn.GroupNorm(32, config.d_model),
132
+ )
133
+ ]
134
+ )
135
+
136
+ if config.embedding_init_target or not config.two_stage:
137
+ self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
138
+
139
+ self.level_embed = nn.Parameter(torch.Tensor(config.num_feature_levels, config.d_model))
140
+
141
+ if config.two_stage:
142
+ self.enc_output = nn.Linear(config.d_model, config.d_model)
143
+ self.enc_output_norm = nn.LayerNorm(config.d_model, config.layer_norm_eps)
144
+ if (
145
+ config.two_stage_bbox_embed_share
146
+ and config.decoder_bbox_embed_share
147
+ and self.decoder.bbox_embed is not None
148
+ ):
149
+ self.encoder_output_bbox_embed = self.decoder.bbox_embed
150
+ else:
151
+ self.encoder_output_bbox_embed = GroundingDinoMLPPredictionHead(
152
+ input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
153
+ )
154
+
155
+ self.encoder_output_class_embed = GroundingDinoContrastiveEmbedding(config)
156
+ else:
157
+ self.reference_points = nn.Embedding(config.num_queries, 4)
158
+
159
+ self.bbox_embed.load_state_dict(stacte_dict["bbox_embed"])
160
+ self.class_embed.load_state_dict(stacte_dict["class_embed"])
161
+ self.input_proj_vision.load_state_dict(stacte_dict["input_proj_vision"])
162
+ with torch.no_grad():
163
+ self.level_embed.copy_(stacte_dict["level_embed"])
164
+ if self.config.two_stage:
165
+ self.enc_output.load_state_dict(stacte_dict["enc_output"])
166
+ self.enc_output_norm.load_state_dict(stacte_dict["enc_output_norm"])
167
+ self.encoder_output_class_embed.load_state_dict(stacte_dict["encoder_output_class_embed"])
168
+ self.encoder_output_bbox_embed.load_state_dict(stacte_dict["encoder_output_bbox_embed"])
169
+ else:
170
+ self.reference_points.load_state_dict(stacte_dict["reference_points"])
171
+ if self.config.embedding_init_target or not self.config.two_stage:
172
+ self.query_position_embeddings.load_state_dict(stacte_dict["query_position_embeddings"])
173
+
174
+ if self.config.position_embedding_type == "learned":
175
+ self.backbone_position_embedding.load_state_dict(stacte_dict["backbone_position_embedding"])
176
+
177
+ @classmethod
178
+ def save_torch_artifacts(
179
+ cls,
180
+ model: "PreTrainedModel",
181
+ save_dir_path: Path,
182
+ subfolder: str,
183
+ rbln_config: RBLNGroundingDinoForObjectDetectionConfig,
184
+ ):
185
+ # If you are unavoidably running on a CPU rather than an RBLN device,
186
+ # store the torch tensor, weight, etc. in this function.
187
+ save_dict = {}
188
+ save_dict["input_proj_vision"] = model.model.input_proj_vision.state_dict()
189
+ save_dict["level_embed"] = model.model.level_embed
190
+ if model.config.two_stage:
191
+ save_dict["enc_output"] = model.model.enc_output.state_dict()
192
+ save_dict["enc_output_norm"] = model.model.enc_output_norm.state_dict()
193
+ save_dict["encoder_output_class_embed"] = model.model.encoder_output_class_embed.state_dict()
194
+ save_dict["encoder_output_bbox_embed"] = model.model.encoder_output_bbox_embed.state_dict()
195
+ else:
196
+ save_dict["reference_points"] = model.model.reference_points.state_dict()
197
+ if model.config.embedding_init_target or not model.config.two_stage:
198
+ save_dict["query_position_embeddings"] = model.model.query_position_embeddings.state_dict()
199
+
200
+ if model.config.position_embedding_type == "learned":
201
+ save_dict["backbone_position_embedding"] = model.model.backbone.position_embedding.state_dict()
202
+
203
+ save_dict["class_embed"] = model.class_embed.state_dict()
204
+ save_dict["bbox_embed"] = model.bbox_embed.state_dict()
205
+
206
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
207
+
208
+ @classmethod
209
+ def get_pytorch_model(cls, *args, **kwargs):
210
+ model = super().get_pytorch_model(*args, **kwargs)
211
+ model.encoder = model.model.encoder
212
+ model.decoder = model.model.decoder
213
+ model.text_backbone = model.model.text_backbone
214
+ model.encoder.config = model.config
215
+ model.decoder.config = model.config
216
+ model.backbone = model.model.backbone.conv_encoder.model
217
+ return model
218
+
219
+ @classmethod
220
+ def wrap_model_if_needed(
221
+ cls, model: torch.nn.Module, rbln_config: RBLNGroundingDinoForObjectDetectionConfig
222
+ ) -> torch.nn.Module:
223
+ return model.model.text_projection
224
+
225
+ @classmethod
226
+ def _update_rbln_config(
227
+ cls,
228
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
229
+ model: Optional["PreTrainedModel"] = None,
230
+ model_config: RBLNGroundingDinoForObjectDetectionConfig = None,
231
+ rbln_config: Optional[RBLNGroundingDinoForObjectDetectionConfig] = None,
232
+ ) -> RBLNGroundingDinoForObjectDetectionConfig:
233
+ input_info = [
234
+ (
235
+ "test_features",
236
+ [rbln_config.batch_size, model_config.max_text_len, model_config.text_config.hidden_size],
237
+ "float32",
238
+ ),
239
+ ]
240
+
241
+ rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
242
+ return rbln_config
243
+
244
+ def generate_encoder_output_proposals(self, *args, **kwargs):
245
+ return GroundingDinoModel.generate_encoder_output_proposals(self, *args, **kwargs)
246
+
247
+ def get_valid_ratio(self, *args, **kwargs):
248
+ return GroundingDinoModel.get_valid_ratio(self, *args, **kwargs)
249
+
250
+ def _model_forward(
251
+ self,
252
+ pixel_values: Tensor,
253
+ input_ids: Tensor,
254
+ token_type_ids: Optional[Tensor] = None,
255
+ attention_mask: Optional[Tensor] = None,
256
+ pixel_mask: Optional[Tensor] = None,
257
+ encoder_outputs=None,
258
+ output_attentions=None,
259
+ output_hidden_states=None,
260
+ return_dict=None,
261
+ _init_reference_points=None,
262
+ ):
263
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
264
+ output_hidden_states = (
265
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
266
+ )
267
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
268
+
269
+ text_self_attention_masks, position_ids = generate_masks_with_special_tokens_and_transfer_map(input_ids)
270
+
271
+ text_token_mask = attention_mask.bool() # just to avoid renaming everywhere
272
+
273
+ max_text_len = self.config.max_text_len
274
+ if text_self_attention_masks.shape[1] > max_text_len:
275
+ text_self_attention_masks = text_self_attention_masks[:, :max_text_len, :max_text_len]
276
+ position_ids = position_ids[:, :max_text_len]
277
+ input_ids = input_ids[:, :max_text_len]
278
+ token_type_ids = token_type_ids[:, :max_text_len]
279
+ text_token_mask = text_token_mask[:, :max_text_len]
280
+
281
+ # Extract text features from text backbone
282
+ text_outputs = self.text_backbone(
283
+ input_ids, text_self_attention_masks.to(torch.long), token_type_ids, position_ids, return_dict=return_dict
284
+ )
285
+ text_features = text_outputs.last_hidden_state if return_dict else text_outputs[0]
286
+ text_features = self.text_projection(text_features)
287
+
288
+ batch_size, num_channels, height, width = pixel_values.shape
289
+ device = pixel_values.device
290
+
291
+ if pixel_mask is None:
292
+ pixel_mask = torch.ones(((batch_size, height, width)), dtype=torch.long, device=device)
293
+
294
+ # Extract multi-scale feature maps of same resolution `config.d_model` (cf Figure 4 in paper)
295
+ # First, sent pixel_values + pixel_mask through Backbone to obtain the features
296
+ # which is a list of tuples
297
+ features = self.backbone(pixel_values)[0]
298
+ vision_features = []
299
+ for feature_map in features:
300
+ # downsample pixel_mask to match shape of corresponding feature_map
301
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
302
+ vision_features.append((feature_map, mask))
303
+
304
+ position_embeddings_list = []
305
+ for feature_map, mask in vision_features:
306
+ # position encoding
307
+ position_embeddings_list.append(self.backbone_position_embedding(feature_map, mask).to(feature_map.dtype))
308
+ vision_features, position_embeddings_list
309
+
310
+ # Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
311
+ feature_maps = []
312
+ masks = []
313
+ for level, (source, mask) in enumerate(vision_features):
314
+ feature_maps.append(self.input_proj_vision[level](source))
315
+ masks.append(mask)
316
+
317
+ # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
318
+ if self.config.num_feature_levels > len(feature_maps):
319
+ _len_sources = len(feature_maps)
320
+ for level in range(_len_sources, self.config.num_feature_levels):
321
+ if level == _len_sources:
322
+ source = self.input_proj_vision[level](vision_features[-1][0])
323
+ else:
324
+ source = self.input_proj_vision[level](feature_maps[-1])
325
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=source.shape[-2:]).to(torch.bool)[0]
326
+ pos_l = self.backbone_position_embedding(source, mask).to(source.dtype)
327
+ feature_maps.append(source)
328
+ masks.append(mask)
329
+ position_embeddings_list.append(pos_l)
330
+
331
+ # Create queries
332
+ query_embeds = None
333
+ if self.config.embedding_init_target or self.config.two_stage:
334
+ query_embeds = self.query_position_embeddings.weight
335
+
336
+ # Prepare encoder inputs (by flattening)
337
+ source_flatten = []
338
+ mask_flatten = []
339
+ lvl_pos_embed_flatten = []
340
+ spatial_shapes_list = []
341
+ for level, (source, mask, pos_embed) in enumerate(zip(feature_maps, masks, position_embeddings_list)):
342
+ batch_size, num_channels, height, width = source.shape
343
+ spatial_shape = (height, width)
344
+ spatial_shapes_list.append(spatial_shape)
345
+ source = source.flatten(2).transpose(1, 2)
346
+ mask = mask.flatten(1)
347
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
348
+ lvl_pos_embed = pos_embed + self.level_embed[level].view(1, 1, -1)
349
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
350
+ source_flatten.append(source)
351
+ mask_flatten.append(mask)
352
+ source_flatten = torch.cat(source_flatten, 1)
353
+ mask_flatten = torch.cat(mask_flatten, 1)
354
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
355
+ spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
356
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
357
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
358
+ valid_ratios = valid_ratios.float()
359
+
360
+ # Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
361
+ # Also provide spatial_shapes, level_start_index and valid_ratios
362
+ if encoder_outputs is None:
363
+ encoder_outputs = self.encoder(
364
+ vision_features=source_flatten,
365
+ vision_attention_mask=~mask_flatten,
366
+ vision_position_embedding=lvl_pos_embed_flatten,
367
+ spatial_shapes=spatial_shapes,
368
+ spatial_shapes_list=spatial_shapes_list,
369
+ level_start_index=level_start_index,
370
+ valid_ratios=valid_ratios,
371
+ text_features=text_features,
372
+ text_attention_mask=~text_token_mask,
373
+ text_position_embedding=None,
374
+ text_self_attention_masks=~text_self_attention_masks,
375
+ text_position_ids=position_ids,
376
+ output_attentions=output_attentions,
377
+ output_hidden_states=output_hidden_states,
378
+ return_dict=True,
379
+ )
380
+
381
+ # Fifth, prepare decoder inputs
382
+ topk_proposals = None
383
+ enc_outputs_class = None
384
+ enc_outputs_coord_logits = None
385
+ encoder_logits = None
386
+ encoder_pred_boxes = None
387
+ if self.config.two_stage:
388
+ object_query_embedding, output_proposals = self.generate_encoder_output_proposals(
389
+ encoder_outputs[0], ~mask_flatten, spatial_shapes
390
+ )
391
+
392
+ # hack implementation as in two-stage Deformable DETR
393
+ # apply a detection head to each pixel (A.4 in paper)
394
+ # linear projection for bounding box binary classification (i.e. foreground and background)
395
+ enc_outputs_class = self.encoder_output_class_embed(
396
+ object_query_embedding, encoder_outputs[1], text_token_mask
397
+ )
398
+ # 3-layer FFN to predict bounding boxes coordinates (bbox regression branch)
399
+ delta_bbox = self.encoder_output_bbox_embed(object_query_embedding)
400
+ enc_outputs_coord_logits = delta_bbox + output_proposals
401
+
402
+ # only keep top scoring `config.num_queries` proposals
403
+ topk = self.config.num_queries
404
+ topk_logits = enc_outputs_class.max(-1)[0]
405
+ topk_proposals = torch.topk(topk_logits, topk, dim=1)[1]
406
+ topk_coords_logits = torch.gather(
407
+ enc_outputs_coord_logits, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
408
+ )
409
+
410
+ topk_coords_logits = topk_coords_logits.detach()
411
+ reference_points = (
412
+ topk_coords_logits.sigmoid() if _init_reference_points is None else _init_reference_points
413
+ )
414
+ init_reference_points = reference_points
415
+ if query_embeds is not None:
416
+ target = query_embeds.unsqueeze(0).repeat(batch_size, 1, 1)
417
+ else:
418
+ target = torch.gather(
419
+ object_query_embedding, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)
420
+ ).detach()
421
+
422
+ # Set intermediate topk proposals (coords and class) for loss computation
423
+ encoder_pred_boxes = reference_points
424
+ encoder_logits = self.encoder_output_class_embed(target, text_features, text_token_mask)
425
+ else:
426
+ target = query_embeds.unsqueeze(0).repeat(batch_size, 1, 1)
427
+ reference_points = self.reference_points.weight.unsqueeze(0).repeat(batch_size, 1, 1).sigmoid()
428
+ init_reference_points = reference_points
429
+
430
+ decoder_outputs = self.decoder(
431
+ inputs_embeds=target,
432
+ vision_encoder_hidden_states=encoder_outputs[0],
433
+ vision_encoder_attention_mask=mask_flatten,
434
+ text_encoder_hidden_states=encoder_outputs[1],
435
+ text_encoder_attention_mask=~text_token_mask,
436
+ reference_points=reference_points,
437
+ spatial_shapes=spatial_shapes,
438
+ spatial_shapes_list=spatial_shapes_list,
439
+ level_start_index=level_start_index,
440
+ valid_ratios=valid_ratios,
441
+ self_attn_mask=None,
442
+ output_attentions=output_attentions,
443
+ output_hidden_states=output_hidden_states,
444
+ return_dict=return_dict,
445
+ )
446
+
447
+ if not return_dict:
448
+ enc_outputs = tuple(
449
+ value
450
+ for value in [
451
+ enc_outputs_class,
452
+ enc_outputs_coord_logits,
453
+ encoder_logits,
454
+ encoder_pred_boxes,
455
+ ]
456
+ if value is not None
457
+ )
458
+ tuple_outputs = (
459
+ (decoder_outputs[0], init_reference_points) + decoder_outputs[1:] + encoder_outputs + enc_outputs
460
+ )
461
+
462
+ return tuple_outputs
463
+
464
+ return GroundingDinoModelOutput(
465
+ last_hidden_state=decoder_outputs.last_hidden_state,
466
+ init_reference_points=init_reference_points,
467
+ intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
468
+ intermediate_reference_points=decoder_outputs.intermediate_reference_points,
469
+ decoder_hidden_states=decoder_outputs.hidden_states,
470
+ decoder_attentions=decoder_outputs.attentions,
471
+ encoder_last_hidden_state_vision=encoder_outputs.last_hidden_state_vision,
472
+ encoder_last_hidden_state_text=encoder_outputs.last_hidden_state_text,
473
+ encoder_vision_hidden_states=encoder_outputs.vision_hidden_states,
474
+ encoder_text_hidden_states=encoder_outputs.text_hidden_states,
475
+ encoder_attentions=encoder_outputs.attentions,
476
+ enc_outputs_class=enc_outputs_class,
477
+ enc_outputs_coord_logits=enc_outputs_coord_logits,
478
+ encoder_logits=encoder_logits,
479
+ encoder_pred_boxes=encoder_pred_boxes,
480
+ )
481
+
482
+ def pad_image_to_rbln_config(self, pixel_values: torch.FloatTensor, pixel_mask: torch.BoolTensor):
483
+ batch_size, _, height, width = pixel_values.shape
484
+ image_height, image_width = self.rbln_config.encoder.image_height, self.rbln_config.encoder.image_width
485
+
486
+ pad_h = image_height - height
487
+ pad_w = image_width - width
488
+ pixel_mask = (
489
+ pixel_mask
490
+ if pixel_mask is not None
491
+ else torch.ones(((batch_size, height, width)), dtype=torch.long, device=pixel_values.device)
492
+ )
493
+
494
+ if pad_h < 0 or pad_w < 0:
495
+ raise ValueError(
496
+ f"Image size {height}x{width} is larger than encoder's image_size {image_height}x{image_width}"
497
+ )
498
+
499
+ if pad_h > 0 or pad_w > 0:
500
+ pixel_values = torch.nn.functional.pad(pixel_values, (0, pad_w, 0, pad_h), value=0)
501
+ pixel_mask = torch.nn.functional.pad(pixel_mask, (0, pad_w, 0, pad_h), value=0)
502
+
503
+ return pixel_values, pixel_mask
504
+
505
+ def pad_text_to_rbln_config(
506
+ self,
507
+ input_ids: torch.LongTensor,
508
+ token_type_ids: Optional[torch.LongTensor] = None,
509
+ attention_mask: Optional[torch.LongTensor] = None,
510
+ ):
511
+ batch_size, seq_len = input_ids.shape
512
+ max_text_len = self.config.max_text_len
513
+ token_type_ids = token_type_ids if token_type_ids is not None else torch.zeros_like(input_ids)
514
+ attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids)
515
+ if seq_len < max_text_len:
516
+ input_ids = torch.nn.functional.pad(input_ids, (0, max_text_len - seq_len, 0, 0), value=0)
517
+ token_type_ids = torch.nn.functional.pad(token_type_ids, (0, max_text_len - seq_len, 0, 0), value=0)
518
+ attention_mask = torch.nn.functional.pad(attention_mask, (0, max_text_len - seq_len, 0, 0), value=0)
519
+
520
+ return input_ids, token_type_ids, attention_mask
521
+
522
+ def forward(
523
+ self,
524
+ pixel_values: torch.FloatTensor,
525
+ input_ids: torch.LongTensor,
526
+ token_type_ids: Optional[torch.LongTensor] = None,
527
+ attention_mask: Optional[torch.LongTensor] = None,
528
+ pixel_mask: Optional[torch.BoolTensor] = None,
529
+ encoder_outputs: Optional[Union[GroundingDinoEncoderOutput, Tuple]] = None,
530
+ output_attentions: Optional[bool] = None,
531
+ output_hidden_states: Optional[bool] = None,
532
+ return_dict: Optional[bool] = None,
533
+ labels: List[Dict[str, Union[torch.LongTensor, torch.FloatTensor]]] = None,
534
+ **kwargs,
535
+ ):
536
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
537
+
538
+ # Pad image to rbln_config.image_height and rbln_config.image_width
539
+ pixel_values, pixel_mask = self.pad_image_to_rbln_config(pixel_values, pixel_mask)
540
+ input_ids, token_type_ids, attention_mask = self.pad_text_to_rbln_config(
541
+ input_ids, token_type_ids, attention_mask
542
+ )
543
+
544
+ with torch.inference_mode():
545
+ # First, sent images through Grounding DINO base model to obtain encoder + decoder outputs
546
+ outputs = self._model_forward(
547
+ pixel_values=pixel_values,
548
+ input_ids=input_ids,
549
+ token_type_ids=token_type_ids,
550
+ attention_mask=attention_mask,
551
+ pixel_mask=pixel_mask,
552
+ encoder_outputs=encoder_outputs,
553
+ output_attentions=output_attentions,
554
+ output_hidden_states=output_hidden_states,
555
+ return_dict=return_dict,
556
+ **kwargs,
557
+ )
558
+
559
+ idx = 5 + (1 if output_attentions else 0) + (1 if output_hidden_states else 0)
560
+ enc_text_hidden_state = outputs.encoder_last_hidden_state_text if return_dict else outputs[idx]
561
+ hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[2]
562
+ init_reference_points = outputs.init_reference_points if return_dict else outputs[1]
563
+ inter_references_points = outputs.intermediate_reference_points if return_dict else outputs[3]
564
+
565
+ # class logits + predicted bounding boxes
566
+ outputs_classes = []
567
+ outputs_coords = []
568
+
569
+ # hidden_states are of shape (batch_size, num_stages, height, width)
570
+ # predict class and bounding box deltas for each stage
571
+ num_levels = hidden_states.shape[1]
572
+ for level in range(num_levels):
573
+ if level == 0:
574
+ reference = init_reference_points
575
+ else:
576
+ reference = inter_references_points[:, level - 1]
577
+ reference = torch.special.logit(reference, eps=1e-5)
578
+ outputs_class = self.class_embed[level](
579
+ vision_hidden_state=hidden_states[:, level],
580
+ text_hidden_state=enc_text_hidden_state,
581
+ text_token_mask=attention_mask.bool(),
582
+ )
583
+ delta_bbox = self.bbox_embed[level](hidden_states[:, level])
584
+
585
+ reference_coordinates = reference.shape[-1]
586
+ if reference_coordinates == 4:
587
+ outputs_coord_logits = delta_bbox + reference
588
+ elif reference_coordinates == 2:
589
+ delta_bbox[..., :2] += reference
590
+ outputs_coord_logits = delta_bbox
591
+ else:
592
+ raise ValueError(f"reference.shape[-1] should be 4 or 2, but got {reference.shape[-1]}")
593
+ outputs_coord = outputs_coord_logits.sigmoid()
594
+ outputs_classes.append(outputs_class)
595
+ outputs_coords.append(outputs_coord)
596
+ outputs_class = torch.stack(outputs_classes)
597
+ outputs_coord = torch.stack(outputs_coords)
598
+
599
+ logits = outputs_class[-1]
600
+ pred_boxes = outputs_coord[-1]
601
+
602
+ if not return_dict:
603
+ auxiliary_outputs = []
604
+ output = [logits, pred_boxes, *auxiliary_outputs, *outputs, input_ids]
605
+ output = tuple(out for out in output if out is not None)
606
+ return output
607
+
608
+ return GroundingDinoObjectDetectionOutput(
609
+ logits=logits,
610
+ pred_boxes=pred_boxes,
611
+ last_hidden_state=outputs.last_hidden_state,
612
+ decoder_hidden_states=outputs.decoder_hidden_states,
613
+ decoder_attentions=outputs.decoder_attentions,
614
+ encoder_last_hidden_state_vision=outputs.encoder_last_hidden_state_vision,
615
+ encoder_last_hidden_state_text=outputs.encoder_last_hidden_state_text,
616
+ encoder_vision_hidden_states=outputs.encoder_vision_hidden_states,
617
+ encoder_text_hidden_states=outputs.encoder_text_hidden_states,
618
+ encoder_attentions=outputs.encoder_attentions,
619
+ intermediate_hidden_states=outputs.intermediate_hidden_states,
620
+ intermediate_reference_points=outputs.intermediate_reference_points,
621
+ init_reference_points=outputs.init_reference_points,
622
+ enc_outputs_class=outputs.enc_outputs_class,
623
+ enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
624
+ encoder_logits=outputs.encoder_logits,
625
+ encoder_pred_boxes=outputs.encoder_pred_boxes,
626
+ input_ids=input_ids,
627
+ )
628
+
629
+
630
+ def _update_spatial_shapes(model_config, rbln_config):
631
+ def down_sampled_size(x, depth: int = 1):
632
+ if depth == 0:
633
+ return x
634
+ return down_sampled_size((x + 1) // 2, depth - 1)
635
+
636
+ def num_patches(image_size, patch_size):
637
+ return (image_size + patch_size - 1) // patch_size
638
+
639
+ # update spatial_shapes
640
+ spatial_shapes = []
641
+ backbone_config = model_config.backbone_config
642
+ num_patched_h = num_patches(rbln_config.image_height, backbone_config.patch_size)
643
+ num_patched_w = num_patches(rbln_config.image_height, backbone_config.patch_size)
644
+ for out_layer in backbone_config.out_indices:
645
+ spatial_shapes.append(
646
+ [down_sampled_size(num_patched_h, out_layer - 1), down_sampled_size(num_patched_w, out_layer - 1)]
647
+ )
648
+
649
+ # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
650
+ if model_config.num_feature_levels > len(spatial_shapes):
651
+ last_h, last_w = spatial_shapes[-1][0], spatial_shapes[-1][1]
652
+ h_out = (last_h - 1) // 2 + 1
653
+ w_out = (last_w - 1) // 2 + 1
654
+ spatial_shapes.append([h_out, w_out])
655
+
656
+ rbln_config.spatial_shapes_list = spatial_shapes
657
+
658
+ return rbln_config
659
+
660
+
661
+ class RBLNGroundingDinoEncoder(RBLNModel):
662
+ def __post_init__(self, **kwargs):
663
+ self.encoder_runtime = RBLNPytorchRuntime(self.model[0])
664
+
665
+ @classmethod
666
+ def wrap_model_if_needed(
667
+ cls, model: torch.nn.Module, rbln_config: RBLNGroundingDinoForObjectDetectionConfig
668
+ ) -> torch.nn.Module:
669
+ model = _GroundingDinoEncoder(model, rbln_config).eval()
670
+ return model
671
+
672
+ @classmethod
673
+ def _update_submodule_config(
674
+ cls,
675
+ model: "PreTrainedModel",
676
+ rbln_config: RBLNModelConfig,
677
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
678
+ ):
679
+ for processor in preprocessors:
680
+ if rbln_config.image_size is None and hasattr(processor, "image_processor"):
681
+ if "height" in processor.image_processor.size and "width" in processor.image_processor.size:
682
+ rbln_config.image_size = (
683
+ processor.image_processor.size["height"],
684
+ processor.image_processor.size["width"],
685
+ )
686
+ elif (
687
+ "longest_edge" in processor.image_processor.size
688
+ and "shortest_edge" in processor.image_processor.size
689
+ ):
690
+ rbln_config.image_size = processor.image_processor.size["longest_edge"]
691
+ elif "shortest_edge" in processor.image_processor.size:
692
+ rbln_config.image_size = processor.image_processor.size["shortest_edge"]
693
+ break
694
+ rbln_config = _update_spatial_shapes(model.config, rbln_config)
695
+ return rbln_config
696
+
697
+ @classmethod
698
+ def _update_rbln_config(
699
+ cls,
700
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
701
+ model: Optional["PreTrainedModel"] = None,
702
+ model_config: RBLNGroundingDinoEncoderConfig = None,
703
+ rbln_config: Optional[RBLNGroundingDinoEncoderConfig] = None,
704
+ ) -> RBLNGroundingDinoEncoderConfig:
705
+ if rbln_config.image_size is None:
706
+ raise ValueError("RBLN config must have image_size set for RBLN optimized GroundingDinoDecoder.")
707
+
708
+ vision_seq_len = int((rbln_config.spatial_shapes[:, 0] * rbln_config.spatial_shapes[:, 1]).sum())
709
+
710
+ input_info = [
711
+ (
712
+ "vision_features",
713
+ [rbln_config.batch_size, vision_seq_len, model_config.d_model],
714
+ "float32",
715
+ ),
716
+ (
717
+ "vision_attention_mask",
718
+ [
719
+ rbln_config.batch_size,
720
+ vision_seq_len,
721
+ model_config.d_model,
722
+ ],
723
+ "float32",
724
+ ),
725
+ (
726
+ "vision_position_embedding",
727
+ [rbln_config.batch_size, vision_seq_len, model_config.d_model],
728
+ "float32",
729
+ ),
730
+ (
731
+ "text_features",
732
+ [rbln_config.batch_size, model_config.max_text_len, model_config.d_model],
733
+ "float32",
734
+ ),
735
+ (
736
+ "text_attention_mask",
737
+ [
738
+ rbln_config.batch_size,
739
+ model_config.max_text_len,
740
+ ],
741
+ "float32",
742
+ ),
743
+ (
744
+ "text_self_attention_masks",
745
+ [
746
+ rbln_config.batch_size,
747
+ model_config.max_text_len,
748
+ model_config.max_text_len,
749
+ ],
750
+ "float32",
751
+ ),
752
+ (
753
+ "reference_points",
754
+ [rbln_config.batch_size, vision_seq_len, 4, 2],
755
+ "float32",
756
+ ),
757
+ ]
758
+
759
+ rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
760
+
761
+ return rbln_config
762
+
763
+ @staticmethod
764
+ def get_reference_points(spatial_shapes, valid_ratios, device):
765
+ reference_points_list = []
766
+ for level, (height, width) in enumerate(spatial_shapes):
767
+ ref_y, ref_x = meshgrid(
768
+ torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
769
+ torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device),
770
+ indexing="ij",
771
+ )
772
+ # TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
773
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height)
774
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, level, 0] * width)
775
+ ref = torch.stack((ref_x, ref_y), -1)
776
+ reference_points_list.append(ref)
777
+ reference_points = torch.cat(reference_points_list, 1)
778
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
779
+ return reference_points
780
+
781
+ def validate_output_config(self, output_attentions, output_hidden_states):
782
+ if output_attentions != self.rbln_config.output_attentions:
783
+ raise ValueError(
784
+ f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
785
+ f"Please compile again with the correct argument."
786
+ )
787
+
788
+ if output_hidden_states != self.rbln_config.output_hidden_states:
789
+ raise ValueError(
790
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
791
+ f"Please compile again with the correct argument."
792
+ )
793
+
794
+ def forward(
795
+ self,
796
+ vision_features: Tensor,
797
+ vision_attention_mask: Tensor,
798
+ vision_position_embedding: Tensor,
799
+ spatial_shapes: Tensor,
800
+ spatial_shapes_list: List[Tuple[int, int]],
801
+ level_start_index: Tensor,
802
+ valid_ratios: Optional[Tensor] = None,
803
+ text_features: Optional[Tensor] = None,
804
+ text_attention_mask: Optional[Tensor] = None,
805
+ text_position_embedding: Optional[Tensor] = None,
806
+ text_self_attention_masks: Optional[Tensor] = None,
807
+ text_position_ids: Optional[Tensor] = None,
808
+ output_attentions: Optional[bool] = None,
809
+ output_hidden_states: Optional[bool] = None,
810
+ return_dict: Optional[bool] = None,
811
+ ):
812
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
813
+ output_hidden_states = (
814
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
815
+ )
816
+ self.validate_output_config(output_attentions, output_hidden_states)
817
+
818
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
819
+ reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device="cpu")
820
+ vision_attention_mask = vision_attention_mask.to(torch.float32).unsqueeze(-1).repeat(1, 1, self.config.d_model)
821
+
822
+ enc_outputs = self.encoder_runtime(
823
+ vision_features=vision_features,
824
+ vision_attention_mask=vision_attention_mask,
825
+ vision_position_embedding=vision_position_embedding,
826
+ text_features=text_features,
827
+ text_attention_mask=text_attention_mask.to(torch.float32),
828
+ text_self_attention_masks=text_self_attention_masks.to(torch.float32),
829
+ reference_points=reference_points,
830
+ )
831
+
832
+ if not return_dict:
833
+ return tuple(enc_outputs)
834
+
835
+ enc_outputs = list(enc_outputs)
836
+ last_hidden_state_vision = enc_outputs.pop(0)
837
+ last_hidden_state_text = enc_outputs.pop(0)
838
+ vision_hidden_states = (
839
+ tuple([enc_outputs.pop(0) for _ in range(self.config.encoder_layers + 1)])
840
+ if self.rbln_config.output_hidden_states
841
+ else None
842
+ )
843
+ text_hidden_states = (
844
+ tuple([enc_outputs.pop(0) for _ in range(self.config.encoder_layers + 1)])
845
+ if self.rbln_config.output_hidden_states
846
+ else None
847
+ )
848
+ attentions = tuple(enc_outputs) if self.rbln_config.output_attentions else None
849
+
850
+ return GroundingDinoEncoderOutput(
851
+ last_hidden_state_vision=last_hidden_state_vision,
852
+ last_hidden_state_text=last_hidden_state_text,
853
+ vision_hidden_states=vision_hidden_states,
854
+ text_hidden_states=text_hidden_states,
855
+ attentions=attentions,
856
+ )
857
+
858
+
859
+ class RBLNGroundingDinoDecoder(RBLNModel):
860
+ def __post_init__(self, **kwargs):
861
+ self.decoder_runtime = RBLNPytorchRuntime(self.model[0])
862
+
863
+ @classmethod
864
+ def wrap_model_if_needed(
865
+ cls, model: torch.nn.Module, rbln_config: RBLNGroundingDinoForObjectDetectionConfig
866
+ ) -> torch.nn.Module:
867
+ return _GroundingDinoDecoder(model, rbln_config).eval()
868
+
869
+ @classmethod
870
+ def _update_submodule_config(
871
+ cls,
872
+ model: "PreTrainedModel",
873
+ rbln_config: RBLNModelConfig,
874
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
875
+ ):
876
+ for processor in preprocessors:
877
+ if rbln_config.image_size is None and hasattr(processor, "image_processor"):
878
+ if "height" in processor.image_processor.size and "width" in processor.image_processor.size:
879
+ rbln_config.image_size = (
880
+ processor.image_processor.size["height"],
881
+ processor.image_processor.size["width"],
882
+ )
883
+ elif (
884
+ "longest_edge" in processor.image_processor.size
885
+ and "shortest_edge" in processor.image_processor.size
886
+ ):
887
+ rbln_config.image_size = processor.image_processor.size["longest_edge"]
888
+ elif "shortest_edge" in processor.image_processor.size:
889
+ rbln_config.image_size = processor.image_processor.size["shortest_edge"]
890
+ break
891
+ rbln_config = _update_spatial_shapes(model.config, rbln_config)
892
+
893
+ return rbln_config
894
+
895
+ @classmethod
896
+ def _update_rbln_config(
897
+ cls,
898
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
899
+ model: Optional["PreTrainedModel"] = None,
900
+ model_config: RBLNGroundingDinoDecoderConfig = None,
901
+ rbln_config: Optional[RBLNGroundingDinoEncoderConfig] = None,
902
+ ) -> RBLNGroundingDinoEncoderConfig:
903
+ if rbln_config.image_size is None:
904
+ raise ValueError("RBLN config must have image_size set for RBLN optimized GroundingDinoDecoder.")
905
+
906
+ vision_seq_len = int((rbln_config.spatial_shapes[:, 0] * rbln_config.spatial_shapes[:, 1]).sum())
907
+
908
+ input_info = [
909
+ (
910
+ "inputs_embeds",
911
+ [rbln_config.batch_size, model_config.num_queries, model_config.d_model],
912
+ "float32",
913
+ ),
914
+ (
915
+ "vision_encoder_hidden_states",
916
+ [
917
+ rbln_config.batch_size,
918
+ vision_seq_len,
919
+ model_config.d_model,
920
+ ],
921
+ "float32",
922
+ ),
923
+ (
924
+ "vision_encoder_attention_mask",
925
+ [rbln_config.batch_size, vision_seq_len, model_config.d_model],
926
+ "float32",
927
+ ),
928
+ (
929
+ "text_encoder_hidden_states",
930
+ [rbln_config.batch_size, model_config.max_text_len, model_config.d_model],
931
+ "float32",
932
+ ),
933
+ (
934
+ "text_encoder_attention_mask",
935
+ [
936
+ rbln_config.batch_size,
937
+ model_config.max_text_len,
938
+ ],
939
+ "float32",
940
+ ),
941
+ (
942
+ "reference_points",
943
+ [
944
+ rbln_config.batch_size,
945
+ model_config.num_queries,
946
+ 4,
947
+ ],
948
+ "float32",
949
+ ),
950
+ (
951
+ "valid_ratios",
952
+ [
953
+ rbln_config.batch_size,
954
+ 4,
955
+ 2,
956
+ ],
957
+ "float32",
958
+ ),
959
+ ]
960
+
961
+ rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
962
+ return rbln_config
963
+
964
+ def validate_output_config(self, output_attentions, output_hidden_states):
965
+ if output_attentions != self.rbln_config.output_attentions:
966
+ raise ValueError(
967
+ f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
968
+ f"Please compile again with the correct argument."
969
+ )
970
+ if output_hidden_states != self.rbln_config.output_hidden_states:
971
+ raise ValueError(
972
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
973
+ f"Please compile again with the correct argument."
974
+ )
975
+
976
+ def forward(
977
+ self,
978
+ inputs_embeds: torch.Tensor,
979
+ vision_encoder_hidden_states: torch.Tensor,
980
+ vision_encoder_attention_mask: torch.Tensor,
981
+ text_encoder_hidden_states: torch.Tensor,
982
+ text_encoder_attention_mask: torch.Tensor,
983
+ reference_points: torch.Tensor,
984
+ valid_ratios: torch.Tensor,
985
+ output_attentions: bool = False,
986
+ output_hidden_states: bool = False,
987
+ return_dict: bool = False,
988
+ **kwargs,
989
+ ):
990
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
991
+ output_hidden_states = (
992
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
993
+ )
994
+ self.validate_output_config(output_attentions, output_hidden_states)
995
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
996
+
997
+ reshaped_vision_encoder_attention_mask = (
998
+ vision_encoder_attention_mask[:, :, None].repeat(1, 1, self.config.d_model).to(torch.float32)
999
+ )
1000
+
1001
+ # Forward pass through the decoder
1002
+ outputs = self.decoder_runtime(
1003
+ inputs_embeds=inputs_embeds,
1004
+ vision_encoder_hidden_states=vision_encoder_hidden_states,
1005
+ vision_encoder_attention_mask=reshaped_vision_encoder_attention_mask,
1006
+ text_encoder_hidden_states=text_encoder_hidden_states,
1007
+ text_encoder_attention_mask=text_encoder_attention_mask.to(torch.float32),
1008
+ reference_points=reference_points,
1009
+ valid_ratios=valid_ratios,
1010
+ )
1011
+
1012
+ if not return_dict:
1013
+ return outputs
1014
+
1015
+ outputs = list(outputs)
1016
+ last_hidden_state = outputs.pop(0)
1017
+ intermediate_hidden_states = outputs.pop(0)
1018
+ intermediate_reference_points = outputs.pop(0)
1019
+ hidden_states = (
1020
+ tuple([outputs.pop(0) for _ in range(self.config.decoder_layers + 1)])
1021
+ if self.rbln_config.output_hidden_states
1022
+ else None
1023
+ )
1024
+ attentions = tuple(outputs) if self.rbln_config.output_attentions else None
1025
+
1026
+ return GroundingDinoDecoderOutput(
1027
+ last_hidden_state=last_hidden_state,
1028
+ intermediate_hidden_states=intermediate_hidden_states,
1029
+ intermediate_reference_points=intermediate_reference_points,
1030
+ hidden_states=hidden_states,
1031
+ attentions=attentions,
1032
+ )