optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Tuple, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import nn
|
|
19
|
+
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
|
20
|
+
from transformers.utils import logging
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
logger = logging.get_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class WhisperWrapper:
|
|
27
|
+
def __init__(self, model, use_attention_mask, rbln_token_timestamps):
|
|
28
|
+
self.encoder = WhisperEncoderWrapper(model)
|
|
29
|
+
self.decoder = WhisperDecoderWrapper(
|
|
30
|
+
model, use_attention_mask=use_attention_mask, output_attentions=rbln_token_timestamps
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class WhisperEncoderWrapper(torch.nn.Module):
|
|
35
|
+
def __init__(self, model):
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.config = model.config
|
|
38
|
+
self.encoder = model.get_encoder()
|
|
39
|
+
self.num_heads = self.config.decoder_attention_heads
|
|
40
|
+
self.d_kv = self.config.d_model // self.num_heads
|
|
41
|
+
self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().layers)
|
|
42
|
+
|
|
43
|
+
def _extract_cross_kv_projects(self, decoder_layers: nn.Module):
|
|
44
|
+
return (
|
|
45
|
+
nn.ModuleList(layer.encoder_attn.k_proj for layer in decoder_layers),
|
|
46
|
+
nn.ModuleList(layer.encoder_attn.v_proj for layer in decoder_layers),
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def forward(
|
|
50
|
+
self,
|
|
51
|
+
input_features: Optional[torch.LongTensor],
|
|
52
|
+
b_idx: torch.Tensor,
|
|
53
|
+
cross_key_values: torch.Tensor,
|
|
54
|
+
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
|
55
|
+
# 1. get encoder last_hidden_states
|
|
56
|
+
encoder_outputs = self.encoder(input_features=input_features)
|
|
57
|
+
last_hidden_states = encoder_outputs[0]
|
|
58
|
+
|
|
59
|
+
# 2. pre-compute cross_attention's past_key_value which used in decoder phase.
|
|
60
|
+
cross_kv = []
|
|
61
|
+
batch_size = input_features.shape[0]
|
|
62
|
+
for k_proj, v_proj in zip(self.cross_k_projects, self.cross_v_projects):
|
|
63
|
+
past_k = k_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
|
|
64
|
+
past_v = v_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
|
|
65
|
+
|
|
66
|
+
cross_kv.append(past_k)
|
|
67
|
+
cross_kv.append(past_v)
|
|
68
|
+
|
|
69
|
+
cross_kv = torch.stack(cross_kv, dim=0)
|
|
70
|
+
|
|
71
|
+
# 3. update cross_attention's past_key_value to the device-dram for optimization.
|
|
72
|
+
batch_axis = torch.tensor(1, dtype=torch.int16)
|
|
73
|
+
cross_key_values = torch.ops.rbln_custom_ops.rbln_cache_update(
|
|
74
|
+
cross_key_values, cross_kv, b_idx[0], batch_axis
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
return cross_key_values
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class WhisperDecoderWrapper(torch.nn.Module):
|
|
81
|
+
def __init__(self, model, use_attention_mask: bool = True, output_attentions: bool = False, **kwargs):
|
|
82
|
+
super().__init__()
|
|
83
|
+
self.config = model.config
|
|
84
|
+
self.proj_out = model.proj_out
|
|
85
|
+
self.use_attention_mask = use_attention_mask
|
|
86
|
+
self.output_attentions = output_attentions
|
|
87
|
+
self.__post_init__(model, **kwargs)
|
|
88
|
+
|
|
89
|
+
def __post_init__(self, model: nn.Module, **kwargs):
|
|
90
|
+
"""
|
|
91
|
+
Post-initialization to extract and configure encoder-related attributes.
|
|
92
|
+
It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
|
|
93
|
+
by subclasses to modify or add custom attributes as necessary.
|
|
94
|
+
"""
|
|
95
|
+
self.num_layers = self.config.decoder_layers
|
|
96
|
+
self.decoder = self.convert_to_rbln_conditional_generation(model)
|
|
97
|
+
|
|
98
|
+
def convert_to_rbln_conditional_generation(self, model: nn.Module):
|
|
99
|
+
new_layers = []
|
|
100
|
+
for layer in model.get_decoder().layers:
|
|
101
|
+
self_attn = WhisperSelfAttention(layer.self_attn)
|
|
102
|
+
cross_attn = WhisperCrossAttention(layer.encoder_attn)
|
|
103
|
+
new_layers.append(WhisperDecoderLayer(layer, self_attn, cross_attn))
|
|
104
|
+
|
|
105
|
+
decoder_model = WhisperDecoder(model.get_decoder(), new_layers)
|
|
106
|
+
|
|
107
|
+
return decoder_model
|
|
108
|
+
|
|
109
|
+
def forward(
|
|
110
|
+
self,
|
|
111
|
+
*args,
|
|
112
|
+
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
|
113
|
+
if self.use_attention_mask:
|
|
114
|
+
(
|
|
115
|
+
decoder_input_ids,
|
|
116
|
+
decoder_attention_mask,
|
|
117
|
+
cache_position,
|
|
118
|
+
block_tables,
|
|
119
|
+
cross_kv_cache,
|
|
120
|
+
*self_kv_cache,
|
|
121
|
+
) = args
|
|
122
|
+
else:
|
|
123
|
+
decoder_attention_mask = None
|
|
124
|
+
(decoder_input_ids, cache_position, block_tables, cross_kv_cache, *self_kv_cache) = args
|
|
125
|
+
|
|
126
|
+
# prepare past_key_values
|
|
127
|
+
self_past_key_values = ()
|
|
128
|
+
cross_past_key_values = ()
|
|
129
|
+
for i in range(0, self.num_layers * 2, 2):
|
|
130
|
+
self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
|
|
131
|
+
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
|
132
|
+
|
|
133
|
+
# Decode
|
|
134
|
+
sequence_output, cross_attentions = self.decoder(
|
|
135
|
+
input_ids=decoder_input_ids,
|
|
136
|
+
attention_mask=decoder_attention_mask,
|
|
137
|
+
cache_position=cache_position,
|
|
138
|
+
self_past_key_values=self_past_key_values,
|
|
139
|
+
cross_past_key_values=cross_past_key_values,
|
|
140
|
+
block_tables=block_tables,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
lm_logits = self.proj_out(sequence_output)
|
|
144
|
+
outputs = (lm_logits,)
|
|
145
|
+
|
|
146
|
+
if self.output_attentions:
|
|
147
|
+
# deocder's cross attention is used for token_timestamps
|
|
148
|
+
cross_attention = torch.stack(cross_attentions, dim=0)
|
|
149
|
+
outputs += (cross_attention,)
|
|
150
|
+
|
|
151
|
+
return outputs
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class WhisperDecoder(nn.Module):
|
|
155
|
+
def __init__(self, model, layers, **kwargs):
|
|
156
|
+
super().__init__()
|
|
157
|
+
self._original_mod = model
|
|
158
|
+
self.layers = nn.ModuleList(layers)
|
|
159
|
+
self.embed_tokens = model.embed_tokens
|
|
160
|
+
self.layer_norm = model.layer_norm
|
|
161
|
+
self.embed_positions = model.embed_positions
|
|
162
|
+
|
|
163
|
+
def forward(
|
|
164
|
+
self,
|
|
165
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
166
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
167
|
+
self_past_key_values: Optional[torch.Tensor] = None,
|
|
168
|
+
cross_past_key_values: Optional[torch.Tensor] = None,
|
|
169
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
170
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
171
|
+
):
|
|
172
|
+
input_shape = input_ids.size()
|
|
173
|
+
input_ids = input_ids.view(-1, input_shape[-1])
|
|
174
|
+
|
|
175
|
+
# positional embeding
|
|
176
|
+
inputs_embeds = self.embed_tokens(input_ids)
|
|
177
|
+
all_hiddens = []
|
|
178
|
+
for i in range(inputs_embeds.shape[0]):
|
|
179
|
+
position_id = cache_position[i]
|
|
180
|
+
position = self.embed_positions.weight[position_id]
|
|
181
|
+
batch_hidden = position + inputs_embeds[i]
|
|
182
|
+
all_hiddens.append(batch_hidden)
|
|
183
|
+
|
|
184
|
+
hidden_states = torch.cat(all_hiddens, dim=0).unsqueeze(1)
|
|
185
|
+
|
|
186
|
+
# prepare attn mask (normal attention - masked)
|
|
187
|
+
if attention_mask is not None:
|
|
188
|
+
attention_mask = attention_mask[:, None, None, :]
|
|
189
|
+
|
|
190
|
+
cross_attentions = ()
|
|
191
|
+
# iterate decoder_layer
|
|
192
|
+
for self_past_key_value, cross_past_key_value, decoder_layer in zip(
|
|
193
|
+
self_past_key_values, cross_past_key_values, self.layers
|
|
194
|
+
):
|
|
195
|
+
hidden_states, cross_attn_weights = decoder_layer(
|
|
196
|
+
hidden_states,
|
|
197
|
+
attention_mask=attention_mask,
|
|
198
|
+
self_past_key_value=self_past_key_value,
|
|
199
|
+
cross_past_key_value=cross_past_key_value,
|
|
200
|
+
cache_position=cache_position,
|
|
201
|
+
block_tables=block_tables,
|
|
202
|
+
)
|
|
203
|
+
cross_attentions += (cross_attn_weights,)
|
|
204
|
+
|
|
205
|
+
hidden_states = self.layer_norm(hidden_states)
|
|
206
|
+
|
|
207
|
+
return hidden_states, cross_attentions
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class WhisperDecoderLayer(nn.Module):
|
|
211
|
+
def __init__(self, decoder_layer, self_attn, cross_attn):
|
|
212
|
+
super().__init__()
|
|
213
|
+
self._original_mod = decoder_layer
|
|
214
|
+
self.self_attn = self_attn
|
|
215
|
+
self.encoder_attn = cross_attn
|
|
216
|
+
self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
|
|
217
|
+
self.encoder_attn_layer_norm = decoder_layer.encoder_attn_layer_norm
|
|
218
|
+
self.final_layer_norm = decoder_layer.final_layer_norm
|
|
219
|
+
self.activation_fn = decoder_layer.activation_fn
|
|
220
|
+
self.fc1 = decoder_layer.fc1
|
|
221
|
+
self.fc2 = decoder_layer.fc2
|
|
222
|
+
|
|
223
|
+
def forward(
|
|
224
|
+
self,
|
|
225
|
+
hidden_states: torch.Tensor,
|
|
226
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
227
|
+
self_past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
228
|
+
cross_past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
229
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
230
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
231
|
+
) -> torch.Tensor:
|
|
232
|
+
# Self Attention Block
|
|
233
|
+
residual = hidden_states
|
|
234
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
235
|
+
hidden_states = self.self_attn(
|
|
236
|
+
hidden_states=hidden_states,
|
|
237
|
+
past_key_value=self_past_key_value,
|
|
238
|
+
attention_mask=attention_mask,
|
|
239
|
+
cache_position=cache_position,
|
|
240
|
+
block_tables=block_tables,
|
|
241
|
+
)
|
|
242
|
+
hidden_states = residual + hidden_states
|
|
243
|
+
|
|
244
|
+
# Cross-Attention Block
|
|
245
|
+
residual = hidden_states
|
|
246
|
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
|
247
|
+
hidden_states, cross_attn_weights = self.encoder_attn(
|
|
248
|
+
hidden_states=hidden_states,
|
|
249
|
+
past_key_value=cross_past_key_value,
|
|
250
|
+
)
|
|
251
|
+
hidden_states = residual + hidden_states
|
|
252
|
+
|
|
253
|
+
# Fully Connected Block
|
|
254
|
+
residual = hidden_states
|
|
255
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
|
256
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
257
|
+
hidden_states = self.fc2(hidden_states)
|
|
258
|
+
hidden_states = residual + hidden_states
|
|
259
|
+
|
|
260
|
+
return hidden_states, cross_attn_weights
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
class WhisperAttention(nn.Module):
|
|
264
|
+
def __init__(self, attn):
|
|
265
|
+
super().__init__()
|
|
266
|
+
self._original_mod = attn
|
|
267
|
+
self.q_proj = attn.q_proj
|
|
268
|
+
self.k_proj = attn.k_proj
|
|
269
|
+
self.v_proj = attn.v_proj
|
|
270
|
+
self.out_proj = attn.out_proj
|
|
271
|
+
self.num_heads = attn.num_heads
|
|
272
|
+
self.embed_dim = attn.embed_dim
|
|
273
|
+
self.head_dim = attn.embed_dim // attn.num_heads
|
|
274
|
+
self.scaling = self.head_dim**-0.5
|
|
275
|
+
|
|
276
|
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
|
277
|
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
class WhisperSelfAttention(WhisperAttention):
|
|
281
|
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
|
282
|
+
return tensor.view(bsz, seq_len, 1, self.num_heads, self.head_dim).transpose(1, 3)
|
|
283
|
+
|
|
284
|
+
def forward(
|
|
285
|
+
self,
|
|
286
|
+
hidden_states: torch.Tensor,
|
|
287
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
288
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
289
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
290
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
291
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
292
|
+
bsz, tgt_len, _ = hidden_states.size()
|
|
293
|
+
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
|
294
|
+
query_states = query_states * self.scaling
|
|
295
|
+
|
|
296
|
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
297
|
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
298
|
+
block_size = past_key_value[0].shape[-2]
|
|
299
|
+
|
|
300
|
+
args = {
|
|
301
|
+
"q": query_states,
|
|
302
|
+
"k": key_states,
|
|
303
|
+
"v": value_states,
|
|
304
|
+
"kcache": past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
|
305
|
+
"vcache": past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
|
306
|
+
"seq": cache_position.expand(bsz, 1),
|
|
307
|
+
"scale": torch.tensor(1.0, dtype=torch.float32),
|
|
308
|
+
"block_table": block_tables,
|
|
309
|
+
"block_size": block_size,
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
if attention_mask is not None:
|
|
313
|
+
args["mask"] = attention_mask.unsqueeze(2)
|
|
314
|
+
attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(**args)
|
|
315
|
+
else:
|
|
316
|
+
args["mask"] = None
|
|
317
|
+
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(**args)
|
|
318
|
+
|
|
319
|
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
|
320
|
+
attn_output = attn_output.transpose(1, 2)
|
|
321
|
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
|
322
|
+
attn_output = self.out_proj(attn_output)
|
|
323
|
+
|
|
324
|
+
return attn_output
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
class WhisperCrossAttention(WhisperAttention):
|
|
328
|
+
def forward(
|
|
329
|
+
self,
|
|
330
|
+
hidden_states: torch.Tensor,
|
|
331
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
332
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
333
|
+
batch_size, query_len, _ = hidden_states.size()
|
|
334
|
+
query_states = self._shape(self.q_proj(hidden_states), query_len, batch_size)
|
|
335
|
+
query_states = query_states * self.scaling
|
|
336
|
+
|
|
337
|
+
key_states = past_key_value[0]
|
|
338
|
+
value_states = past_key_value[1]
|
|
339
|
+
|
|
340
|
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
|
341
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
342
|
+
|
|
343
|
+
attn_output = torch.matmul(attn_weights, value_states)
|
|
344
|
+
attn_output = attn_output.view(batch_size, self.num_heads, query_len, self.head_dim)
|
|
345
|
+
attn_output = attn_output.transpose(1, 2)
|
|
346
|
+
attn_output = attn_output.reshape(batch_size, query_len, self.embed_dim)
|
|
347
|
+
attn_output = self.out_proj(attn_output)
|
|
348
|
+
|
|
349
|
+
return attn_output, attn_weights
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_xlm_roberta import RBLNXLMRobertaForSequenceClassificationConfig, RBLNXLMRobertaModelConfig
|
|
16
|
+
from .modeling_xlm_roberta import RBLNXLMRobertaForSequenceClassification, RBLNXLMRobertaModel
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"RBLNXLMRobertaModelConfig",
|
|
21
|
+
"RBLNXLMRobertaForSequenceClassificationConfig",
|
|
22
|
+
"RBLNXLMRobertaModel",
|
|
23
|
+
"RBLNXLMRobertaForSequenceClassification",
|
|
24
|
+
]
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ...configuration_generic import (
|
|
16
|
+
RBLNModelForSequenceClassificationConfig,
|
|
17
|
+
RBLNTransformerEncoderForFeatureExtractionConfig,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class RBLNXLMRobertaModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
|
|
22
|
+
"""
|
|
23
|
+
Configuration class for XLM-RoBERTa model.
|
|
24
|
+
Inherits from RBLNTransformerEncoderForFeatureExtractionConfig with no additional parameters.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class RBLNXLMRobertaForSequenceClassificationConfig(RBLNModelForSequenceClassificationConfig):
|
|
29
|
+
"""
|
|
30
|
+
Configuration class for XLM-RoBERTa sequence classification model.
|
|
31
|
+
Inherits from RBLNModelForSequenceClassificationConfig with no additional parameters.
|
|
32
|
+
"""
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions, SequenceClassifierOutput
|
|
19
|
+
|
|
20
|
+
from ...modeling_generic import RBLNModelForSequenceClassification, RBLNTransformerEncoderForFeatureExtraction
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RBLNXLMRobertaModel(RBLNTransformerEncoderForFeatureExtraction):
|
|
24
|
+
"""
|
|
25
|
+
XLM-RoBERTa base model optimized for RBLN NPU.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def forward(
|
|
29
|
+
self,
|
|
30
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
31
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
32
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
|
33
|
+
**kwargs,
|
|
34
|
+
) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, tuple]:
|
|
35
|
+
"""
|
|
36
|
+
Forward pass for the RBLN-optimized XLM-RoBERTa base model.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
40
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
41
|
+
token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate different portions of the inputs.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPoolingAndCrossAttentions object.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
if token_type_ids is not None:
|
|
48
|
+
kwargs.setdefault("token_type_ids", token_type_ids)
|
|
49
|
+
|
|
50
|
+
return super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class RBLNXLMRobertaForSequenceClassification(RBLNModelForSequenceClassification):
|
|
54
|
+
"""
|
|
55
|
+
XLM-RoBERTa model for sequence classification tasks optimized for RBLN NPU.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
59
|
+
|
|
60
|
+
def forward(
|
|
61
|
+
self,
|
|
62
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
63
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
|
64
|
+
token_type_ids: Optional[torch.LongTensor] = None,
|
|
65
|
+
**kwargs,
|
|
66
|
+
) -> Union[SequenceClassifierOutput, tuple]:
|
|
67
|
+
"""
|
|
68
|
+
Forward pass for the RBLN-optimized XLM-RoBERTa model for sequence classification.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
72
|
+
attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
73
|
+
token_type_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SequenceClassifierOutput object.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
if token_type_ids is not None:
|
|
80
|
+
kwargs.setdefault("token_type_ids", token_type_ids)
|
|
81
|
+
|
|
82
|
+
return super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
|
File without changes
|