optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,599 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import math
|
|
15
|
+
from functools import wraps
|
|
16
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn.functional as F
|
|
20
|
+
from torch import Tensor
|
|
21
|
+
from transformers.models.grounding_dino.modeling_grounding_dino import (
|
|
22
|
+
GroundingDinoDecoder,
|
|
23
|
+
GroundingDinoEncoder,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from .configuration_grounding_dino import RBLNGroundingDinoDecoderConfig, RBLNGroundingDinoEncoderConfig
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def monkey_patch():
|
|
32
|
+
from transformers.models.grounding_dino.modeling_grounding_dino import (
|
|
33
|
+
GroundingDinoBiMultiHeadAttention,
|
|
34
|
+
GroundingDinoEncoderLayer,
|
|
35
|
+
GroundingDinoMultiscaleDeformableAttention,
|
|
36
|
+
MultiScaleDeformableAttention,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
original_forward = GroundingDinoMultiscaleDeformableAttention.forward
|
|
40
|
+
original_bi_multihead_attention_forward = GroundingDinoBiMultiHeadAttention.forward
|
|
41
|
+
original_encoder_layer_forward = GroundingDinoEncoderLayer.forward
|
|
42
|
+
original_multiscale_deform_attn = MultiScaleDeformableAttention.forward
|
|
43
|
+
|
|
44
|
+
# Patch the methods with the custom implementations
|
|
45
|
+
GroundingDinoMultiscaleDeformableAttention.forward = _GroundingDinoMultiscaleDeformableAttention.forward
|
|
46
|
+
GroundingDinoBiMultiHeadAttention.forward = _GroundingDinoBiMultiHeadAttention.forward
|
|
47
|
+
GroundingDinoEncoderLayer.forward = _GroundingDinoEncoderLayer.forward
|
|
48
|
+
MultiScaleDeformableAttention.forward = _MultiScaleDeformableAttention.forward
|
|
49
|
+
|
|
50
|
+
return (
|
|
51
|
+
original_forward,
|
|
52
|
+
original_bi_multihead_attention_forward,
|
|
53
|
+
original_encoder_layer_forward,
|
|
54
|
+
original_multiscale_deform_attn,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def restore_monkey_patch(
|
|
59
|
+
original_forward,
|
|
60
|
+
original_bi_multihead_attention_forward,
|
|
61
|
+
original_encoder_layer_forward,
|
|
62
|
+
original_multiscale_deform_attn,
|
|
63
|
+
):
|
|
64
|
+
from transformers.models.grounding_dino.modeling_grounding_dino import (
|
|
65
|
+
GroundingDinoBiMultiHeadAttention,
|
|
66
|
+
GroundingDinoEncoderLayer,
|
|
67
|
+
GroundingDinoMultiscaleDeformableAttention,
|
|
68
|
+
MultiScaleDeformableAttention,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Restore the original methods
|
|
72
|
+
GroundingDinoMultiscaleDeformableAttention.forward = original_forward
|
|
73
|
+
GroundingDinoBiMultiHeadAttention.forward = original_bi_multihead_attention_forward
|
|
74
|
+
GroundingDinoEncoderLayer.forward = original_encoder_layer_forward
|
|
75
|
+
MultiScaleDeformableAttention.forward = original_multiscale_deform_attn
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def monkey_patch_decorator(func):
|
|
79
|
+
@wraps(func)
|
|
80
|
+
def wrapper(*args, **kwargs):
|
|
81
|
+
# Apply monkey patch and capture original methods
|
|
82
|
+
original_functions = monkey_patch()
|
|
83
|
+
try:
|
|
84
|
+
# Call the original function
|
|
85
|
+
result = func(*args, **kwargs)
|
|
86
|
+
finally:
|
|
87
|
+
# Restore original methods
|
|
88
|
+
restore_monkey_patch(*original_functions)
|
|
89
|
+
return result
|
|
90
|
+
|
|
91
|
+
return wrapper
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def get_sine_pos_embed(
|
|
95
|
+
pos_tensor: torch.Tensor, num_pos_feats: int = 128, temperature: int = 10000, exchange_xy: bool = True
|
|
96
|
+
) -> Tensor:
|
|
97
|
+
scale = 2 * math.pi
|
|
98
|
+
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
|
|
99
|
+
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
|
|
100
|
+
|
|
101
|
+
scaled_pos = pos_tensor.unsqueeze(-1) * scale / dim_t
|
|
102
|
+
reshaped_pos = scaled_pos.view(*scaled_pos.shape[:-1], -1, 2)
|
|
103
|
+
sin_chunk, cos_chunk = torch.split(reshaped_pos, 1, dim=-1)
|
|
104
|
+
sin_embed = sin_chunk.squeeze(-1).sin()
|
|
105
|
+
cos_embed = cos_chunk.squeeze(-1).cos()
|
|
106
|
+
|
|
107
|
+
pos_embed = torch.stack((sin_embed, cos_embed), dim=-1).flatten(-2)
|
|
108
|
+
|
|
109
|
+
if exchange_xy and pos_tensor.shape[-1] >= 2:
|
|
110
|
+
swapped_embeds = torch.cat([pos_embed[..., 1:2, :], pos_embed[..., 0:1, :], pos_embed[..., 2:, :]], dim=-2)
|
|
111
|
+
pos_embed = swapped_embeds
|
|
112
|
+
|
|
113
|
+
position_embeddings = pos_embed.flatten(start_dim=-2)
|
|
114
|
+
|
|
115
|
+
return position_embeddings
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class _GroundingDinoEncoder(torch.nn.Module):
|
|
119
|
+
def __init__(self, model: "GroundingDinoEncoder", rbln_config: "RBLNGroundingDinoEncoderConfig"):
|
|
120
|
+
super().__init__()
|
|
121
|
+
self.layers = model.layers
|
|
122
|
+
self.config = model.config
|
|
123
|
+
self.rbln_config = rbln_config
|
|
124
|
+
self.spatial_shapes = self.rbln_config.spatial_shapes
|
|
125
|
+
self.spatial_shapes_list = self.rbln_config.spatial_shapes_list
|
|
126
|
+
self.text_position_embedding = model.layers[0].get_text_position_embeddings(
|
|
127
|
+
torch.zeros(1, model.config.max_text_len, model.config.d_model),
|
|
128
|
+
None,
|
|
129
|
+
torch.arange(model.config.max_text_len, dtype=torch.int32).unsqueeze(0),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
@monkey_patch_decorator
|
|
133
|
+
def forward(
|
|
134
|
+
self,
|
|
135
|
+
vision_features: torch.Tensor,
|
|
136
|
+
vision_attention_mask: torch.Tensor,
|
|
137
|
+
vision_position_embedding: torch.Tensor,
|
|
138
|
+
text_features: Optional[torch.Tensor] = None,
|
|
139
|
+
text_attention_mask: Optional[torch.Tensor] = None,
|
|
140
|
+
text_self_attention_masks: Optional[torch.Tensor] = None,
|
|
141
|
+
reference_points: Optional[torch.Tensor] = None,
|
|
142
|
+
):
|
|
143
|
+
output_attentions = self.rbln_config.output_attentions
|
|
144
|
+
output_hidden_states = self.rbln_config.output_hidden_states
|
|
145
|
+
|
|
146
|
+
encoder_vision_states = () if output_hidden_states else None
|
|
147
|
+
encoder_text_states = () if output_hidden_states else None
|
|
148
|
+
all_attns = () if output_attentions else None
|
|
149
|
+
all_attn_fused_text = () if output_attentions else None
|
|
150
|
+
all_attn_fused_vision = () if output_attentions else None
|
|
151
|
+
all_attn_enhanced_text = () if output_attentions else None
|
|
152
|
+
all_attn_deformable = () if output_attentions else None
|
|
153
|
+
for i, encoder_layer in enumerate(self.layers):
|
|
154
|
+
if output_hidden_states:
|
|
155
|
+
encoder_vision_states += (vision_features,)
|
|
156
|
+
encoder_text_states += (text_features,)
|
|
157
|
+
|
|
158
|
+
(vision_features, text_features), attentions = encoder_layer(
|
|
159
|
+
vision_features=vision_features,
|
|
160
|
+
vision_position_embedding=vision_position_embedding,
|
|
161
|
+
spatial_shapes=self.spatial_shapes,
|
|
162
|
+
spatial_shapes_list=self.spatial_shapes_list,
|
|
163
|
+
level_start_index=None,
|
|
164
|
+
key_padding_mask=vision_attention_mask,
|
|
165
|
+
reference_points=reference_points,
|
|
166
|
+
text_features=text_features,
|
|
167
|
+
text_attention_mask=text_attention_mask,
|
|
168
|
+
text_position_embedding=self.text_position_embedding,
|
|
169
|
+
text_self_attention_masks=text_self_attention_masks,
|
|
170
|
+
)
|
|
171
|
+
if output_attentions:
|
|
172
|
+
all_attn_fused_vision += (attentions[0],)
|
|
173
|
+
all_attn_fused_text += (attentions[1],)
|
|
174
|
+
all_attn_enhanced_text += (attentions[2],)
|
|
175
|
+
all_attn_deformable += (attentions[3],)
|
|
176
|
+
|
|
177
|
+
if output_hidden_states:
|
|
178
|
+
encoder_vision_states += (vision_features,)
|
|
179
|
+
encoder_text_states += (text_features,)
|
|
180
|
+
|
|
181
|
+
if output_attentions:
|
|
182
|
+
all_attns = (all_attn_fused_vision, all_attn_fused_text, all_attn_enhanced_text, all_attn_deformable)
|
|
183
|
+
|
|
184
|
+
enc_outputs = [vision_features, text_features, encoder_vision_states, encoder_text_states, all_attns]
|
|
185
|
+
|
|
186
|
+
return tuple(v for v in enc_outputs if v is not None)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class _GroundingDinoDecoder(torch.nn.Module):
|
|
190
|
+
def __init__(self, model: "GroundingDinoDecoder", rbln_config: "RBLNGroundingDinoDecoderConfig"):
|
|
191
|
+
super().__init__()
|
|
192
|
+
self.layers = model.layers
|
|
193
|
+
self.config = model.config
|
|
194
|
+
self.spatial_shapes = rbln_config.spatial_shapes
|
|
195
|
+
self.spatial_shapes_list = rbln_config.spatial_shapes_list
|
|
196
|
+
self.rbln_config = rbln_config
|
|
197
|
+
self.reference_points_head = model.reference_points_head
|
|
198
|
+
self.bbox_embed = model.bbox_embed
|
|
199
|
+
self.layer_norm = model.layer_norm
|
|
200
|
+
|
|
201
|
+
@monkey_patch_decorator
|
|
202
|
+
def forward(
|
|
203
|
+
self,
|
|
204
|
+
inputs_embeds,
|
|
205
|
+
vision_encoder_hidden_states,
|
|
206
|
+
vision_encoder_attention_mask=None,
|
|
207
|
+
text_encoder_hidden_states=None,
|
|
208
|
+
text_encoder_attention_mask=None,
|
|
209
|
+
reference_points=None,
|
|
210
|
+
valid_ratios=None,
|
|
211
|
+
):
|
|
212
|
+
output_attentions = self.rbln_config.output_attentions
|
|
213
|
+
output_hidden_states = self.rbln_config.output_hidden_states
|
|
214
|
+
|
|
215
|
+
if inputs_embeds is not None:
|
|
216
|
+
hidden_states = inputs_embeds
|
|
217
|
+
|
|
218
|
+
# decoder layers
|
|
219
|
+
all_hidden_states = () if output_hidden_states else None
|
|
220
|
+
all_self_attns = () if output_attentions else None
|
|
221
|
+
all_attns = () if output_attentions else None
|
|
222
|
+
all_cross_attns_vision = () if (output_attentions and vision_encoder_hidden_states is not None) else None
|
|
223
|
+
all_cross_attns_text = () if (output_attentions and text_encoder_hidden_states is not None) else None
|
|
224
|
+
intermediate = ()
|
|
225
|
+
intermediate_reference_points = ()
|
|
226
|
+
|
|
227
|
+
if text_encoder_attention_mask is not None:
|
|
228
|
+
text_encoder_attention_mask = text_encoder_attention_mask[:, None, None, :]
|
|
229
|
+
text_encoder_attention_mask = text_encoder_attention_mask.repeat(
|
|
230
|
+
1, self.config.decoder_attention_heads, self.config.num_queries, 1
|
|
231
|
+
)
|
|
232
|
+
text_encoder_attention_mask = text_encoder_attention_mask
|
|
233
|
+
text_encoder_attention_mask = text_encoder_attention_mask * torch.finfo(torch.float16).min
|
|
234
|
+
|
|
235
|
+
for idx, decoder_layer in enumerate(self.layers):
|
|
236
|
+
num_coordinates = reference_points.shape[-1]
|
|
237
|
+
if num_coordinates == 4:
|
|
238
|
+
reference_points_input = (
|
|
239
|
+
reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
|
|
240
|
+
)
|
|
241
|
+
elif num_coordinates == 2:
|
|
242
|
+
reference_points_input = reference_points[:, :, None] * valid_ratios[:, None]
|
|
243
|
+
else:
|
|
244
|
+
raise ValueError("Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
|
245
|
+
_query_pos = get_sine_pos_embed(reference_points_input[:, :, 0, :], num_pos_feats=self.config.d_model // 2)
|
|
246
|
+
query_pos = self.reference_points_head(_query_pos)
|
|
247
|
+
|
|
248
|
+
# In original implementation they apply layer norm before outputting intermediate hidden states
|
|
249
|
+
# Though that's not through between layers so the layers use as input the output of the previous layer
|
|
250
|
+
# withtout layer norm
|
|
251
|
+
if output_hidden_states:
|
|
252
|
+
all_hidden_states += (self.layer_norm(hidden_states),)
|
|
253
|
+
|
|
254
|
+
layer_outputs = decoder_layer(
|
|
255
|
+
hidden_states=hidden_states,
|
|
256
|
+
position_embeddings=query_pos,
|
|
257
|
+
reference_points=reference_points_input,
|
|
258
|
+
spatial_shapes=self.spatial_shapes,
|
|
259
|
+
spatial_shapes_list=self.spatial_shapes_list,
|
|
260
|
+
level_start_index=None,
|
|
261
|
+
vision_encoder_hidden_states=vision_encoder_hidden_states,
|
|
262
|
+
vision_encoder_attention_mask=vision_encoder_attention_mask,
|
|
263
|
+
text_encoder_hidden_states=text_encoder_hidden_states,
|
|
264
|
+
text_encoder_attention_mask=text_encoder_attention_mask,
|
|
265
|
+
self_attn_mask=None,
|
|
266
|
+
output_attentions=output_attentions,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
hidden_states = layer_outputs[0]
|
|
270
|
+
|
|
271
|
+
# hack implementation for iterative bounding box refinement
|
|
272
|
+
if self.bbox_embed is not None:
|
|
273
|
+
tmp = self.bbox_embed[idx](hidden_states)
|
|
274
|
+
num_coordinates = reference_points.shape[-1]
|
|
275
|
+
if num_coordinates == 4:
|
|
276
|
+
new_reference_points = tmp + torch.special.logit(reference_points, eps=1e-5)
|
|
277
|
+
new_reference_points = new_reference_points.sigmoid()
|
|
278
|
+
elif num_coordinates == 2:
|
|
279
|
+
new_reference_points = tmp
|
|
280
|
+
new_reference_points[..., :2] = tmp[..., :2] + torch.special.logit(reference_points, eps=1e-5)
|
|
281
|
+
new_reference_points = new_reference_points.sigmoid()
|
|
282
|
+
else:
|
|
283
|
+
raise ValueError(
|
|
284
|
+
f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}"
|
|
285
|
+
)
|
|
286
|
+
reference_points = new_reference_points.detach()
|
|
287
|
+
|
|
288
|
+
intermediate += (self.layer_norm(hidden_states),)
|
|
289
|
+
intermediate_reference_points += (reference_points,)
|
|
290
|
+
|
|
291
|
+
if output_attentions:
|
|
292
|
+
all_self_attns += (layer_outputs[1],)
|
|
293
|
+
|
|
294
|
+
if text_encoder_hidden_states is not None:
|
|
295
|
+
all_cross_attns_text += (layer_outputs[2],)
|
|
296
|
+
|
|
297
|
+
if vision_encoder_hidden_states is not None:
|
|
298
|
+
all_cross_attns_vision += (layer_outputs[3],)
|
|
299
|
+
|
|
300
|
+
# Keep batch_size as first dimension
|
|
301
|
+
intermediate = torch.stack(intermediate, dim=1)
|
|
302
|
+
intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
|
|
303
|
+
hidden_states = self.layer_norm(hidden_states)
|
|
304
|
+
|
|
305
|
+
# add hidden states from the last decoder layer
|
|
306
|
+
if output_hidden_states:
|
|
307
|
+
all_hidden_states += (hidden_states,)
|
|
308
|
+
|
|
309
|
+
if output_attentions:
|
|
310
|
+
all_attns += (all_self_attns, all_cross_attns_text, all_cross_attns_vision)
|
|
311
|
+
|
|
312
|
+
return tuple(
|
|
313
|
+
v
|
|
314
|
+
for v in [
|
|
315
|
+
hidden_states,
|
|
316
|
+
intermediate,
|
|
317
|
+
intermediate_reference_points,
|
|
318
|
+
all_hidden_states,
|
|
319
|
+
all_attns,
|
|
320
|
+
]
|
|
321
|
+
if v is not None
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class _GroundingDinoEncoderLayer(torch.nn.Module):
|
|
326
|
+
def forward(
|
|
327
|
+
self,
|
|
328
|
+
vision_features: Tensor,
|
|
329
|
+
vision_position_embedding: Tensor,
|
|
330
|
+
spatial_shapes: Tensor,
|
|
331
|
+
spatial_shapes_list: List[Tuple[int, int]],
|
|
332
|
+
level_start_index: Tensor,
|
|
333
|
+
key_padding_mask: Tensor,
|
|
334
|
+
reference_points: Tensor,
|
|
335
|
+
text_features: Optional[Tensor] = None,
|
|
336
|
+
text_attention_mask: Optional[Tensor] = None,
|
|
337
|
+
text_position_embedding: Optional[Tensor] = None,
|
|
338
|
+
text_self_attention_masks: Optional[Tensor] = None,
|
|
339
|
+
text_position_ids: Optional[Tensor] = None,
|
|
340
|
+
):
|
|
341
|
+
text_position_embedding = self.get_text_position_embeddings(
|
|
342
|
+
text_features, text_position_embedding, text_position_ids
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
(vision_features, vision_fused_attn), (text_features, text_fused_attn) = self.fusion_layer(
|
|
346
|
+
vision_features=vision_features,
|
|
347
|
+
text_features=text_features,
|
|
348
|
+
attention_mask_vision=key_padding_mask,
|
|
349
|
+
attention_mask_text=text_attention_mask,
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
(text_features, text_enhanced_attn) = self.text_enhancer_layer(
|
|
353
|
+
hidden_states=text_features,
|
|
354
|
+
attention_masks=(1.0 - text_self_attention_masks), # RBLN FIX, change from ~ to 1.0 -
|
|
355
|
+
position_embeddings=(text_position_embedding if text_position_embedding is not None else None),
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
(vision_features, vision_deformable_attn) = self.deformable_layer(
|
|
359
|
+
hidden_states=vision_features,
|
|
360
|
+
attention_mask=(1.0 - key_padding_mask), # RBLN FIX, change from ~ to 1.0 -
|
|
361
|
+
position_embeddings=vision_position_embedding,
|
|
362
|
+
reference_points=reference_points,
|
|
363
|
+
spatial_shapes=spatial_shapes,
|
|
364
|
+
spatial_shapes_list=spatial_shapes_list,
|
|
365
|
+
level_start_index=level_start_index,
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
return (
|
|
369
|
+
(vision_features, text_features),
|
|
370
|
+
(vision_fused_attn, text_fused_attn, text_enhanced_attn, vision_deformable_attn),
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
class _GroundingDinoMultiscaleDeformableAttention(torch.nn.Module):
|
|
375
|
+
"""
|
|
376
|
+
Multiscale deformable attention as proposed in Deformable DETR.
|
|
377
|
+
"""
|
|
378
|
+
|
|
379
|
+
def forward(
|
|
380
|
+
self,
|
|
381
|
+
hidden_states: torch.Tensor,
|
|
382
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
383
|
+
encoder_hidden_states=None,
|
|
384
|
+
encoder_attention_mask=None,
|
|
385
|
+
position_embeddings: Optional[torch.Tensor] = None,
|
|
386
|
+
reference_points=None,
|
|
387
|
+
spatial_shapes=None,
|
|
388
|
+
spatial_shapes_list=None,
|
|
389
|
+
level_start_index=None,
|
|
390
|
+
output_attentions: bool = False,
|
|
391
|
+
):
|
|
392
|
+
# add position embeddings to the hidden states before projecting to queries and keys
|
|
393
|
+
if position_embeddings is not None:
|
|
394
|
+
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
|
|
395
|
+
|
|
396
|
+
batch_size, num_queries, _ = hidden_states.shape
|
|
397
|
+
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
|
398
|
+
# Ignore copy
|
|
399
|
+
if torch.compiler.is_exporting():
|
|
400
|
+
torch._check(
|
|
401
|
+
(spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum().item() == sequence_length,
|
|
402
|
+
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states",
|
|
403
|
+
)
|
|
404
|
+
else:
|
|
405
|
+
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
|
|
406
|
+
raise ValueError(
|
|
407
|
+
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
value = self.value_proj(encoder_hidden_states)
|
|
411
|
+
if attention_mask is not None:
|
|
412
|
+
# RBLN FIX: bool tensor to float tensor
|
|
413
|
+
value = attention_mask * value
|
|
414
|
+
|
|
415
|
+
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
|
|
416
|
+
sampling_offsets = self.sampling_offsets(hidden_states).view(
|
|
417
|
+
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
|
|
418
|
+
)
|
|
419
|
+
attention_weights = self.attention_weights(hidden_states).view(
|
|
420
|
+
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
|
|
421
|
+
)
|
|
422
|
+
attention_weights = F.softmax(attention_weights, -1).view(
|
|
423
|
+
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
|
|
424
|
+
)
|
|
425
|
+
# batch_size, num_queries, n_heads, n_levels, n_points, 2
|
|
426
|
+
num_coordinates = reference_points.shape[-1]
|
|
427
|
+
if num_coordinates == 2:
|
|
428
|
+
offset_normalizer = 0.5 * torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
|
|
429
|
+
sampling_grids = (
|
|
430
|
+
2 * reference_points[:, :, None, :, None, :]
|
|
431
|
+
- 1
|
|
432
|
+
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
|
433
|
+
)
|
|
434
|
+
elif num_coordinates == 4:
|
|
435
|
+
ref_points_xy, ref_points_wh = torch.split(reference_points, 2, dim=-1)
|
|
436
|
+
ref_points_xy = ref_points_xy[:, :, None, :, None, :]
|
|
437
|
+
ref_points_wh = ref_points_wh[:, :, None, :, None, :]
|
|
438
|
+
ref_points_grids = 2 * ref_points_xy - 1
|
|
439
|
+
offset_grids = sampling_offsets / self.n_points * ref_points_wh
|
|
440
|
+
sampling_grids = ref_points_grids + offset_grids
|
|
441
|
+
|
|
442
|
+
else:
|
|
443
|
+
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
|
|
444
|
+
|
|
445
|
+
output = self.attn(
|
|
446
|
+
value,
|
|
447
|
+
spatial_shapes,
|
|
448
|
+
spatial_shapes_list,
|
|
449
|
+
level_start_index,
|
|
450
|
+
sampling_grids,
|
|
451
|
+
attention_weights,
|
|
452
|
+
self.im2col_step,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
output = self.output_proj(output)
|
|
456
|
+
|
|
457
|
+
return output, attention_weights
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
class _GroundingDinoBiMultiHeadAttention(torch.nn.Module):
|
|
461
|
+
def forward(
|
|
462
|
+
self,
|
|
463
|
+
vision_features: torch.FloatTensor,
|
|
464
|
+
text_features: torch.FloatTensor,
|
|
465
|
+
vision_attention_mask: Optional[torch.BoolTensor] = None,
|
|
466
|
+
text_attention_mask: Optional[torch.BoolTensor] = None,
|
|
467
|
+
) -> Tuple[Tuple[torch.FloatTensor, torch.FloatTensor], Tuple[torch.FloatTensor, torch.FloatTensor]]:
|
|
468
|
+
batch_size, tgt_len, _ = vision_features.size()
|
|
469
|
+
|
|
470
|
+
vision_query_states = self.vision_proj(vision_features) * self.scale
|
|
471
|
+
vision_query_states = self._reshape(vision_query_states, tgt_len, batch_size)
|
|
472
|
+
|
|
473
|
+
text_key_states = self.text_proj(text_features)
|
|
474
|
+
text_key_states = self._reshape(text_key_states, -1, batch_size)
|
|
475
|
+
|
|
476
|
+
vision_value_states = self.values_vision_proj(vision_features)
|
|
477
|
+
vision_value_states = self._reshape(vision_value_states, -1, batch_size)
|
|
478
|
+
|
|
479
|
+
text_value_states = self.values_text_proj(text_features)
|
|
480
|
+
text_value_states = self._reshape(text_value_states, -1, batch_size)
|
|
481
|
+
|
|
482
|
+
proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
|
|
483
|
+
|
|
484
|
+
vision_query_states = vision_query_states.view(*proj_shape)
|
|
485
|
+
text_key_states = text_key_states.view(*proj_shape)
|
|
486
|
+
vision_value_states = vision_value_states.view(*proj_shape)
|
|
487
|
+
text_value_states = text_value_states.view(*proj_shape)
|
|
488
|
+
|
|
489
|
+
src_len = text_key_states.size(1)
|
|
490
|
+
attn_weights = torch.bmm(vision_query_states, text_key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt
|
|
491
|
+
|
|
492
|
+
if attn_weights.size() != (batch_size * self.num_heads, tgt_len, src_len):
|
|
493
|
+
raise ValueError(
|
|
494
|
+
f"Attention weights should be of size {(batch_size * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
# RBLN FIX: max_values from scalar to vector
|
|
498
|
+
attn_weights = attn_weights - torch.max(attn_weights).reshape(1).repeat(src_len)
|
|
499
|
+
# # Do not increase -50000/50000, data type half has quite limited range
|
|
500
|
+
attn_weights = torch.clamp(attn_weights, min=-50000, max=50000)
|
|
501
|
+
|
|
502
|
+
# RBLN FIX: max_values from scalar to vector
|
|
503
|
+
text_attn_weights = attn_weights - torch.max(attn_weights, dim=1, keepdim=True)[0].repeat(1, tgt_len, 1)
|
|
504
|
+
|
|
505
|
+
# # Do not increase -50000/50000, data type half has quite limited range
|
|
506
|
+
text_attn_weights = torch.clamp(text_attn_weights, min=-50000, max=50000)
|
|
507
|
+
|
|
508
|
+
text_attn_weights = text_attn_weights.transpose(1, 2)
|
|
509
|
+
|
|
510
|
+
# mask vision for language
|
|
511
|
+
if vision_attention_mask is not None:
|
|
512
|
+
# RBLN FIX: bool tensor to float tensor
|
|
513
|
+
mask = vision_attention_mask * torch.finfo(torch.float16).min
|
|
514
|
+
text_attn_weights = text_attn_weights.transpose(1, 2) + mask
|
|
515
|
+
text_attn_weights = text_attn_weights.transpose(1, 2)
|
|
516
|
+
|
|
517
|
+
text_attn_weights = text_attn_weights.softmax(dim=-1)
|
|
518
|
+
|
|
519
|
+
# mask language for vision
|
|
520
|
+
if text_attention_mask is not None:
|
|
521
|
+
text_attention_mask = text_attention_mask[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
|
|
522
|
+
# RBLN FIX: bool tensor to float tensor
|
|
523
|
+
mask = text_attention_mask * torch.finfo(torch.float16).min
|
|
524
|
+
attn_weights = attn_weights + mask
|
|
525
|
+
|
|
526
|
+
vision_attn_weights = attn_weights.softmax(dim=-1)
|
|
527
|
+
|
|
528
|
+
vision_attn_probs = F.dropout(vision_attn_weights, p=self.dropout, training=self.training)
|
|
529
|
+
text_attn_probs = F.dropout(text_attn_weights, p=self.dropout, training=self.training)
|
|
530
|
+
|
|
531
|
+
vision_attn_output = torch.bmm(vision_attn_probs, text_value_states)
|
|
532
|
+
text_attn_output = torch.bmm(text_attn_probs, vision_value_states)
|
|
533
|
+
|
|
534
|
+
if vision_attn_output.size() != (batch_size * self.num_heads, tgt_len, self.head_dim):
|
|
535
|
+
raise ValueError(
|
|
536
|
+
f"`vision_attn_output` should be of size {(batch_size, self.num_heads, tgt_len, self.head_dim)}, but is {vision_attn_output.size()}"
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
if text_attn_output.size() != (batch_size * self.num_heads, src_len, self.head_dim):
|
|
540
|
+
raise ValueError(
|
|
541
|
+
f"`text_attn_output` should be of size {(batch_size, self.num_heads, src_len, self.head_dim)}, but is {text_attn_output.size()}"
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
vision_attn_output = vision_attn_output.view(batch_size, self.num_heads, tgt_len, self.head_dim)
|
|
545
|
+
vision_attn_output = vision_attn_output.transpose(1, 2)
|
|
546
|
+
vision_attn_output = vision_attn_output.reshape(batch_size, tgt_len, self.embed_dim)
|
|
547
|
+
|
|
548
|
+
text_attn_output = text_attn_output.view(batch_size, self.num_heads, src_len, self.head_dim)
|
|
549
|
+
text_attn_output = text_attn_output.transpose(1, 2)
|
|
550
|
+
text_attn_output = text_attn_output.reshape(batch_size, src_len, self.embed_dim)
|
|
551
|
+
|
|
552
|
+
vision_attn_output = self.out_vision_proj(vision_attn_output)
|
|
553
|
+
text_attn_output = self.out_text_proj(text_attn_output)
|
|
554
|
+
|
|
555
|
+
return (vision_attn_output, vision_attn_weights), (text_attn_output, text_attn_weights)
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
class _MultiScaleDeformableAttention(torch.nn.Module):
|
|
559
|
+
def forward(
|
|
560
|
+
self,
|
|
561
|
+
value: Tensor,
|
|
562
|
+
value_spatial_shapes: Tensor,
|
|
563
|
+
value_spatial_shapes_list: List[Tuple],
|
|
564
|
+
level_start_index: Tensor,
|
|
565
|
+
sampling_grids: Tensor,
|
|
566
|
+
attention_weights: Tensor,
|
|
567
|
+
im2col_step: int,
|
|
568
|
+
):
|
|
569
|
+
batch_size, _, num_heads, hidden_dim = value.shape
|
|
570
|
+
_, num_queries, num_heads, num_levels, num_points, _ = sampling_grids.shape
|
|
571
|
+
value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
|
|
572
|
+
sampling_value_list = []
|
|
573
|
+
sampling_grids_list = [t.squeeze(3) for t in torch.split(sampling_grids, 1, dim=3)]
|
|
574
|
+
for level_id, (height, width) in enumerate(value_spatial_shapes_list):
|
|
575
|
+
value_l_ = (
|
|
576
|
+
value_list[level_id].permute(0, 2, 3, 1).reshape(batch_size * num_heads, hidden_dim, height, width)
|
|
577
|
+
)
|
|
578
|
+
sampling_grid_l_ = sampling_grids_list[level_id].transpose(1, 2).flatten(0, 1)
|
|
579
|
+
sampling_value_l_ = torch.nn.functional.grid_sample(
|
|
580
|
+
value_l_,
|
|
581
|
+
sampling_grid_l_,
|
|
582
|
+
mode="bilinear",
|
|
583
|
+
padding_mode="zeros",
|
|
584
|
+
align_corners=False,
|
|
585
|
+
)
|
|
586
|
+
sampling_value_list.append(sampling_value_l_)
|
|
587
|
+
|
|
588
|
+
sampling_values = torch.cat(sampling_value_list, dim=-1)
|
|
589
|
+
attention_weights_prep = attention_weights.transpose(1, 2)
|
|
590
|
+
values_permuted = sampling_values.permute(0, 2, 3, 1)
|
|
591
|
+
|
|
592
|
+
weights_for_matmul = attention_weights_prep.reshape(
|
|
593
|
+
batch_size * num_heads, num_queries, 1, num_levels * num_points
|
|
594
|
+
)
|
|
595
|
+
output_before_permute = torch.matmul(weights_for_matmul, values_permuted)
|
|
596
|
+
output_before_view = output_before_permute.squeeze(2).permute(0, 2, 1)
|
|
597
|
+
output = output_before_view.reshape(batch_size, num_heads * hidden_dim, num_queries)
|
|
598
|
+
|
|
599
|
+
return output.transpose(1, 2).contiguous()
|