optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. optimum/rbln/__init__.py +116 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +171 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +33 -18
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +50 -24
  52. optimum/rbln/modeling_base.py +116 -35
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +100 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +93 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  158. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  159. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  160. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  161. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  162. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  163. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  164. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
  165. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  166. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  167. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  168. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  169. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  170. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  171. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  172. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  173. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  174. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  175. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  176. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
  177. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  178. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  179. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  180. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  181. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  182. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  183. optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
  184. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  185. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  186. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  187. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  188. optimum/rbln/utils/deprecation.py +213 -0
  189. optimum/rbln/utils/hub.py +22 -50
  190. optimum/rbln/utils/runtime_utils.py +85 -17
  191. optimum/rbln/utils/submodule.py +31 -9
  192. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  193. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  194. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  195. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  196. optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
  197. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1048 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from torch import Tensor, nn
20
+ from transformers.modeling_utils import no_init_weights
21
+ from transformers.models.grounding_dino.modeling_grounding_dino import (
22
+ GroundingDinoContrastiveEmbedding,
23
+ GroundingDinoConvEncoder,
24
+ GroundingDinoDecoderOutput,
25
+ GroundingDinoEncoderOutput,
26
+ GroundingDinoMLPPredictionHead,
27
+ GroundingDinoModel,
28
+ GroundingDinoModelOutput,
29
+ GroundingDinoObjectDetectionOutput,
30
+ build_position_encoding,
31
+ generate_masks_with_special_tokens_and_transfer_map,
32
+ )
33
+ from transformers.pytorch_utils import meshgrid
34
+
35
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
36
+ from ....modeling import RBLNModel
37
+ from ....utils.runtime_utils import RBLNPytorchRuntime
38
+ from .configuration_grounding_dino import (
39
+ RBLNGroundingDinoDecoderConfig,
40
+ RBLNGroundingDinoEncoderConfig,
41
+ RBLNGroundingDinoForObjectDetectionConfig,
42
+ )
43
+ from .grounding_dino_architecture import (
44
+ _GroundingDinoDecoder,
45
+ _GroundingDinoEncoder,
46
+ )
47
+
48
+
49
+ if TYPE_CHECKING:
50
+ from transformers import (
51
+ AutoFeatureExtractor,
52
+ AutoProcessor,
53
+ AutoTokenizer,
54
+ PreTrainedModel,
55
+ )
56
+
57
+
58
+ class RBLNGroundingDinoForObjectDetection(RBLNModel):
59
+ _rbln_submodules = [
60
+ {"name": "text_backbone"},
61
+ {"name": "backbone"},
62
+ {"name": "encoder"},
63
+ {"name": "decoder"},
64
+ ]
65
+ """
66
+ RBLN optimized Grounding DINO model for object detection.
67
+ This class provides hardware-accelerated inference for Grounding DINO models
68
+ on RBLN devices, supporting multimodal object detection tasks that combine
69
+ vision and language understanding.
70
+
71
+ Grounding DINO is a transformer-based architecture consisting of:
72
+ - A backbone for feature extraction from images
73
+ - An encoder-decoder transformer for processing visual and textual features
74
+ - Object detection heads for predicting bounding boxes and class labels
75
+ """
76
+
77
+ def __post_init__(self, **kwargs):
78
+ self._setup_cpu_instances()
79
+ self.text_projection = RBLNPytorchRuntime(self.model[0])
80
+ self.text_backbone = self.rbln_submodules[0]
81
+ self.backbone = self.rbln_submodules[1]
82
+ self.encoder = self.rbln_submodules[2]
83
+ self.decoder = self.rbln_submodules[3]
84
+
85
+ def _setup_cpu_instances(self):
86
+ stacte_dict = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
87
+ with no_init_weights():
88
+ config = self.config
89
+ _class_embed = GroundingDinoContrastiveEmbedding(config)
90
+ if config.decoder_bbox_embed_share: # True
91
+ _bbox_embed = GroundingDinoMLPPredictionHead(
92
+ input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
93
+ )
94
+ self.bbox_embed = nn.ModuleList([_bbox_embed for _ in range(config.decoder_layers)])
95
+ else:
96
+ for _ in range(config.decoder_layers):
97
+ _bbox_embed = GroundingDinoMLPPredictionHead(
98
+ input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
99
+ )
100
+ self.bbox_embed = nn.ModuleList([_bbox_embed for _ in range(config.decoder_layers)])
101
+ self.class_embed = nn.ModuleList([_class_embed for _ in range(config.decoder_layers)])
102
+
103
+ backbone = GroundingDinoConvEncoder(config)
104
+ self.backbone_position_embedding = build_position_encoding(self.config)
105
+ # Create input projection layers
106
+ if config.num_feature_levels > 1:
107
+ num_backbone_outs = len(backbone.intermediate_channel_sizes)
108
+ input_proj_list = []
109
+ for i in range(num_backbone_outs):
110
+ in_channels = backbone.intermediate_channel_sizes[i]
111
+ input_proj_list.append(
112
+ nn.Sequential(
113
+ nn.Conv2d(in_channels, config.d_model, kernel_size=1),
114
+ nn.GroupNorm(32, config.d_model),
115
+ )
116
+ )
117
+ for _ in range(config.num_feature_levels - num_backbone_outs):
118
+ input_proj_list.append(
119
+ nn.Sequential(
120
+ nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1),
121
+ nn.GroupNorm(32, config.d_model),
122
+ )
123
+ )
124
+ in_channels = config.d_model
125
+ self.input_proj_vision = nn.ModuleList(input_proj_list)
126
+ else:
127
+ self.input_proj_vision = nn.ModuleList(
128
+ [
129
+ nn.Sequential(
130
+ nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1),
131
+ nn.GroupNorm(32, config.d_model),
132
+ )
133
+ ]
134
+ )
135
+
136
+ if config.embedding_init_target or not config.two_stage:
137
+ self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
138
+
139
+ self.level_embed = nn.Parameter(torch.Tensor(config.num_feature_levels, config.d_model))
140
+
141
+ if config.two_stage:
142
+ self.enc_output = nn.Linear(config.d_model, config.d_model)
143
+ self.enc_output_norm = nn.LayerNorm(config.d_model, config.layer_norm_eps)
144
+ if (
145
+ config.two_stage_bbox_embed_share
146
+ and config.decoder_bbox_embed_share
147
+ and self.decoder.bbox_embed is not None
148
+ ):
149
+ self.encoder_output_bbox_embed = self.decoder.bbox_embed
150
+ else:
151
+ self.encoder_output_bbox_embed = GroundingDinoMLPPredictionHead(
152
+ input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
153
+ )
154
+
155
+ self.encoder_output_class_embed = GroundingDinoContrastiveEmbedding(config)
156
+ else:
157
+ self.reference_points = nn.Embedding(config.num_queries, 4)
158
+
159
+ self.bbox_embed.load_state_dict(stacte_dict["bbox_embed"])
160
+ self.class_embed.load_state_dict(stacte_dict["class_embed"])
161
+ self.input_proj_vision.load_state_dict(stacte_dict["input_proj_vision"])
162
+ with torch.no_grad():
163
+ self.level_embed.copy_(stacte_dict["level_embed"])
164
+ if self.config.two_stage:
165
+ self.enc_output.load_state_dict(stacte_dict["enc_output"])
166
+ self.enc_output_norm.load_state_dict(stacte_dict["enc_output_norm"])
167
+ self.encoder_output_class_embed.load_state_dict(stacte_dict["encoder_output_class_embed"])
168
+ self.encoder_output_bbox_embed.load_state_dict(stacte_dict["encoder_output_bbox_embed"])
169
+ else:
170
+ self.reference_points.load_state_dict(stacte_dict["reference_points"])
171
+ if self.config.embedding_init_target or not self.config.two_stage:
172
+ self.query_position_embeddings.load_state_dict(stacte_dict["query_position_embeddings"])
173
+
174
+ if self.config.position_embedding_type == "learned":
175
+ self.backbone_position_embedding.load_state_dict(stacte_dict["backbone_position_embedding"])
176
+
177
+ @classmethod
178
+ def save_torch_artifacts(
179
+ cls,
180
+ model: "PreTrainedModel",
181
+ save_dir_path: Path,
182
+ subfolder: str,
183
+ rbln_config: RBLNGroundingDinoForObjectDetectionConfig,
184
+ ):
185
+ # If you are unavoidably running on a CPU rather than an RBLN device,
186
+ # store the torch tensor, weight, etc. in this function.
187
+ save_dict = {}
188
+ save_dict["input_proj_vision"] = model.model.input_proj_vision.state_dict()
189
+ save_dict["level_embed"] = model.model.level_embed
190
+ if model.config.two_stage:
191
+ save_dict["enc_output"] = model.model.enc_output.state_dict()
192
+ save_dict["enc_output_norm"] = model.model.enc_output_norm.state_dict()
193
+ save_dict["encoder_output_class_embed"] = model.model.encoder_output_class_embed.state_dict()
194
+ save_dict["encoder_output_bbox_embed"] = model.model.encoder_output_bbox_embed.state_dict()
195
+ else:
196
+ save_dict["reference_points"] = model.model.reference_points.state_dict()
197
+ if model.config.embedding_init_target or not model.config.two_stage:
198
+ save_dict["query_position_embeddings"] = model.model.query_position_embeddings.state_dict()
199
+
200
+ if model.config.position_embedding_type == "learned":
201
+ save_dict["backbone_position_embedding"] = model.model.backbone.position_embedding.state_dict()
202
+
203
+ save_dict["class_embed"] = model.class_embed.state_dict()
204
+ save_dict["bbox_embed"] = model.bbox_embed.state_dict()
205
+
206
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
207
+
208
+ @classmethod
209
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
210
+ model.encoder = model.model.encoder
211
+ model.decoder = model.model.decoder
212
+ model.text_backbone = model.model.text_backbone
213
+ model.encoder.config = model.config
214
+ model.decoder.config = model.config
215
+ model.backbone = model.model.backbone.conv_encoder.model
216
+ return model
217
+
218
+ @classmethod
219
+ def _wrap_model_if_needed(
220
+ cls, model: torch.nn.Module, rbln_config: RBLNGroundingDinoForObjectDetectionConfig
221
+ ) -> torch.nn.Module:
222
+ return model.model.text_projection
223
+
224
+ @classmethod
225
+ def _update_rbln_config(
226
+ cls,
227
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
228
+ model: Optional["PreTrainedModel"] = None,
229
+ model_config: RBLNGroundingDinoForObjectDetectionConfig = None,
230
+ rbln_config: Optional[RBLNGroundingDinoForObjectDetectionConfig] = None,
231
+ ) -> RBLNGroundingDinoForObjectDetectionConfig:
232
+ input_info = [
233
+ (
234
+ "test_features",
235
+ [rbln_config.batch_size, model_config.max_text_len, model_config.text_config.hidden_size],
236
+ "float32",
237
+ ),
238
+ ]
239
+
240
+ rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
241
+ return rbln_config
242
+
243
+ def generate_encoder_output_proposals(self, *args, **kwargs):
244
+ return GroundingDinoModel.generate_encoder_output_proposals(self, *args, **kwargs)
245
+
246
+ def get_valid_ratio(self, *args, **kwargs):
247
+ return GroundingDinoModel.get_valid_ratio(self, *args, **kwargs)
248
+
249
+ def _model_forward(
250
+ self,
251
+ pixel_values: Tensor,
252
+ input_ids: Tensor,
253
+ token_type_ids: Optional[Tensor] = None,
254
+ attention_mask: Optional[Tensor] = None,
255
+ pixel_mask: Optional[Tensor] = None,
256
+ encoder_outputs=None,
257
+ output_attentions=None,
258
+ output_hidden_states=None,
259
+ return_dict=None,
260
+ _init_reference_points=None,
261
+ ):
262
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
263
+ output_hidden_states = (
264
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
265
+ )
266
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
267
+
268
+ text_self_attention_masks, position_ids = generate_masks_with_special_tokens_and_transfer_map(input_ids)
269
+
270
+ text_token_mask = attention_mask.bool() # just to avoid renaming everywhere
271
+
272
+ max_text_len = self.config.max_text_len
273
+ if text_self_attention_masks.shape[1] > max_text_len:
274
+ text_self_attention_masks = text_self_attention_masks[:, :max_text_len, :max_text_len]
275
+ position_ids = position_ids[:, :max_text_len]
276
+ input_ids = input_ids[:, :max_text_len]
277
+ token_type_ids = token_type_ids[:, :max_text_len]
278
+ text_token_mask = text_token_mask[:, :max_text_len]
279
+
280
+ # Extract text features from text backbone
281
+ text_outputs = self.text_backbone(
282
+ input_ids, text_self_attention_masks.to(torch.long), token_type_ids, position_ids, return_dict=return_dict
283
+ )
284
+ text_features = text_outputs.last_hidden_state if return_dict else text_outputs[0]
285
+ text_features = self.text_projection(text_features)
286
+
287
+ batch_size, num_channels, height, width = pixel_values.shape
288
+ device = pixel_values.device
289
+
290
+ if pixel_mask is None:
291
+ pixel_mask = torch.ones(((batch_size, height, width)), dtype=torch.long, device=device)
292
+
293
+ # Extract multi-scale feature maps of same resolution `config.d_model` (cf Figure 4 in paper)
294
+ # First, sent pixel_values + pixel_mask through Backbone to obtain the features
295
+ # which is a list of tuples
296
+ features = self.backbone(pixel_values)[0]
297
+ vision_features = []
298
+ for feature_map in features:
299
+ # downsample pixel_mask to match shape of corresponding feature_map
300
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
301
+ vision_features.append((feature_map, mask))
302
+
303
+ position_embeddings_list = []
304
+ for feature_map, mask in vision_features:
305
+ # position encoding
306
+ position_embeddings_list.append(self.backbone_position_embedding(feature_map, mask).to(feature_map.dtype))
307
+ vision_features, position_embeddings_list
308
+
309
+ # Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
310
+ feature_maps = []
311
+ masks = []
312
+ for level, (source, mask) in enumerate(vision_features):
313
+ feature_maps.append(self.input_proj_vision[level](source))
314
+ masks.append(mask)
315
+
316
+ # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
317
+ if self.config.num_feature_levels > len(feature_maps):
318
+ _len_sources = len(feature_maps)
319
+ for level in range(_len_sources, self.config.num_feature_levels):
320
+ if level == _len_sources:
321
+ source = self.input_proj_vision[level](vision_features[-1][0])
322
+ else:
323
+ source = self.input_proj_vision[level](feature_maps[-1])
324
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=source.shape[-2:]).to(torch.bool)[0]
325
+ pos_l = self.backbone_position_embedding(source, mask).to(source.dtype)
326
+ feature_maps.append(source)
327
+ masks.append(mask)
328
+ position_embeddings_list.append(pos_l)
329
+
330
+ # Create queries
331
+ query_embeds = None
332
+ if self.config.embedding_init_target or self.config.two_stage:
333
+ query_embeds = self.query_position_embeddings.weight
334
+
335
+ # Prepare encoder inputs (by flattening)
336
+ source_flatten = []
337
+ mask_flatten = []
338
+ lvl_pos_embed_flatten = []
339
+ spatial_shapes_list = []
340
+ for level, (source, mask, pos_embed) in enumerate(zip(feature_maps, masks, position_embeddings_list)):
341
+ batch_size, num_channels, height, width = source.shape
342
+ spatial_shape = (height, width)
343
+ spatial_shapes_list.append(spatial_shape)
344
+ source = source.flatten(2).transpose(1, 2)
345
+ mask = mask.flatten(1)
346
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
347
+ lvl_pos_embed = pos_embed + self.level_embed[level].view(1, 1, -1)
348
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
349
+ source_flatten.append(source)
350
+ mask_flatten.append(mask)
351
+ source_flatten = torch.cat(source_flatten, 1)
352
+ mask_flatten = torch.cat(mask_flatten, 1)
353
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
354
+ spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
355
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
356
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
357
+ valid_ratios = valid_ratios.float()
358
+
359
+ # Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
360
+ # Also provide spatial_shapes, level_start_index and valid_ratios
361
+ if encoder_outputs is None:
362
+ encoder_outputs = self.encoder(
363
+ vision_features=source_flatten,
364
+ vision_attention_mask=~mask_flatten,
365
+ vision_position_embedding=lvl_pos_embed_flatten,
366
+ spatial_shapes=spatial_shapes,
367
+ spatial_shapes_list=spatial_shapes_list,
368
+ level_start_index=level_start_index,
369
+ valid_ratios=valid_ratios,
370
+ text_features=text_features,
371
+ text_attention_mask=~text_token_mask,
372
+ text_position_embedding=None,
373
+ text_self_attention_masks=~text_self_attention_masks,
374
+ text_position_ids=position_ids,
375
+ output_attentions=output_attentions,
376
+ output_hidden_states=output_hidden_states,
377
+ return_dict=True,
378
+ )
379
+
380
+ # Fifth, prepare decoder inputs
381
+ topk_proposals = None
382
+ enc_outputs_class = None
383
+ enc_outputs_coord_logits = None
384
+ encoder_logits = None
385
+ encoder_pred_boxes = None
386
+ if self.config.two_stage:
387
+ object_query_embedding, output_proposals = self.generate_encoder_output_proposals(
388
+ encoder_outputs[0], ~mask_flatten, spatial_shapes
389
+ )
390
+
391
+ # hack implementation as in two-stage Deformable DETR
392
+ # apply a detection head to each pixel (A.4 in paper)
393
+ # linear projection for bounding box binary classification (i.e. foreground and background)
394
+ enc_outputs_class = self.encoder_output_class_embed(
395
+ object_query_embedding, encoder_outputs[1], text_token_mask
396
+ )
397
+ # 3-layer FFN to predict bounding boxes coordinates (bbox regression branch)
398
+ delta_bbox = self.encoder_output_bbox_embed(object_query_embedding)
399
+ enc_outputs_coord_logits = delta_bbox + output_proposals
400
+
401
+ # only keep top scoring `config.num_queries` proposals
402
+ topk = self.config.num_queries
403
+ topk_logits = enc_outputs_class.max(-1)[0]
404
+ topk_proposals = torch.topk(topk_logits, topk, dim=1)[1]
405
+ topk_coords_logits = torch.gather(
406
+ enc_outputs_coord_logits, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
407
+ )
408
+
409
+ topk_coords_logits = topk_coords_logits.detach()
410
+ reference_points = (
411
+ topk_coords_logits.sigmoid() if _init_reference_points is None else _init_reference_points
412
+ )
413
+ init_reference_points = reference_points
414
+ if query_embeds is not None:
415
+ target = query_embeds.unsqueeze(0).repeat(batch_size, 1, 1)
416
+ else:
417
+ target = torch.gather(
418
+ object_query_embedding, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)
419
+ ).detach()
420
+
421
+ # Set intermediate topk proposals (coords and class) for loss computation
422
+ encoder_pred_boxes = reference_points
423
+ encoder_logits = self.encoder_output_class_embed(target, text_features, text_token_mask)
424
+ else:
425
+ target = query_embeds.unsqueeze(0).repeat(batch_size, 1, 1)
426
+ reference_points = self.reference_points.weight.unsqueeze(0).repeat(batch_size, 1, 1).sigmoid()
427
+ init_reference_points = reference_points
428
+
429
+ decoder_outputs = self.decoder(
430
+ inputs_embeds=target,
431
+ vision_encoder_hidden_states=encoder_outputs[0],
432
+ vision_encoder_attention_mask=mask_flatten,
433
+ text_encoder_hidden_states=encoder_outputs[1],
434
+ text_encoder_attention_mask=~text_token_mask,
435
+ reference_points=reference_points,
436
+ spatial_shapes=spatial_shapes,
437
+ spatial_shapes_list=spatial_shapes_list,
438
+ level_start_index=level_start_index,
439
+ valid_ratios=valid_ratios,
440
+ self_attn_mask=None,
441
+ output_attentions=output_attentions,
442
+ output_hidden_states=output_hidden_states,
443
+ return_dict=return_dict,
444
+ )
445
+
446
+ if not return_dict:
447
+ enc_outputs = tuple(
448
+ value
449
+ for value in [
450
+ enc_outputs_class,
451
+ enc_outputs_coord_logits,
452
+ encoder_logits,
453
+ encoder_pred_boxes,
454
+ ]
455
+ if value is not None
456
+ )
457
+ tuple_outputs = (
458
+ (decoder_outputs[0], init_reference_points) + decoder_outputs[1:] + encoder_outputs + enc_outputs
459
+ )
460
+
461
+ return tuple_outputs
462
+
463
+ return GroundingDinoModelOutput(
464
+ last_hidden_state=decoder_outputs.last_hidden_state,
465
+ init_reference_points=init_reference_points,
466
+ intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
467
+ intermediate_reference_points=decoder_outputs.intermediate_reference_points,
468
+ decoder_hidden_states=decoder_outputs.hidden_states,
469
+ decoder_attentions=decoder_outputs.attentions,
470
+ encoder_last_hidden_state_vision=encoder_outputs.last_hidden_state_vision,
471
+ encoder_last_hidden_state_text=encoder_outputs.last_hidden_state_text,
472
+ encoder_vision_hidden_states=encoder_outputs.vision_hidden_states,
473
+ encoder_text_hidden_states=encoder_outputs.text_hidden_states,
474
+ encoder_attentions=encoder_outputs.attentions,
475
+ enc_outputs_class=enc_outputs_class,
476
+ enc_outputs_coord_logits=enc_outputs_coord_logits,
477
+ encoder_logits=encoder_logits,
478
+ encoder_pred_boxes=encoder_pred_boxes,
479
+ )
480
+
481
+ def pad_image_to_rbln_config(self, pixel_values: torch.FloatTensor, pixel_mask: torch.BoolTensor):
482
+ batch_size, _, height, width = pixel_values.shape
483
+ image_height, image_width = self.rbln_config.encoder.image_height, self.rbln_config.encoder.image_width
484
+
485
+ pad_h = image_height - height
486
+ pad_w = image_width - width
487
+ pixel_mask = (
488
+ pixel_mask
489
+ if pixel_mask is not None
490
+ else torch.ones(((batch_size, height, width)), dtype=torch.long, device=pixel_values.device)
491
+ )
492
+
493
+ if pad_h < 0 or pad_w < 0:
494
+ raise ValueError(
495
+ f"Image size {height}x{width} is larger than encoder's image_size {image_height}x{image_width}"
496
+ )
497
+
498
+ if pad_h > 0 or pad_w > 0:
499
+ pixel_values = torch.nn.functional.pad(pixel_values, (0, pad_w, 0, pad_h), value=0)
500
+ pixel_mask = torch.nn.functional.pad(pixel_mask, (0, pad_w, 0, pad_h), value=0)
501
+
502
+ return pixel_values, pixel_mask
503
+
504
+ def pad_text_to_rbln_config(
505
+ self,
506
+ input_ids: torch.LongTensor,
507
+ token_type_ids: Optional[torch.LongTensor] = None,
508
+ attention_mask: Optional[torch.LongTensor] = None,
509
+ ):
510
+ batch_size, seq_len = input_ids.shape
511
+ max_text_len = self.config.max_text_len
512
+ token_type_ids = token_type_ids if token_type_ids is not None else torch.zeros_like(input_ids)
513
+ attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids)
514
+ if seq_len < max_text_len:
515
+ input_ids = torch.nn.functional.pad(input_ids, (0, max_text_len - seq_len, 0, 0), value=0)
516
+ token_type_ids = torch.nn.functional.pad(token_type_ids, (0, max_text_len - seq_len, 0, 0), value=0)
517
+ attention_mask = torch.nn.functional.pad(attention_mask, (0, max_text_len - seq_len, 0, 0), value=0)
518
+
519
+ return input_ids, token_type_ids, attention_mask
520
+
521
+ def forward(
522
+ self,
523
+ pixel_values: torch.FloatTensor,
524
+ input_ids: torch.LongTensor,
525
+ token_type_ids: Optional[torch.LongTensor] = None,
526
+ attention_mask: Optional[torch.LongTensor] = None,
527
+ pixel_mask: Optional[torch.BoolTensor] = None,
528
+ encoder_outputs: Optional[Union[GroundingDinoEncoderOutput, Tuple]] = None,
529
+ output_attentions: Optional[bool] = None,
530
+ output_hidden_states: Optional[bool] = None,
531
+ return_dict: Optional[bool] = None,
532
+ **kwargs,
533
+ ) -> Union[GroundingDinoObjectDetectionOutput, Tuple]:
534
+ """
535
+ Forward pass for the RBLN-optimized GroundingDinoForObjectDetection model.
536
+
537
+ Args:
538
+ pixel_values (torch.Tensor of shape (batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images.
539
+ input_ids (torch.LongTensor of shape (batch_size, text_sequence_length)): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
540
+ token_type_ids (torch.LongTensor of shape (batch_size, text_sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
541
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
542
+ pixel_mask (torch.Tensor of shape (batch_size, height, width), optional): Mask to avoid performing attention on padding pixel values.
543
+ encoder_outputs (Tuple consists of last_hidden_state of shape(batch_size, sequence_length, hidden_size), optional): A sequence of hidden-states at the output of the last layer of the encoder.
544
+ output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers.
545
+ output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers.
546
+ return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
547
+
548
+ Returns:
549
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a GroundingDinoObjectDetectionOutput object.
550
+ """
551
+
552
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
553
+
554
+ # Pad image to rbln_config.image_height and rbln_config.image_width
555
+ pixel_values, pixel_mask = self.pad_image_to_rbln_config(pixel_values, pixel_mask)
556
+ input_ids, token_type_ids, attention_mask = self.pad_text_to_rbln_config(
557
+ input_ids, token_type_ids, attention_mask
558
+ )
559
+
560
+ with torch.inference_mode():
561
+ # First, sent images through Grounding DINO base model to obtain encoder + decoder outputs
562
+ outputs = self._model_forward(
563
+ pixel_values=pixel_values,
564
+ input_ids=input_ids,
565
+ token_type_ids=token_type_ids,
566
+ attention_mask=attention_mask,
567
+ pixel_mask=pixel_mask,
568
+ encoder_outputs=encoder_outputs,
569
+ output_attentions=output_attentions,
570
+ output_hidden_states=output_hidden_states,
571
+ return_dict=return_dict,
572
+ **kwargs,
573
+ )
574
+
575
+ idx = 5 + (1 if output_attentions else 0) + (1 if output_hidden_states else 0)
576
+ enc_text_hidden_state = outputs.encoder_last_hidden_state_text if return_dict else outputs[idx]
577
+ hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[2]
578
+ init_reference_points = outputs.init_reference_points if return_dict else outputs[1]
579
+ inter_references_points = outputs.intermediate_reference_points if return_dict else outputs[3]
580
+
581
+ # class logits + predicted bounding boxes
582
+ outputs_classes = []
583
+ outputs_coords = []
584
+
585
+ # hidden_states are of shape (batch_size, num_stages, height, width)
586
+ # predict class and bounding box deltas for each stage
587
+ num_levels = hidden_states.shape[1]
588
+ for level in range(num_levels):
589
+ if level == 0:
590
+ reference = init_reference_points
591
+ else:
592
+ reference = inter_references_points[:, level - 1]
593
+ reference = torch.special.logit(reference, eps=1e-5)
594
+ outputs_class = self.class_embed[level](
595
+ vision_hidden_state=hidden_states[:, level],
596
+ text_hidden_state=enc_text_hidden_state,
597
+ text_token_mask=attention_mask.bool(),
598
+ )
599
+ delta_bbox = self.bbox_embed[level](hidden_states[:, level])
600
+
601
+ reference_coordinates = reference.shape[-1]
602
+ if reference_coordinates == 4:
603
+ outputs_coord_logits = delta_bbox + reference
604
+ elif reference_coordinates == 2:
605
+ delta_bbox[..., :2] += reference
606
+ outputs_coord_logits = delta_bbox
607
+ else:
608
+ raise ValueError(f"reference.shape[-1] should be 4 or 2, but got {reference.shape[-1]}")
609
+ outputs_coord = outputs_coord_logits.sigmoid()
610
+ outputs_classes.append(outputs_class)
611
+ outputs_coords.append(outputs_coord)
612
+ outputs_class = torch.stack(outputs_classes)
613
+ outputs_coord = torch.stack(outputs_coords)
614
+
615
+ logits = outputs_class[-1]
616
+ pred_boxes = outputs_coord[-1]
617
+
618
+ if not return_dict:
619
+ auxiliary_outputs = []
620
+ output = [logits, pred_boxes, *auxiliary_outputs, *outputs, input_ids]
621
+ output = tuple(out for out in output if out is not None)
622
+ return output
623
+
624
+ return GroundingDinoObjectDetectionOutput(
625
+ logits=logits,
626
+ pred_boxes=pred_boxes,
627
+ last_hidden_state=outputs.last_hidden_state,
628
+ decoder_hidden_states=outputs.decoder_hidden_states,
629
+ decoder_attentions=outputs.decoder_attentions,
630
+ encoder_last_hidden_state_vision=outputs.encoder_last_hidden_state_vision,
631
+ encoder_last_hidden_state_text=outputs.encoder_last_hidden_state_text,
632
+ encoder_vision_hidden_states=outputs.encoder_vision_hidden_states,
633
+ encoder_text_hidden_states=outputs.encoder_text_hidden_states,
634
+ encoder_attentions=outputs.encoder_attentions,
635
+ intermediate_hidden_states=outputs.intermediate_hidden_states,
636
+ intermediate_reference_points=outputs.intermediate_reference_points,
637
+ init_reference_points=outputs.init_reference_points,
638
+ enc_outputs_class=outputs.enc_outputs_class,
639
+ enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
640
+ encoder_logits=outputs.encoder_logits,
641
+ encoder_pred_boxes=outputs.encoder_pred_boxes,
642
+ input_ids=input_ids,
643
+ )
644
+
645
+
646
+ def _update_spatial_shapes(model_config, rbln_config):
647
+ def down_sampled_size(x, depth: int = 1):
648
+ if depth == 0:
649
+ return x
650
+ return down_sampled_size((x + 1) // 2, depth - 1)
651
+
652
+ def num_patches(image_size, patch_size):
653
+ return (image_size + patch_size - 1) // patch_size
654
+
655
+ # update spatial_shapes
656
+ spatial_shapes = []
657
+ backbone_config = model_config.backbone_config
658
+ num_patched_h = num_patches(rbln_config.image_height, backbone_config.patch_size)
659
+ num_patched_w = num_patches(rbln_config.image_height, backbone_config.patch_size)
660
+ for out_layer in backbone_config.out_indices:
661
+ spatial_shapes.append(
662
+ [down_sampled_size(num_patched_h, out_layer - 1), down_sampled_size(num_patched_w, out_layer - 1)]
663
+ )
664
+
665
+ # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
666
+ if model_config.num_feature_levels > len(spatial_shapes):
667
+ last_h, last_w = spatial_shapes[-1][0], spatial_shapes[-1][1]
668
+ h_out = (last_h - 1) // 2 + 1
669
+ w_out = (last_w - 1) // 2 + 1
670
+ spatial_shapes.append([h_out, w_out])
671
+
672
+ rbln_config.spatial_shapes_list = spatial_shapes
673
+
674
+ return rbln_config
675
+
676
+
677
+ class RBLNGroundingDinoEncoder(RBLNModel):
678
+ def __post_init__(self, **kwargs):
679
+ self.encoder_runtime = RBLNPytorchRuntime(self.model[0])
680
+
681
+ @classmethod
682
+ def _wrap_model_if_needed(
683
+ cls, model: torch.nn.Module, rbln_config: RBLNGroundingDinoForObjectDetectionConfig
684
+ ) -> torch.nn.Module:
685
+ model = _GroundingDinoEncoder(model, rbln_config).eval()
686
+ return model
687
+
688
+ @classmethod
689
+ def _update_submodule_config(
690
+ cls,
691
+ model: "PreTrainedModel",
692
+ rbln_config: RBLNModelConfig,
693
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
694
+ ):
695
+ for processor in preprocessors:
696
+ if rbln_config.image_size is None and hasattr(processor, "image_processor"):
697
+ if "height" in processor.image_processor.size and "width" in processor.image_processor.size:
698
+ rbln_config.image_size = (
699
+ processor.image_processor.size["height"],
700
+ processor.image_processor.size["width"],
701
+ )
702
+ elif (
703
+ "longest_edge" in processor.image_processor.size
704
+ and "shortest_edge" in processor.image_processor.size
705
+ ):
706
+ rbln_config.image_size = processor.image_processor.size["longest_edge"]
707
+ elif "shortest_edge" in processor.image_processor.size:
708
+ rbln_config.image_size = processor.image_processor.size["shortest_edge"]
709
+ break
710
+ rbln_config = _update_spatial_shapes(model.config, rbln_config)
711
+ return rbln_config
712
+
713
+ @classmethod
714
+ def _update_rbln_config(
715
+ cls,
716
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
717
+ model: Optional["PreTrainedModel"] = None,
718
+ model_config: RBLNGroundingDinoEncoderConfig = None,
719
+ rbln_config: Optional[RBLNGroundingDinoEncoderConfig] = None,
720
+ ) -> RBLNGroundingDinoEncoderConfig:
721
+ if rbln_config.image_size is None:
722
+ raise ValueError("RBLN config must have image_size set for RBLN optimized GroundingDinoDecoder.")
723
+
724
+ vision_seq_len = int((rbln_config.spatial_shapes[:, 0] * rbln_config.spatial_shapes[:, 1]).sum())
725
+
726
+ input_info = [
727
+ (
728
+ "vision_features",
729
+ [rbln_config.batch_size, vision_seq_len, model_config.d_model],
730
+ "float32",
731
+ ),
732
+ (
733
+ "vision_attention_mask",
734
+ [
735
+ rbln_config.batch_size,
736
+ vision_seq_len,
737
+ model_config.d_model,
738
+ ],
739
+ "float32",
740
+ ),
741
+ (
742
+ "vision_position_embedding",
743
+ [rbln_config.batch_size, vision_seq_len, model_config.d_model],
744
+ "float32",
745
+ ),
746
+ (
747
+ "text_features",
748
+ [rbln_config.batch_size, model_config.max_text_len, model_config.d_model],
749
+ "float32",
750
+ ),
751
+ (
752
+ "text_attention_mask",
753
+ [
754
+ rbln_config.batch_size,
755
+ model_config.max_text_len,
756
+ ],
757
+ "float32",
758
+ ),
759
+ (
760
+ "text_self_attention_masks",
761
+ [
762
+ rbln_config.batch_size,
763
+ model_config.max_text_len,
764
+ model_config.max_text_len,
765
+ ],
766
+ "float32",
767
+ ),
768
+ (
769
+ "reference_points",
770
+ [rbln_config.batch_size, vision_seq_len, 4, 2],
771
+ "float32",
772
+ ),
773
+ ]
774
+
775
+ rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
776
+
777
+ return rbln_config
778
+
779
+ @staticmethod
780
+ def get_reference_points(spatial_shapes, valid_ratios, device):
781
+ reference_points_list = []
782
+ for level, (height, width) in enumerate(spatial_shapes):
783
+ ref_y, ref_x = meshgrid(
784
+ torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
785
+ torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device),
786
+ indexing="ij",
787
+ )
788
+ # TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
789
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height)
790
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, level, 0] * width)
791
+ ref = torch.stack((ref_x, ref_y), -1)
792
+ reference_points_list.append(ref)
793
+ reference_points = torch.cat(reference_points_list, 1)
794
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
795
+ return reference_points
796
+
797
+ def validate_output_config(self, output_attentions, output_hidden_states):
798
+ if output_attentions != self.rbln_config.output_attentions:
799
+ raise ValueError(
800
+ f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
801
+ f"Please compile again with the correct argument."
802
+ )
803
+
804
+ if output_hidden_states != self.rbln_config.output_hidden_states:
805
+ raise ValueError(
806
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
807
+ f"Please compile again with the correct argument."
808
+ )
809
+
810
+ def forward(
811
+ self,
812
+ vision_features: Tensor,
813
+ vision_attention_mask: Tensor,
814
+ vision_position_embedding: Tensor,
815
+ spatial_shapes: Tensor,
816
+ spatial_shapes_list: List[Tuple[int, int]],
817
+ level_start_index: Tensor,
818
+ valid_ratios: Optional[Tensor] = None,
819
+ text_features: Optional[Tensor] = None,
820
+ text_attention_mask: Optional[Tensor] = None,
821
+ text_position_embedding: Optional[Tensor] = None,
822
+ text_self_attention_masks: Optional[Tensor] = None,
823
+ text_position_ids: Optional[Tensor] = None,
824
+ output_attentions: Optional[bool] = None,
825
+ output_hidden_states: Optional[bool] = None,
826
+ return_dict: Optional[bool] = None,
827
+ ):
828
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
829
+ output_hidden_states = (
830
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
831
+ )
832
+ self.validate_output_config(output_attentions, output_hidden_states)
833
+
834
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
835
+ reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device="cpu")
836
+ vision_attention_mask = vision_attention_mask.to(torch.float32).unsqueeze(-1).repeat(1, 1, self.config.d_model)
837
+
838
+ enc_outputs = self.encoder_runtime(
839
+ vision_features=vision_features,
840
+ vision_attention_mask=vision_attention_mask,
841
+ vision_position_embedding=vision_position_embedding,
842
+ text_features=text_features,
843
+ text_attention_mask=text_attention_mask.to(torch.float32),
844
+ text_self_attention_masks=text_self_attention_masks.to(torch.float32),
845
+ reference_points=reference_points,
846
+ )
847
+
848
+ if not return_dict:
849
+ return tuple(enc_outputs)
850
+
851
+ enc_outputs = list(enc_outputs)
852
+ last_hidden_state_vision = enc_outputs.pop(0)
853
+ last_hidden_state_text = enc_outputs.pop(0)
854
+ vision_hidden_states = (
855
+ tuple([enc_outputs.pop(0) for _ in range(self.config.encoder_layers + 1)])
856
+ if self.rbln_config.output_hidden_states
857
+ else None
858
+ )
859
+ text_hidden_states = (
860
+ tuple([enc_outputs.pop(0) for _ in range(self.config.encoder_layers + 1)])
861
+ if self.rbln_config.output_hidden_states
862
+ else None
863
+ )
864
+ attentions = tuple(enc_outputs) if self.rbln_config.output_attentions else None
865
+
866
+ return GroundingDinoEncoderOutput(
867
+ last_hidden_state_vision=last_hidden_state_vision,
868
+ last_hidden_state_text=last_hidden_state_text,
869
+ vision_hidden_states=vision_hidden_states,
870
+ text_hidden_states=text_hidden_states,
871
+ attentions=attentions,
872
+ )
873
+
874
+
875
+ class RBLNGroundingDinoDecoder(RBLNModel):
876
+ def __post_init__(self, **kwargs):
877
+ self.decoder_runtime = RBLNPytorchRuntime(self.model[0])
878
+
879
+ @classmethod
880
+ def _wrap_model_if_needed(
881
+ cls, model: torch.nn.Module, rbln_config: RBLNGroundingDinoForObjectDetectionConfig
882
+ ) -> torch.nn.Module:
883
+ return _GroundingDinoDecoder(model, rbln_config).eval()
884
+
885
+ @classmethod
886
+ def _update_submodule_config(
887
+ cls,
888
+ model: "PreTrainedModel",
889
+ rbln_config: RBLNModelConfig,
890
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
891
+ ):
892
+ for processor in preprocessors:
893
+ if rbln_config.image_size is None and hasattr(processor, "image_processor"):
894
+ if "height" in processor.image_processor.size and "width" in processor.image_processor.size:
895
+ rbln_config.image_size = (
896
+ processor.image_processor.size["height"],
897
+ processor.image_processor.size["width"],
898
+ )
899
+ elif (
900
+ "longest_edge" in processor.image_processor.size
901
+ and "shortest_edge" in processor.image_processor.size
902
+ ):
903
+ rbln_config.image_size = processor.image_processor.size["longest_edge"]
904
+ elif "shortest_edge" in processor.image_processor.size:
905
+ rbln_config.image_size = processor.image_processor.size["shortest_edge"]
906
+ break
907
+ rbln_config = _update_spatial_shapes(model.config, rbln_config)
908
+
909
+ return rbln_config
910
+
911
+ @classmethod
912
+ def _update_rbln_config(
913
+ cls,
914
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
915
+ model: Optional["PreTrainedModel"] = None,
916
+ model_config: RBLNGroundingDinoDecoderConfig = None,
917
+ rbln_config: Optional[RBLNGroundingDinoEncoderConfig] = None,
918
+ ) -> RBLNGroundingDinoEncoderConfig:
919
+ if rbln_config.image_size is None:
920
+ raise ValueError("RBLN config must have image_size set for RBLN optimized GroundingDinoDecoder.")
921
+
922
+ vision_seq_len = int((rbln_config.spatial_shapes[:, 0] * rbln_config.spatial_shapes[:, 1]).sum())
923
+
924
+ input_info = [
925
+ (
926
+ "inputs_embeds",
927
+ [rbln_config.batch_size, model_config.num_queries, model_config.d_model],
928
+ "float32",
929
+ ),
930
+ (
931
+ "vision_encoder_hidden_states",
932
+ [
933
+ rbln_config.batch_size,
934
+ vision_seq_len,
935
+ model_config.d_model,
936
+ ],
937
+ "float32",
938
+ ),
939
+ (
940
+ "vision_encoder_attention_mask",
941
+ [rbln_config.batch_size, vision_seq_len, model_config.d_model],
942
+ "float32",
943
+ ),
944
+ (
945
+ "text_encoder_hidden_states",
946
+ [rbln_config.batch_size, model_config.max_text_len, model_config.d_model],
947
+ "float32",
948
+ ),
949
+ (
950
+ "text_encoder_attention_mask",
951
+ [
952
+ rbln_config.batch_size,
953
+ model_config.max_text_len,
954
+ ],
955
+ "float32",
956
+ ),
957
+ (
958
+ "reference_points",
959
+ [
960
+ rbln_config.batch_size,
961
+ model_config.num_queries,
962
+ 4,
963
+ ],
964
+ "float32",
965
+ ),
966
+ (
967
+ "valid_ratios",
968
+ [
969
+ rbln_config.batch_size,
970
+ 4,
971
+ 2,
972
+ ],
973
+ "float32",
974
+ ),
975
+ ]
976
+
977
+ rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
978
+ return rbln_config
979
+
980
+ def validate_output_config(self, output_attentions, output_hidden_states):
981
+ if output_attentions != self.rbln_config.output_attentions:
982
+ raise ValueError(
983
+ f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
984
+ f"Please compile again with the correct argument."
985
+ )
986
+ if output_hidden_states != self.rbln_config.output_hidden_states:
987
+ raise ValueError(
988
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
989
+ f"Please compile again with the correct argument."
990
+ )
991
+
992
+ def forward(
993
+ self,
994
+ inputs_embeds: torch.Tensor,
995
+ vision_encoder_hidden_states: torch.Tensor,
996
+ vision_encoder_attention_mask: torch.Tensor,
997
+ text_encoder_hidden_states: torch.Tensor,
998
+ text_encoder_attention_mask: torch.Tensor,
999
+ reference_points: torch.Tensor,
1000
+ valid_ratios: torch.Tensor,
1001
+ output_attentions: bool = False,
1002
+ output_hidden_states: bool = False,
1003
+ return_dict: bool = False,
1004
+ **kwargs,
1005
+ ):
1006
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1007
+ output_hidden_states = (
1008
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1009
+ )
1010
+ self.validate_output_config(output_attentions, output_hidden_states)
1011
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1012
+
1013
+ reshaped_vision_encoder_attention_mask = (
1014
+ vision_encoder_attention_mask[:, :, None].repeat(1, 1, self.config.d_model).to(torch.float32)
1015
+ )
1016
+
1017
+ # Forward pass through the decoder
1018
+ outputs = self.decoder_runtime(
1019
+ inputs_embeds=inputs_embeds,
1020
+ vision_encoder_hidden_states=vision_encoder_hidden_states,
1021
+ vision_encoder_attention_mask=reshaped_vision_encoder_attention_mask,
1022
+ text_encoder_hidden_states=text_encoder_hidden_states,
1023
+ text_encoder_attention_mask=text_encoder_attention_mask.to(torch.float32),
1024
+ reference_points=reference_points,
1025
+ valid_ratios=valid_ratios,
1026
+ )
1027
+
1028
+ if not return_dict:
1029
+ return outputs
1030
+
1031
+ outputs = list(outputs)
1032
+ last_hidden_state = outputs.pop(0)
1033
+ intermediate_hidden_states = outputs.pop(0)
1034
+ intermediate_reference_points = outputs.pop(0)
1035
+ hidden_states = (
1036
+ tuple([outputs.pop(0) for _ in range(self.config.decoder_layers + 1)])
1037
+ if self.rbln_config.output_hidden_states
1038
+ else None
1039
+ )
1040
+ attentions = tuple(outputs) if self.rbln_config.output_attentions else None
1041
+
1042
+ return GroundingDinoDecoderOutput(
1043
+ last_hidden_state=last_hidden_state,
1044
+ intermediate_hidden_states=intermediate_hidden_states,
1045
+ intermediate_reference_points=intermediate_reference_points,
1046
+ hidden_states=hidden_states,
1047
+ attentions=attentions,
1048
+ )