optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +116 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +171 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +33 -18
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +50 -24
- optimum/rbln/modeling_base.py +116 -35
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +100 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +93 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +22 -50
- optimum/rbln/utils/runtime_utils.py +85 -17
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1048 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch import Tensor, nn
|
|
20
|
+
from transformers.modeling_utils import no_init_weights
|
|
21
|
+
from transformers.models.grounding_dino.modeling_grounding_dino import (
|
|
22
|
+
GroundingDinoContrastiveEmbedding,
|
|
23
|
+
GroundingDinoConvEncoder,
|
|
24
|
+
GroundingDinoDecoderOutput,
|
|
25
|
+
GroundingDinoEncoderOutput,
|
|
26
|
+
GroundingDinoMLPPredictionHead,
|
|
27
|
+
GroundingDinoModel,
|
|
28
|
+
GroundingDinoModelOutput,
|
|
29
|
+
GroundingDinoObjectDetectionOutput,
|
|
30
|
+
build_position_encoding,
|
|
31
|
+
generate_masks_with_special_tokens_and_transfer_map,
|
|
32
|
+
)
|
|
33
|
+
from transformers.pytorch_utils import meshgrid
|
|
34
|
+
|
|
35
|
+
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
36
|
+
from ....modeling import RBLNModel
|
|
37
|
+
from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
38
|
+
from .configuration_grounding_dino import (
|
|
39
|
+
RBLNGroundingDinoDecoderConfig,
|
|
40
|
+
RBLNGroundingDinoEncoderConfig,
|
|
41
|
+
RBLNGroundingDinoForObjectDetectionConfig,
|
|
42
|
+
)
|
|
43
|
+
from .grounding_dino_architecture import (
|
|
44
|
+
_GroundingDinoDecoder,
|
|
45
|
+
_GroundingDinoEncoder,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
if TYPE_CHECKING:
|
|
50
|
+
from transformers import (
|
|
51
|
+
AutoFeatureExtractor,
|
|
52
|
+
AutoProcessor,
|
|
53
|
+
AutoTokenizer,
|
|
54
|
+
PreTrainedModel,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class RBLNGroundingDinoForObjectDetection(RBLNModel):
|
|
59
|
+
_rbln_submodules = [
|
|
60
|
+
{"name": "text_backbone"},
|
|
61
|
+
{"name": "backbone"},
|
|
62
|
+
{"name": "encoder"},
|
|
63
|
+
{"name": "decoder"},
|
|
64
|
+
]
|
|
65
|
+
"""
|
|
66
|
+
RBLN optimized Grounding DINO model for object detection.
|
|
67
|
+
This class provides hardware-accelerated inference for Grounding DINO models
|
|
68
|
+
on RBLN devices, supporting multimodal object detection tasks that combine
|
|
69
|
+
vision and language understanding.
|
|
70
|
+
|
|
71
|
+
Grounding DINO is a transformer-based architecture consisting of:
|
|
72
|
+
- A backbone for feature extraction from images
|
|
73
|
+
- An encoder-decoder transformer for processing visual and textual features
|
|
74
|
+
- Object detection heads for predicting bounding boxes and class labels
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __post_init__(self, **kwargs):
|
|
78
|
+
self._setup_cpu_instances()
|
|
79
|
+
self.text_projection = RBLNPytorchRuntime(self.model[0])
|
|
80
|
+
self.text_backbone = self.rbln_submodules[0]
|
|
81
|
+
self.backbone = self.rbln_submodules[1]
|
|
82
|
+
self.encoder = self.rbln_submodules[2]
|
|
83
|
+
self.decoder = self.rbln_submodules[3]
|
|
84
|
+
|
|
85
|
+
def _setup_cpu_instances(self):
|
|
86
|
+
stacte_dict = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
87
|
+
with no_init_weights():
|
|
88
|
+
config = self.config
|
|
89
|
+
_class_embed = GroundingDinoContrastiveEmbedding(config)
|
|
90
|
+
if config.decoder_bbox_embed_share: # True
|
|
91
|
+
_bbox_embed = GroundingDinoMLPPredictionHead(
|
|
92
|
+
input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
|
|
93
|
+
)
|
|
94
|
+
self.bbox_embed = nn.ModuleList([_bbox_embed for _ in range(config.decoder_layers)])
|
|
95
|
+
else:
|
|
96
|
+
for _ in range(config.decoder_layers):
|
|
97
|
+
_bbox_embed = GroundingDinoMLPPredictionHead(
|
|
98
|
+
input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
|
|
99
|
+
)
|
|
100
|
+
self.bbox_embed = nn.ModuleList([_bbox_embed for _ in range(config.decoder_layers)])
|
|
101
|
+
self.class_embed = nn.ModuleList([_class_embed for _ in range(config.decoder_layers)])
|
|
102
|
+
|
|
103
|
+
backbone = GroundingDinoConvEncoder(config)
|
|
104
|
+
self.backbone_position_embedding = build_position_encoding(self.config)
|
|
105
|
+
# Create input projection layers
|
|
106
|
+
if config.num_feature_levels > 1:
|
|
107
|
+
num_backbone_outs = len(backbone.intermediate_channel_sizes)
|
|
108
|
+
input_proj_list = []
|
|
109
|
+
for i in range(num_backbone_outs):
|
|
110
|
+
in_channels = backbone.intermediate_channel_sizes[i]
|
|
111
|
+
input_proj_list.append(
|
|
112
|
+
nn.Sequential(
|
|
113
|
+
nn.Conv2d(in_channels, config.d_model, kernel_size=1),
|
|
114
|
+
nn.GroupNorm(32, config.d_model),
|
|
115
|
+
)
|
|
116
|
+
)
|
|
117
|
+
for _ in range(config.num_feature_levels - num_backbone_outs):
|
|
118
|
+
input_proj_list.append(
|
|
119
|
+
nn.Sequential(
|
|
120
|
+
nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1),
|
|
121
|
+
nn.GroupNorm(32, config.d_model),
|
|
122
|
+
)
|
|
123
|
+
)
|
|
124
|
+
in_channels = config.d_model
|
|
125
|
+
self.input_proj_vision = nn.ModuleList(input_proj_list)
|
|
126
|
+
else:
|
|
127
|
+
self.input_proj_vision = nn.ModuleList(
|
|
128
|
+
[
|
|
129
|
+
nn.Sequential(
|
|
130
|
+
nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1),
|
|
131
|
+
nn.GroupNorm(32, config.d_model),
|
|
132
|
+
)
|
|
133
|
+
]
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
if config.embedding_init_target or not config.two_stage:
|
|
137
|
+
self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model)
|
|
138
|
+
|
|
139
|
+
self.level_embed = nn.Parameter(torch.Tensor(config.num_feature_levels, config.d_model))
|
|
140
|
+
|
|
141
|
+
if config.two_stage:
|
|
142
|
+
self.enc_output = nn.Linear(config.d_model, config.d_model)
|
|
143
|
+
self.enc_output_norm = nn.LayerNorm(config.d_model, config.layer_norm_eps)
|
|
144
|
+
if (
|
|
145
|
+
config.two_stage_bbox_embed_share
|
|
146
|
+
and config.decoder_bbox_embed_share
|
|
147
|
+
and self.decoder.bbox_embed is not None
|
|
148
|
+
):
|
|
149
|
+
self.encoder_output_bbox_embed = self.decoder.bbox_embed
|
|
150
|
+
else:
|
|
151
|
+
self.encoder_output_bbox_embed = GroundingDinoMLPPredictionHead(
|
|
152
|
+
input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
self.encoder_output_class_embed = GroundingDinoContrastiveEmbedding(config)
|
|
156
|
+
else:
|
|
157
|
+
self.reference_points = nn.Embedding(config.num_queries, 4)
|
|
158
|
+
|
|
159
|
+
self.bbox_embed.load_state_dict(stacte_dict["bbox_embed"])
|
|
160
|
+
self.class_embed.load_state_dict(stacte_dict["class_embed"])
|
|
161
|
+
self.input_proj_vision.load_state_dict(stacte_dict["input_proj_vision"])
|
|
162
|
+
with torch.no_grad():
|
|
163
|
+
self.level_embed.copy_(stacte_dict["level_embed"])
|
|
164
|
+
if self.config.two_stage:
|
|
165
|
+
self.enc_output.load_state_dict(stacte_dict["enc_output"])
|
|
166
|
+
self.enc_output_norm.load_state_dict(stacte_dict["enc_output_norm"])
|
|
167
|
+
self.encoder_output_class_embed.load_state_dict(stacte_dict["encoder_output_class_embed"])
|
|
168
|
+
self.encoder_output_bbox_embed.load_state_dict(stacte_dict["encoder_output_bbox_embed"])
|
|
169
|
+
else:
|
|
170
|
+
self.reference_points.load_state_dict(stacte_dict["reference_points"])
|
|
171
|
+
if self.config.embedding_init_target or not self.config.two_stage:
|
|
172
|
+
self.query_position_embeddings.load_state_dict(stacte_dict["query_position_embeddings"])
|
|
173
|
+
|
|
174
|
+
if self.config.position_embedding_type == "learned":
|
|
175
|
+
self.backbone_position_embedding.load_state_dict(stacte_dict["backbone_position_embedding"])
|
|
176
|
+
|
|
177
|
+
@classmethod
|
|
178
|
+
def save_torch_artifacts(
|
|
179
|
+
cls,
|
|
180
|
+
model: "PreTrainedModel",
|
|
181
|
+
save_dir_path: Path,
|
|
182
|
+
subfolder: str,
|
|
183
|
+
rbln_config: RBLNGroundingDinoForObjectDetectionConfig,
|
|
184
|
+
):
|
|
185
|
+
# If you are unavoidably running on a CPU rather than an RBLN device,
|
|
186
|
+
# store the torch tensor, weight, etc. in this function.
|
|
187
|
+
save_dict = {}
|
|
188
|
+
save_dict["input_proj_vision"] = model.model.input_proj_vision.state_dict()
|
|
189
|
+
save_dict["level_embed"] = model.model.level_embed
|
|
190
|
+
if model.config.two_stage:
|
|
191
|
+
save_dict["enc_output"] = model.model.enc_output.state_dict()
|
|
192
|
+
save_dict["enc_output_norm"] = model.model.enc_output_norm.state_dict()
|
|
193
|
+
save_dict["encoder_output_class_embed"] = model.model.encoder_output_class_embed.state_dict()
|
|
194
|
+
save_dict["encoder_output_bbox_embed"] = model.model.encoder_output_bbox_embed.state_dict()
|
|
195
|
+
else:
|
|
196
|
+
save_dict["reference_points"] = model.model.reference_points.state_dict()
|
|
197
|
+
if model.config.embedding_init_target or not model.config.two_stage:
|
|
198
|
+
save_dict["query_position_embeddings"] = model.model.query_position_embeddings.state_dict()
|
|
199
|
+
|
|
200
|
+
if model.config.position_embedding_type == "learned":
|
|
201
|
+
save_dict["backbone_position_embedding"] = model.model.backbone.position_embedding.state_dict()
|
|
202
|
+
|
|
203
|
+
save_dict["class_embed"] = model.class_embed.state_dict()
|
|
204
|
+
save_dict["bbox_embed"] = model.bbox_embed.state_dict()
|
|
205
|
+
|
|
206
|
+
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
207
|
+
|
|
208
|
+
@classmethod
|
|
209
|
+
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
|
|
210
|
+
model.encoder = model.model.encoder
|
|
211
|
+
model.decoder = model.model.decoder
|
|
212
|
+
model.text_backbone = model.model.text_backbone
|
|
213
|
+
model.encoder.config = model.config
|
|
214
|
+
model.decoder.config = model.config
|
|
215
|
+
model.backbone = model.model.backbone.conv_encoder.model
|
|
216
|
+
return model
|
|
217
|
+
|
|
218
|
+
@classmethod
|
|
219
|
+
def _wrap_model_if_needed(
|
|
220
|
+
cls, model: torch.nn.Module, rbln_config: RBLNGroundingDinoForObjectDetectionConfig
|
|
221
|
+
) -> torch.nn.Module:
|
|
222
|
+
return model.model.text_projection
|
|
223
|
+
|
|
224
|
+
@classmethod
|
|
225
|
+
def _update_rbln_config(
|
|
226
|
+
cls,
|
|
227
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
228
|
+
model: Optional["PreTrainedModel"] = None,
|
|
229
|
+
model_config: RBLNGroundingDinoForObjectDetectionConfig = None,
|
|
230
|
+
rbln_config: Optional[RBLNGroundingDinoForObjectDetectionConfig] = None,
|
|
231
|
+
) -> RBLNGroundingDinoForObjectDetectionConfig:
|
|
232
|
+
input_info = [
|
|
233
|
+
(
|
|
234
|
+
"test_features",
|
|
235
|
+
[rbln_config.batch_size, model_config.max_text_len, model_config.text_config.hidden_size],
|
|
236
|
+
"float32",
|
|
237
|
+
),
|
|
238
|
+
]
|
|
239
|
+
|
|
240
|
+
rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
|
|
241
|
+
return rbln_config
|
|
242
|
+
|
|
243
|
+
def generate_encoder_output_proposals(self, *args, **kwargs):
|
|
244
|
+
return GroundingDinoModel.generate_encoder_output_proposals(self, *args, **kwargs)
|
|
245
|
+
|
|
246
|
+
def get_valid_ratio(self, *args, **kwargs):
|
|
247
|
+
return GroundingDinoModel.get_valid_ratio(self, *args, **kwargs)
|
|
248
|
+
|
|
249
|
+
def _model_forward(
|
|
250
|
+
self,
|
|
251
|
+
pixel_values: Tensor,
|
|
252
|
+
input_ids: Tensor,
|
|
253
|
+
token_type_ids: Optional[Tensor] = None,
|
|
254
|
+
attention_mask: Optional[Tensor] = None,
|
|
255
|
+
pixel_mask: Optional[Tensor] = None,
|
|
256
|
+
encoder_outputs=None,
|
|
257
|
+
output_attentions=None,
|
|
258
|
+
output_hidden_states=None,
|
|
259
|
+
return_dict=None,
|
|
260
|
+
_init_reference_points=None,
|
|
261
|
+
):
|
|
262
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
263
|
+
output_hidden_states = (
|
|
264
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
265
|
+
)
|
|
266
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
267
|
+
|
|
268
|
+
text_self_attention_masks, position_ids = generate_masks_with_special_tokens_and_transfer_map(input_ids)
|
|
269
|
+
|
|
270
|
+
text_token_mask = attention_mask.bool() # just to avoid renaming everywhere
|
|
271
|
+
|
|
272
|
+
max_text_len = self.config.max_text_len
|
|
273
|
+
if text_self_attention_masks.shape[1] > max_text_len:
|
|
274
|
+
text_self_attention_masks = text_self_attention_masks[:, :max_text_len, :max_text_len]
|
|
275
|
+
position_ids = position_ids[:, :max_text_len]
|
|
276
|
+
input_ids = input_ids[:, :max_text_len]
|
|
277
|
+
token_type_ids = token_type_ids[:, :max_text_len]
|
|
278
|
+
text_token_mask = text_token_mask[:, :max_text_len]
|
|
279
|
+
|
|
280
|
+
# Extract text features from text backbone
|
|
281
|
+
text_outputs = self.text_backbone(
|
|
282
|
+
input_ids, text_self_attention_masks.to(torch.long), token_type_ids, position_ids, return_dict=return_dict
|
|
283
|
+
)
|
|
284
|
+
text_features = text_outputs.last_hidden_state if return_dict else text_outputs[0]
|
|
285
|
+
text_features = self.text_projection(text_features)
|
|
286
|
+
|
|
287
|
+
batch_size, num_channels, height, width = pixel_values.shape
|
|
288
|
+
device = pixel_values.device
|
|
289
|
+
|
|
290
|
+
if pixel_mask is None:
|
|
291
|
+
pixel_mask = torch.ones(((batch_size, height, width)), dtype=torch.long, device=device)
|
|
292
|
+
|
|
293
|
+
# Extract multi-scale feature maps of same resolution `config.d_model` (cf Figure 4 in paper)
|
|
294
|
+
# First, sent pixel_values + pixel_mask through Backbone to obtain the features
|
|
295
|
+
# which is a list of tuples
|
|
296
|
+
features = self.backbone(pixel_values)[0]
|
|
297
|
+
vision_features = []
|
|
298
|
+
for feature_map in features:
|
|
299
|
+
# downsample pixel_mask to match shape of corresponding feature_map
|
|
300
|
+
mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
|
|
301
|
+
vision_features.append((feature_map, mask))
|
|
302
|
+
|
|
303
|
+
position_embeddings_list = []
|
|
304
|
+
for feature_map, mask in vision_features:
|
|
305
|
+
# position encoding
|
|
306
|
+
position_embeddings_list.append(self.backbone_position_embedding(feature_map, mask).to(feature_map.dtype))
|
|
307
|
+
vision_features, position_embeddings_list
|
|
308
|
+
|
|
309
|
+
# Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
|
|
310
|
+
feature_maps = []
|
|
311
|
+
masks = []
|
|
312
|
+
for level, (source, mask) in enumerate(vision_features):
|
|
313
|
+
feature_maps.append(self.input_proj_vision[level](source))
|
|
314
|
+
masks.append(mask)
|
|
315
|
+
|
|
316
|
+
# Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
|
|
317
|
+
if self.config.num_feature_levels > len(feature_maps):
|
|
318
|
+
_len_sources = len(feature_maps)
|
|
319
|
+
for level in range(_len_sources, self.config.num_feature_levels):
|
|
320
|
+
if level == _len_sources:
|
|
321
|
+
source = self.input_proj_vision[level](vision_features[-1][0])
|
|
322
|
+
else:
|
|
323
|
+
source = self.input_proj_vision[level](feature_maps[-1])
|
|
324
|
+
mask = nn.functional.interpolate(pixel_mask[None].float(), size=source.shape[-2:]).to(torch.bool)[0]
|
|
325
|
+
pos_l = self.backbone_position_embedding(source, mask).to(source.dtype)
|
|
326
|
+
feature_maps.append(source)
|
|
327
|
+
masks.append(mask)
|
|
328
|
+
position_embeddings_list.append(pos_l)
|
|
329
|
+
|
|
330
|
+
# Create queries
|
|
331
|
+
query_embeds = None
|
|
332
|
+
if self.config.embedding_init_target or self.config.two_stage:
|
|
333
|
+
query_embeds = self.query_position_embeddings.weight
|
|
334
|
+
|
|
335
|
+
# Prepare encoder inputs (by flattening)
|
|
336
|
+
source_flatten = []
|
|
337
|
+
mask_flatten = []
|
|
338
|
+
lvl_pos_embed_flatten = []
|
|
339
|
+
spatial_shapes_list = []
|
|
340
|
+
for level, (source, mask, pos_embed) in enumerate(zip(feature_maps, masks, position_embeddings_list)):
|
|
341
|
+
batch_size, num_channels, height, width = source.shape
|
|
342
|
+
spatial_shape = (height, width)
|
|
343
|
+
spatial_shapes_list.append(spatial_shape)
|
|
344
|
+
source = source.flatten(2).transpose(1, 2)
|
|
345
|
+
mask = mask.flatten(1)
|
|
346
|
+
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
|
347
|
+
lvl_pos_embed = pos_embed + self.level_embed[level].view(1, 1, -1)
|
|
348
|
+
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
|
349
|
+
source_flatten.append(source)
|
|
350
|
+
mask_flatten.append(mask)
|
|
351
|
+
source_flatten = torch.cat(source_flatten, 1)
|
|
352
|
+
mask_flatten = torch.cat(mask_flatten, 1)
|
|
353
|
+
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
|
354
|
+
spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
|
|
355
|
+
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
|
356
|
+
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
|
|
357
|
+
valid_ratios = valid_ratios.float()
|
|
358
|
+
|
|
359
|
+
# Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
|
|
360
|
+
# Also provide spatial_shapes, level_start_index and valid_ratios
|
|
361
|
+
if encoder_outputs is None:
|
|
362
|
+
encoder_outputs = self.encoder(
|
|
363
|
+
vision_features=source_flatten,
|
|
364
|
+
vision_attention_mask=~mask_flatten,
|
|
365
|
+
vision_position_embedding=lvl_pos_embed_flatten,
|
|
366
|
+
spatial_shapes=spatial_shapes,
|
|
367
|
+
spatial_shapes_list=spatial_shapes_list,
|
|
368
|
+
level_start_index=level_start_index,
|
|
369
|
+
valid_ratios=valid_ratios,
|
|
370
|
+
text_features=text_features,
|
|
371
|
+
text_attention_mask=~text_token_mask,
|
|
372
|
+
text_position_embedding=None,
|
|
373
|
+
text_self_attention_masks=~text_self_attention_masks,
|
|
374
|
+
text_position_ids=position_ids,
|
|
375
|
+
output_attentions=output_attentions,
|
|
376
|
+
output_hidden_states=output_hidden_states,
|
|
377
|
+
return_dict=True,
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
# Fifth, prepare decoder inputs
|
|
381
|
+
topk_proposals = None
|
|
382
|
+
enc_outputs_class = None
|
|
383
|
+
enc_outputs_coord_logits = None
|
|
384
|
+
encoder_logits = None
|
|
385
|
+
encoder_pred_boxes = None
|
|
386
|
+
if self.config.two_stage:
|
|
387
|
+
object_query_embedding, output_proposals = self.generate_encoder_output_proposals(
|
|
388
|
+
encoder_outputs[0], ~mask_flatten, spatial_shapes
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
# hack implementation as in two-stage Deformable DETR
|
|
392
|
+
# apply a detection head to each pixel (A.4 in paper)
|
|
393
|
+
# linear projection for bounding box binary classification (i.e. foreground and background)
|
|
394
|
+
enc_outputs_class = self.encoder_output_class_embed(
|
|
395
|
+
object_query_embedding, encoder_outputs[1], text_token_mask
|
|
396
|
+
)
|
|
397
|
+
# 3-layer FFN to predict bounding boxes coordinates (bbox regression branch)
|
|
398
|
+
delta_bbox = self.encoder_output_bbox_embed(object_query_embedding)
|
|
399
|
+
enc_outputs_coord_logits = delta_bbox + output_proposals
|
|
400
|
+
|
|
401
|
+
# only keep top scoring `config.num_queries` proposals
|
|
402
|
+
topk = self.config.num_queries
|
|
403
|
+
topk_logits = enc_outputs_class.max(-1)[0]
|
|
404
|
+
topk_proposals = torch.topk(topk_logits, topk, dim=1)[1]
|
|
405
|
+
topk_coords_logits = torch.gather(
|
|
406
|
+
enc_outputs_coord_logits, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
topk_coords_logits = topk_coords_logits.detach()
|
|
410
|
+
reference_points = (
|
|
411
|
+
topk_coords_logits.sigmoid() if _init_reference_points is None else _init_reference_points
|
|
412
|
+
)
|
|
413
|
+
init_reference_points = reference_points
|
|
414
|
+
if query_embeds is not None:
|
|
415
|
+
target = query_embeds.unsqueeze(0).repeat(batch_size, 1, 1)
|
|
416
|
+
else:
|
|
417
|
+
target = torch.gather(
|
|
418
|
+
object_query_embedding, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)
|
|
419
|
+
).detach()
|
|
420
|
+
|
|
421
|
+
# Set intermediate topk proposals (coords and class) for loss computation
|
|
422
|
+
encoder_pred_boxes = reference_points
|
|
423
|
+
encoder_logits = self.encoder_output_class_embed(target, text_features, text_token_mask)
|
|
424
|
+
else:
|
|
425
|
+
target = query_embeds.unsqueeze(0).repeat(batch_size, 1, 1)
|
|
426
|
+
reference_points = self.reference_points.weight.unsqueeze(0).repeat(batch_size, 1, 1).sigmoid()
|
|
427
|
+
init_reference_points = reference_points
|
|
428
|
+
|
|
429
|
+
decoder_outputs = self.decoder(
|
|
430
|
+
inputs_embeds=target,
|
|
431
|
+
vision_encoder_hidden_states=encoder_outputs[0],
|
|
432
|
+
vision_encoder_attention_mask=mask_flatten,
|
|
433
|
+
text_encoder_hidden_states=encoder_outputs[1],
|
|
434
|
+
text_encoder_attention_mask=~text_token_mask,
|
|
435
|
+
reference_points=reference_points,
|
|
436
|
+
spatial_shapes=spatial_shapes,
|
|
437
|
+
spatial_shapes_list=spatial_shapes_list,
|
|
438
|
+
level_start_index=level_start_index,
|
|
439
|
+
valid_ratios=valid_ratios,
|
|
440
|
+
self_attn_mask=None,
|
|
441
|
+
output_attentions=output_attentions,
|
|
442
|
+
output_hidden_states=output_hidden_states,
|
|
443
|
+
return_dict=return_dict,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
if not return_dict:
|
|
447
|
+
enc_outputs = tuple(
|
|
448
|
+
value
|
|
449
|
+
for value in [
|
|
450
|
+
enc_outputs_class,
|
|
451
|
+
enc_outputs_coord_logits,
|
|
452
|
+
encoder_logits,
|
|
453
|
+
encoder_pred_boxes,
|
|
454
|
+
]
|
|
455
|
+
if value is not None
|
|
456
|
+
)
|
|
457
|
+
tuple_outputs = (
|
|
458
|
+
(decoder_outputs[0], init_reference_points) + decoder_outputs[1:] + encoder_outputs + enc_outputs
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
return tuple_outputs
|
|
462
|
+
|
|
463
|
+
return GroundingDinoModelOutput(
|
|
464
|
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
465
|
+
init_reference_points=init_reference_points,
|
|
466
|
+
intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
|
|
467
|
+
intermediate_reference_points=decoder_outputs.intermediate_reference_points,
|
|
468
|
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
469
|
+
decoder_attentions=decoder_outputs.attentions,
|
|
470
|
+
encoder_last_hidden_state_vision=encoder_outputs.last_hidden_state_vision,
|
|
471
|
+
encoder_last_hidden_state_text=encoder_outputs.last_hidden_state_text,
|
|
472
|
+
encoder_vision_hidden_states=encoder_outputs.vision_hidden_states,
|
|
473
|
+
encoder_text_hidden_states=encoder_outputs.text_hidden_states,
|
|
474
|
+
encoder_attentions=encoder_outputs.attentions,
|
|
475
|
+
enc_outputs_class=enc_outputs_class,
|
|
476
|
+
enc_outputs_coord_logits=enc_outputs_coord_logits,
|
|
477
|
+
encoder_logits=encoder_logits,
|
|
478
|
+
encoder_pred_boxes=encoder_pred_boxes,
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
def pad_image_to_rbln_config(self, pixel_values: torch.FloatTensor, pixel_mask: torch.BoolTensor):
|
|
482
|
+
batch_size, _, height, width = pixel_values.shape
|
|
483
|
+
image_height, image_width = self.rbln_config.encoder.image_height, self.rbln_config.encoder.image_width
|
|
484
|
+
|
|
485
|
+
pad_h = image_height - height
|
|
486
|
+
pad_w = image_width - width
|
|
487
|
+
pixel_mask = (
|
|
488
|
+
pixel_mask
|
|
489
|
+
if pixel_mask is not None
|
|
490
|
+
else torch.ones(((batch_size, height, width)), dtype=torch.long, device=pixel_values.device)
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
if pad_h < 0 or pad_w < 0:
|
|
494
|
+
raise ValueError(
|
|
495
|
+
f"Image size {height}x{width} is larger than encoder's image_size {image_height}x{image_width}"
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
if pad_h > 0 or pad_w > 0:
|
|
499
|
+
pixel_values = torch.nn.functional.pad(pixel_values, (0, pad_w, 0, pad_h), value=0)
|
|
500
|
+
pixel_mask = torch.nn.functional.pad(pixel_mask, (0, pad_w, 0, pad_h), value=0)
|
|
501
|
+
|
|
502
|
+
return pixel_values, pixel_mask
|
|
503
|
+
|
|
504
|
+
def pad_text_to_rbln_config(
|
|
505
|
+
self,
|
|
506
|
+
input_ids: torch.LongTensor,
|
|
507
|
+
token_type_ids: Optional[torch.LongTensor] = None,
|
|
508
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
509
|
+
):
|
|
510
|
+
batch_size, seq_len = input_ids.shape
|
|
511
|
+
max_text_len = self.config.max_text_len
|
|
512
|
+
token_type_ids = token_type_ids if token_type_ids is not None else torch.zeros_like(input_ids)
|
|
513
|
+
attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids)
|
|
514
|
+
if seq_len < max_text_len:
|
|
515
|
+
input_ids = torch.nn.functional.pad(input_ids, (0, max_text_len - seq_len, 0, 0), value=0)
|
|
516
|
+
token_type_ids = torch.nn.functional.pad(token_type_ids, (0, max_text_len - seq_len, 0, 0), value=0)
|
|
517
|
+
attention_mask = torch.nn.functional.pad(attention_mask, (0, max_text_len - seq_len, 0, 0), value=0)
|
|
518
|
+
|
|
519
|
+
return input_ids, token_type_ids, attention_mask
|
|
520
|
+
|
|
521
|
+
def forward(
|
|
522
|
+
self,
|
|
523
|
+
pixel_values: torch.FloatTensor,
|
|
524
|
+
input_ids: torch.LongTensor,
|
|
525
|
+
token_type_ids: Optional[torch.LongTensor] = None,
|
|
526
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
527
|
+
pixel_mask: Optional[torch.BoolTensor] = None,
|
|
528
|
+
encoder_outputs: Optional[Union[GroundingDinoEncoderOutput, Tuple]] = None,
|
|
529
|
+
output_attentions: Optional[bool] = None,
|
|
530
|
+
output_hidden_states: Optional[bool] = None,
|
|
531
|
+
return_dict: Optional[bool] = None,
|
|
532
|
+
**kwargs,
|
|
533
|
+
) -> Union[GroundingDinoObjectDetectionOutput, Tuple]:
|
|
534
|
+
"""
|
|
535
|
+
Forward pass for the RBLN-optimized GroundingDinoForObjectDetection model.
|
|
536
|
+
|
|
537
|
+
Args:
|
|
538
|
+
pixel_values (torch.Tensor of shape (batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images.
|
|
539
|
+
input_ids (torch.LongTensor of shape (batch_size, text_sequence_length)): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
|
|
540
|
+
token_type_ids (torch.LongTensor of shape (batch_size, text_sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
|
|
541
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
542
|
+
pixel_mask (torch.Tensor of shape (batch_size, height, width), optional): Mask to avoid performing attention on padding pixel values.
|
|
543
|
+
encoder_outputs (Tuple consists of last_hidden_state of shape(batch_size, sequence_length, hidden_size), optional): A sequence of hidden-states at the output of the last layer of the encoder.
|
|
544
|
+
output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers.
|
|
545
|
+
output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers.
|
|
546
|
+
return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
|
|
547
|
+
|
|
548
|
+
Returns:
|
|
549
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a GroundingDinoObjectDetectionOutput object.
|
|
550
|
+
"""
|
|
551
|
+
|
|
552
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
553
|
+
|
|
554
|
+
# Pad image to rbln_config.image_height and rbln_config.image_width
|
|
555
|
+
pixel_values, pixel_mask = self.pad_image_to_rbln_config(pixel_values, pixel_mask)
|
|
556
|
+
input_ids, token_type_ids, attention_mask = self.pad_text_to_rbln_config(
|
|
557
|
+
input_ids, token_type_ids, attention_mask
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
with torch.inference_mode():
|
|
561
|
+
# First, sent images through Grounding DINO base model to obtain encoder + decoder outputs
|
|
562
|
+
outputs = self._model_forward(
|
|
563
|
+
pixel_values=pixel_values,
|
|
564
|
+
input_ids=input_ids,
|
|
565
|
+
token_type_ids=token_type_ids,
|
|
566
|
+
attention_mask=attention_mask,
|
|
567
|
+
pixel_mask=pixel_mask,
|
|
568
|
+
encoder_outputs=encoder_outputs,
|
|
569
|
+
output_attentions=output_attentions,
|
|
570
|
+
output_hidden_states=output_hidden_states,
|
|
571
|
+
return_dict=return_dict,
|
|
572
|
+
**kwargs,
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
idx = 5 + (1 if output_attentions else 0) + (1 if output_hidden_states else 0)
|
|
576
|
+
enc_text_hidden_state = outputs.encoder_last_hidden_state_text if return_dict else outputs[idx]
|
|
577
|
+
hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[2]
|
|
578
|
+
init_reference_points = outputs.init_reference_points if return_dict else outputs[1]
|
|
579
|
+
inter_references_points = outputs.intermediate_reference_points if return_dict else outputs[3]
|
|
580
|
+
|
|
581
|
+
# class logits + predicted bounding boxes
|
|
582
|
+
outputs_classes = []
|
|
583
|
+
outputs_coords = []
|
|
584
|
+
|
|
585
|
+
# hidden_states are of shape (batch_size, num_stages, height, width)
|
|
586
|
+
# predict class and bounding box deltas for each stage
|
|
587
|
+
num_levels = hidden_states.shape[1]
|
|
588
|
+
for level in range(num_levels):
|
|
589
|
+
if level == 0:
|
|
590
|
+
reference = init_reference_points
|
|
591
|
+
else:
|
|
592
|
+
reference = inter_references_points[:, level - 1]
|
|
593
|
+
reference = torch.special.logit(reference, eps=1e-5)
|
|
594
|
+
outputs_class = self.class_embed[level](
|
|
595
|
+
vision_hidden_state=hidden_states[:, level],
|
|
596
|
+
text_hidden_state=enc_text_hidden_state,
|
|
597
|
+
text_token_mask=attention_mask.bool(),
|
|
598
|
+
)
|
|
599
|
+
delta_bbox = self.bbox_embed[level](hidden_states[:, level])
|
|
600
|
+
|
|
601
|
+
reference_coordinates = reference.shape[-1]
|
|
602
|
+
if reference_coordinates == 4:
|
|
603
|
+
outputs_coord_logits = delta_bbox + reference
|
|
604
|
+
elif reference_coordinates == 2:
|
|
605
|
+
delta_bbox[..., :2] += reference
|
|
606
|
+
outputs_coord_logits = delta_bbox
|
|
607
|
+
else:
|
|
608
|
+
raise ValueError(f"reference.shape[-1] should be 4 or 2, but got {reference.shape[-1]}")
|
|
609
|
+
outputs_coord = outputs_coord_logits.sigmoid()
|
|
610
|
+
outputs_classes.append(outputs_class)
|
|
611
|
+
outputs_coords.append(outputs_coord)
|
|
612
|
+
outputs_class = torch.stack(outputs_classes)
|
|
613
|
+
outputs_coord = torch.stack(outputs_coords)
|
|
614
|
+
|
|
615
|
+
logits = outputs_class[-1]
|
|
616
|
+
pred_boxes = outputs_coord[-1]
|
|
617
|
+
|
|
618
|
+
if not return_dict:
|
|
619
|
+
auxiliary_outputs = []
|
|
620
|
+
output = [logits, pred_boxes, *auxiliary_outputs, *outputs, input_ids]
|
|
621
|
+
output = tuple(out for out in output if out is not None)
|
|
622
|
+
return output
|
|
623
|
+
|
|
624
|
+
return GroundingDinoObjectDetectionOutput(
|
|
625
|
+
logits=logits,
|
|
626
|
+
pred_boxes=pred_boxes,
|
|
627
|
+
last_hidden_state=outputs.last_hidden_state,
|
|
628
|
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
|
629
|
+
decoder_attentions=outputs.decoder_attentions,
|
|
630
|
+
encoder_last_hidden_state_vision=outputs.encoder_last_hidden_state_vision,
|
|
631
|
+
encoder_last_hidden_state_text=outputs.encoder_last_hidden_state_text,
|
|
632
|
+
encoder_vision_hidden_states=outputs.encoder_vision_hidden_states,
|
|
633
|
+
encoder_text_hidden_states=outputs.encoder_text_hidden_states,
|
|
634
|
+
encoder_attentions=outputs.encoder_attentions,
|
|
635
|
+
intermediate_hidden_states=outputs.intermediate_hidden_states,
|
|
636
|
+
intermediate_reference_points=outputs.intermediate_reference_points,
|
|
637
|
+
init_reference_points=outputs.init_reference_points,
|
|
638
|
+
enc_outputs_class=outputs.enc_outputs_class,
|
|
639
|
+
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
|
|
640
|
+
encoder_logits=outputs.encoder_logits,
|
|
641
|
+
encoder_pred_boxes=outputs.encoder_pred_boxes,
|
|
642
|
+
input_ids=input_ids,
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
def _update_spatial_shapes(model_config, rbln_config):
|
|
647
|
+
def down_sampled_size(x, depth: int = 1):
|
|
648
|
+
if depth == 0:
|
|
649
|
+
return x
|
|
650
|
+
return down_sampled_size((x + 1) // 2, depth - 1)
|
|
651
|
+
|
|
652
|
+
def num_patches(image_size, patch_size):
|
|
653
|
+
return (image_size + patch_size - 1) // patch_size
|
|
654
|
+
|
|
655
|
+
# update spatial_shapes
|
|
656
|
+
spatial_shapes = []
|
|
657
|
+
backbone_config = model_config.backbone_config
|
|
658
|
+
num_patched_h = num_patches(rbln_config.image_height, backbone_config.patch_size)
|
|
659
|
+
num_patched_w = num_patches(rbln_config.image_height, backbone_config.patch_size)
|
|
660
|
+
for out_layer in backbone_config.out_indices:
|
|
661
|
+
spatial_shapes.append(
|
|
662
|
+
[down_sampled_size(num_patched_h, out_layer - 1), down_sampled_size(num_patched_w, out_layer - 1)]
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
# Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
|
|
666
|
+
if model_config.num_feature_levels > len(spatial_shapes):
|
|
667
|
+
last_h, last_w = spatial_shapes[-1][0], spatial_shapes[-1][1]
|
|
668
|
+
h_out = (last_h - 1) // 2 + 1
|
|
669
|
+
w_out = (last_w - 1) // 2 + 1
|
|
670
|
+
spatial_shapes.append([h_out, w_out])
|
|
671
|
+
|
|
672
|
+
rbln_config.spatial_shapes_list = spatial_shapes
|
|
673
|
+
|
|
674
|
+
return rbln_config
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
class RBLNGroundingDinoEncoder(RBLNModel):
|
|
678
|
+
def __post_init__(self, **kwargs):
|
|
679
|
+
self.encoder_runtime = RBLNPytorchRuntime(self.model[0])
|
|
680
|
+
|
|
681
|
+
@classmethod
|
|
682
|
+
def _wrap_model_if_needed(
|
|
683
|
+
cls, model: torch.nn.Module, rbln_config: RBLNGroundingDinoForObjectDetectionConfig
|
|
684
|
+
) -> torch.nn.Module:
|
|
685
|
+
model = _GroundingDinoEncoder(model, rbln_config).eval()
|
|
686
|
+
return model
|
|
687
|
+
|
|
688
|
+
@classmethod
|
|
689
|
+
def _update_submodule_config(
|
|
690
|
+
cls,
|
|
691
|
+
model: "PreTrainedModel",
|
|
692
|
+
rbln_config: RBLNModelConfig,
|
|
693
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
694
|
+
):
|
|
695
|
+
for processor in preprocessors:
|
|
696
|
+
if rbln_config.image_size is None and hasattr(processor, "image_processor"):
|
|
697
|
+
if "height" in processor.image_processor.size and "width" in processor.image_processor.size:
|
|
698
|
+
rbln_config.image_size = (
|
|
699
|
+
processor.image_processor.size["height"],
|
|
700
|
+
processor.image_processor.size["width"],
|
|
701
|
+
)
|
|
702
|
+
elif (
|
|
703
|
+
"longest_edge" in processor.image_processor.size
|
|
704
|
+
and "shortest_edge" in processor.image_processor.size
|
|
705
|
+
):
|
|
706
|
+
rbln_config.image_size = processor.image_processor.size["longest_edge"]
|
|
707
|
+
elif "shortest_edge" in processor.image_processor.size:
|
|
708
|
+
rbln_config.image_size = processor.image_processor.size["shortest_edge"]
|
|
709
|
+
break
|
|
710
|
+
rbln_config = _update_spatial_shapes(model.config, rbln_config)
|
|
711
|
+
return rbln_config
|
|
712
|
+
|
|
713
|
+
@classmethod
|
|
714
|
+
def _update_rbln_config(
|
|
715
|
+
cls,
|
|
716
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
717
|
+
model: Optional["PreTrainedModel"] = None,
|
|
718
|
+
model_config: RBLNGroundingDinoEncoderConfig = None,
|
|
719
|
+
rbln_config: Optional[RBLNGroundingDinoEncoderConfig] = None,
|
|
720
|
+
) -> RBLNGroundingDinoEncoderConfig:
|
|
721
|
+
if rbln_config.image_size is None:
|
|
722
|
+
raise ValueError("RBLN config must have image_size set for RBLN optimized GroundingDinoDecoder.")
|
|
723
|
+
|
|
724
|
+
vision_seq_len = int((rbln_config.spatial_shapes[:, 0] * rbln_config.spatial_shapes[:, 1]).sum())
|
|
725
|
+
|
|
726
|
+
input_info = [
|
|
727
|
+
(
|
|
728
|
+
"vision_features",
|
|
729
|
+
[rbln_config.batch_size, vision_seq_len, model_config.d_model],
|
|
730
|
+
"float32",
|
|
731
|
+
),
|
|
732
|
+
(
|
|
733
|
+
"vision_attention_mask",
|
|
734
|
+
[
|
|
735
|
+
rbln_config.batch_size,
|
|
736
|
+
vision_seq_len,
|
|
737
|
+
model_config.d_model,
|
|
738
|
+
],
|
|
739
|
+
"float32",
|
|
740
|
+
),
|
|
741
|
+
(
|
|
742
|
+
"vision_position_embedding",
|
|
743
|
+
[rbln_config.batch_size, vision_seq_len, model_config.d_model],
|
|
744
|
+
"float32",
|
|
745
|
+
),
|
|
746
|
+
(
|
|
747
|
+
"text_features",
|
|
748
|
+
[rbln_config.batch_size, model_config.max_text_len, model_config.d_model],
|
|
749
|
+
"float32",
|
|
750
|
+
),
|
|
751
|
+
(
|
|
752
|
+
"text_attention_mask",
|
|
753
|
+
[
|
|
754
|
+
rbln_config.batch_size,
|
|
755
|
+
model_config.max_text_len,
|
|
756
|
+
],
|
|
757
|
+
"float32",
|
|
758
|
+
),
|
|
759
|
+
(
|
|
760
|
+
"text_self_attention_masks",
|
|
761
|
+
[
|
|
762
|
+
rbln_config.batch_size,
|
|
763
|
+
model_config.max_text_len,
|
|
764
|
+
model_config.max_text_len,
|
|
765
|
+
],
|
|
766
|
+
"float32",
|
|
767
|
+
),
|
|
768
|
+
(
|
|
769
|
+
"reference_points",
|
|
770
|
+
[rbln_config.batch_size, vision_seq_len, 4, 2],
|
|
771
|
+
"float32",
|
|
772
|
+
),
|
|
773
|
+
]
|
|
774
|
+
|
|
775
|
+
rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
|
|
776
|
+
|
|
777
|
+
return rbln_config
|
|
778
|
+
|
|
779
|
+
@staticmethod
|
|
780
|
+
def get_reference_points(spatial_shapes, valid_ratios, device):
|
|
781
|
+
reference_points_list = []
|
|
782
|
+
for level, (height, width) in enumerate(spatial_shapes):
|
|
783
|
+
ref_y, ref_x = meshgrid(
|
|
784
|
+
torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
|
|
785
|
+
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device),
|
|
786
|
+
indexing="ij",
|
|
787
|
+
)
|
|
788
|
+
# TODO: valid_ratios could be useless here. check https://github.com/fundamentalvision/Deformable-DETR/issues/36
|
|
789
|
+
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, level, 1] * height)
|
|
790
|
+
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, level, 0] * width)
|
|
791
|
+
ref = torch.stack((ref_x, ref_y), -1)
|
|
792
|
+
reference_points_list.append(ref)
|
|
793
|
+
reference_points = torch.cat(reference_points_list, 1)
|
|
794
|
+
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
|
|
795
|
+
return reference_points
|
|
796
|
+
|
|
797
|
+
def validate_output_config(self, output_attentions, output_hidden_states):
|
|
798
|
+
if output_attentions != self.rbln_config.output_attentions:
|
|
799
|
+
raise ValueError(
|
|
800
|
+
f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
|
|
801
|
+
f"Please compile again with the correct argument."
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
805
|
+
raise ValueError(
|
|
806
|
+
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
807
|
+
f"Please compile again with the correct argument."
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
def forward(
|
|
811
|
+
self,
|
|
812
|
+
vision_features: Tensor,
|
|
813
|
+
vision_attention_mask: Tensor,
|
|
814
|
+
vision_position_embedding: Tensor,
|
|
815
|
+
spatial_shapes: Tensor,
|
|
816
|
+
spatial_shapes_list: List[Tuple[int, int]],
|
|
817
|
+
level_start_index: Tensor,
|
|
818
|
+
valid_ratios: Optional[Tensor] = None,
|
|
819
|
+
text_features: Optional[Tensor] = None,
|
|
820
|
+
text_attention_mask: Optional[Tensor] = None,
|
|
821
|
+
text_position_embedding: Optional[Tensor] = None,
|
|
822
|
+
text_self_attention_masks: Optional[Tensor] = None,
|
|
823
|
+
text_position_ids: Optional[Tensor] = None,
|
|
824
|
+
output_attentions: Optional[bool] = None,
|
|
825
|
+
output_hidden_states: Optional[bool] = None,
|
|
826
|
+
return_dict: Optional[bool] = None,
|
|
827
|
+
):
|
|
828
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
829
|
+
output_hidden_states = (
|
|
830
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
831
|
+
)
|
|
832
|
+
self.validate_output_config(output_attentions, output_hidden_states)
|
|
833
|
+
|
|
834
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
835
|
+
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device="cpu")
|
|
836
|
+
vision_attention_mask = vision_attention_mask.to(torch.float32).unsqueeze(-1).repeat(1, 1, self.config.d_model)
|
|
837
|
+
|
|
838
|
+
enc_outputs = self.encoder_runtime(
|
|
839
|
+
vision_features=vision_features,
|
|
840
|
+
vision_attention_mask=vision_attention_mask,
|
|
841
|
+
vision_position_embedding=vision_position_embedding,
|
|
842
|
+
text_features=text_features,
|
|
843
|
+
text_attention_mask=text_attention_mask.to(torch.float32),
|
|
844
|
+
text_self_attention_masks=text_self_attention_masks.to(torch.float32),
|
|
845
|
+
reference_points=reference_points,
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
if not return_dict:
|
|
849
|
+
return tuple(enc_outputs)
|
|
850
|
+
|
|
851
|
+
enc_outputs = list(enc_outputs)
|
|
852
|
+
last_hidden_state_vision = enc_outputs.pop(0)
|
|
853
|
+
last_hidden_state_text = enc_outputs.pop(0)
|
|
854
|
+
vision_hidden_states = (
|
|
855
|
+
tuple([enc_outputs.pop(0) for _ in range(self.config.encoder_layers + 1)])
|
|
856
|
+
if self.rbln_config.output_hidden_states
|
|
857
|
+
else None
|
|
858
|
+
)
|
|
859
|
+
text_hidden_states = (
|
|
860
|
+
tuple([enc_outputs.pop(0) for _ in range(self.config.encoder_layers + 1)])
|
|
861
|
+
if self.rbln_config.output_hidden_states
|
|
862
|
+
else None
|
|
863
|
+
)
|
|
864
|
+
attentions = tuple(enc_outputs) if self.rbln_config.output_attentions else None
|
|
865
|
+
|
|
866
|
+
return GroundingDinoEncoderOutput(
|
|
867
|
+
last_hidden_state_vision=last_hidden_state_vision,
|
|
868
|
+
last_hidden_state_text=last_hidden_state_text,
|
|
869
|
+
vision_hidden_states=vision_hidden_states,
|
|
870
|
+
text_hidden_states=text_hidden_states,
|
|
871
|
+
attentions=attentions,
|
|
872
|
+
)
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
class RBLNGroundingDinoDecoder(RBLNModel):
|
|
876
|
+
def __post_init__(self, **kwargs):
|
|
877
|
+
self.decoder_runtime = RBLNPytorchRuntime(self.model[0])
|
|
878
|
+
|
|
879
|
+
@classmethod
|
|
880
|
+
def _wrap_model_if_needed(
|
|
881
|
+
cls, model: torch.nn.Module, rbln_config: RBLNGroundingDinoForObjectDetectionConfig
|
|
882
|
+
) -> torch.nn.Module:
|
|
883
|
+
return _GroundingDinoDecoder(model, rbln_config).eval()
|
|
884
|
+
|
|
885
|
+
@classmethod
|
|
886
|
+
def _update_submodule_config(
|
|
887
|
+
cls,
|
|
888
|
+
model: "PreTrainedModel",
|
|
889
|
+
rbln_config: RBLNModelConfig,
|
|
890
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
891
|
+
):
|
|
892
|
+
for processor in preprocessors:
|
|
893
|
+
if rbln_config.image_size is None and hasattr(processor, "image_processor"):
|
|
894
|
+
if "height" in processor.image_processor.size and "width" in processor.image_processor.size:
|
|
895
|
+
rbln_config.image_size = (
|
|
896
|
+
processor.image_processor.size["height"],
|
|
897
|
+
processor.image_processor.size["width"],
|
|
898
|
+
)
|
|
899
|
+
elif (
|
|
900
|
+
"longest_edge" in processor.image_processor.size
|
|
901
|
+
and "shortest_edge" in processor.image_processor.size
|
|
902
|
+
):
|
|
903
|
+
rbln_config.image_size = processor.image_processor.size["longest_edge"]
|
|
904
|
+
elif "shortest_edge" in processor.image_processor.size:
|
|
905
|
+
rbln_config.image_size = processor.image_processor.size["shortest_edge"]
|
|
906
|
+
break
|
|
907
|
+
rbln_config = _update_spatial_shapes(model.config, rbln_config)
|
|
908
|
+
|
|
909
|
+
return rbln_config
|
|
910
|
+
|
|
911
|
+
@classmethod
|
|
912
|
+
def _update_rbln_config(
|
|
913
|
+
cls,
|
|
914
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
915
|
+
model: Optional["PreTrainedModel"] = None,
|
|
916
|
+
model_config: RBLNGroundingDinoDecoderConfig = None,
|
|
917
|
+
rbln_config: Optional[RBLNGroundingDinoEncoderConfig] = None,
|
|
918
|
+
) -> RBLNGroundingDinoEncoderConfig:
|
|
919
|
+
if rbln_config.image_size is None:
|
|
920
|
+
raise ValueError("RBLN config must have image_size set for RBLN optimized GroundingDinoDecoder.")
|
|
921
|
+
|
|
922
|
+
vision_seq_len = int((rbln_config.spatial_shapes[:, 0] * rbln_config.spatial_shapes[:, 1]).sum())
|
|
923
|
+
|
|
924
|
+
input_info = [
|
|
925
|
+
(
|
|
926
|
+
"inputs_embeds",
|
|
927
|
+
[rbln_config.batch_size, model_config.num_queries, model_config.d_model],
|
|
928
|
+
"float32",
|
|
929
|
+
),
|
|
930
|
+
(
|
|
931
|
+
"vision_encoder_hidden_states",
|
|
932
|
+
[
|
|
933
|
+
rbln_config.batch_size,
|
|
934
|
+
vision_seq_len,
|
|
935
|
+
model_config.d_model,
|
|
936
|
+
],
|
|
937
|
+
"float32",
|
|
938
|
+
),
|
|
939
|
+
(
|
|
940
|
+
"vision_encoder_attention_mask",
|
|
941
|
+
[rbln_config.batch_size, vision_seq_len, model_config.d_model],
|
|
942
|
+
"float32",
|
|
943
|
+
),
|
|
944
|
+
(
|
|
945
|
+
"text_encoder_hidden_states",
|
|
946
|
+
[rbln_config.batch_size, model_config.max_text_len, model_config.d_model],
|
|
947
|
+
"float32",
|
|
948
|
+
),
|
|
949
|
+
(
|
|
950
|
+
"text_encoder_attention_mask",
|
|
951
|
+
[
|
|
952
|
+
rbln_config.batch_size,
|
|
953
|
+
model_config.max_text_len,
|
|
954
|
+
],
|
|
955
|
+
"float32",
|
|
956
|
+
),
|
|
957
|
+
(
|
|
958
|
+
"reference_points",
|
|
959
|
+
[
|
|
960
|
+
rbln_config.batch_size,
|
|
961
|
+
model_config.num_queries,
|
|
962
|
+
4,
|
|
963
|
+
],
|
|
964
|
+
"float32",
|
|
965
|
+
),
|
|
966
|
+
(
|
|
967
|
+
"valid_ratios",
|
|
968
|
+
[
|
|
969
|
+
rbln_config.batch_size,
|
|
970
|
+
4,
|
|
971
|
+
2,
|
|
972
|
+
],
|
|
973
|
+
"float32",
|
|
974
|
+
),
|
|
975
|
+
]
|
|
976
|
+
|
|
977
|
+
rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
|
|
978
|
+
return rbln_config
|
|
979
|
+
|
|
980
|
+
def validate_output_config(self, output_attentions, output_hidden_states):
|
|
981
|
+
if output_attentions != self.rbln_config.output_attentions:
|
|
982
|
+
raise ValueError(
|
|
983
|
+
f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
|
|
984
|
+
f"Please compile again with the correct argument."
|
|
985
|
+
)
|
|
986
|
+
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
987
|
+
raise ValueError(
|
|
988
|
+
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
989
|
+
f"Please compile again with the correct argument."
|
|
990
|
+
)
|
|
991
|
+
|
|
992
|
+
def forward(
|
|
993
|
+
self,
|
|
994
|
+
inputs_embeds: torch.Tensor,
|
|
995
|
+
vision_encoder_hidden_states: torch.Tensor,
|
|
996
|
+
vision_encoder_attention_mask: torch.Tensor,
|
|
997
|
+
text_encoder_hidden_states: torch.Tensor,
|
|
998
|
+
text_encoder_attention_mask: torch.Tensor,
|
|
999
|
+
reference_points: torch.Tensor,
|
|
1000
|
+
valid_ratios: torch.Tensor,
|
|
1001
|
+
output_attentions: bool = False,
|
|
1002
|
+
output_hidden_states: bool = False,
|
|
1003
|
+
return_dict: bool = False,
|
|
1004
|
+
**kwargs,
|
|
1005
|
+
):
|
|
1006
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1007
|
+
output_hidden_states = (
|
|
1008
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1009
|
+
)
|
|
1010
|
+
self.validate_output_config(output_attentions, output_hidden_states)
|
|
1011
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1012
|
+
|
|
1013
|
+
reshaped_vision_encoder_attention_mask = (
|
|
1014
|
+
vision_encoder_attention_mask[:, :, None].repeat(1, 1, self.config.d_model).to(torch.float32)
|
|
1015
|
+
)
|
|
1016
|
+
|
|
1017
|
+
# Forward pass through the decoder
|
|
1018
|
+
outputs = self.decoder_runtime(
|
|
1019
|
+
inputs_embeds=inputs_embeds,
|
|
1020
|
+
vision_encoder_hidden_states=vision_encoder_hidden_states,
|
|
1021
|
+
vision_encoder_attention_mask=reshaped_vision_encoder_attention_mask,
|
|
1022
|
+
text_encoder_hidden_states=text_encoder_hidden_states,
|
|
1023
|
+
text_encoder_attention_mask=text_encoder_attention_mask.to(torch.float32),
|
|
1024
|
+
reference_points=reference_points,
|
|
1025
|
+
valid_ratios=valid_ratios,
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
if not return_dict:
|
|
1029
|
+
return outputs
|
|
1030
|
+
|
|
1031
|
+
outputs = list(outputs)
|
|
1032
|
+
last_hidden_state = outputs.pop(0)
|
|
1033
|
+
intermediate_hidden_states = outputs.pop(0)
|
|
1034
|
+
intermediate_reference_points = outputs.pop(0)
|
|
1035
|
+
hidden_states = (
|
|
1036
|
+
tuple([outputs.pop(0) for _ in range(self.config.decoder_layers + 1)])
|
|
1037
|
+
if self.rbln_config.output_hidden_states
|
|
1038
|
+
else None
|
|
1039
|
+
)
|
|
1040
|
+
attentions = tuple(outputs) if self.rbln_config.output_attentions else None
|
|
1041
|
+
|
|
1042
|
+
return GroundingDinoDecoderOutput(
|
|
1043
|
+
last_hidden_state=last_hidden_state,
|
|
1044
|
+
intermediate_hidden_states=intermediate_hidden_states,
|
|
1045
|
+
intermediate_reference_points=intermediate_reference_points,
|
|
1046
|
+
hidden_states=hidden_states,
|
|
1047
|
+
attentions=attentions,
|
|
1048
|
+
)
|