optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from typing import Optional, Tuple
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from transformers.modeling_outputs import ModelOutput
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class RBLNDecoderOnlyOutput(ModelOutput):
|
|
24
|
+
logits: torch.FloatTensor = None
|
|
25
|
+
generate_idx: torch.Tensor = None
|
|
26
|
+
padded_cache_lengths: int = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyOutput):
|
|
31
|
+
attention_mask: Optional[torch.Tensor] = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
|
|
36
|
+
last_hidden_states: torch.FloatTensor = None
|
|
37
|
+
params: Tuple[torch.FloatTensor] = None
|
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
16
|
+
|
|
17
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
18
|
+
# you may not use this file except in compliance with the License.
|
|
19
|
+
# You may obtain a copy of the License at:
|
|
20
|
+
|
|
21
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
22
|
+
|
|
23
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
24
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
25
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
26
|
+
# See the License for the specific language governing permissions and
|
|
27
|
+
# limitations under the License.
|
|
28
|
+
|
|
29
|
+
import math
|
|
30
|
+
from typing import Optional, Tuple
|
|
31
|
+
|
|
32
|
+
import torch
|
|
33
|
+
from transformers import PretrainedConfig
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _compute_default_rope_parameters(
|
|
37
|
+
config: Optional[PretrainedConfig] = None,
|
|
38
|
+
seq_len: Optional[int] = None,
|
|
39
|
+
) -> Tuple["torch.Tensor", float]:
|
|
40
|
+
"""
|
|
41
|
+
Computes the inverse frequencies according to the original RoPE implementation
|
|
42
|
+
Args:
|
|
43
|
+
config ([`~transformers.PretrainedConfig`]):
|
|
44
|
+
The model configuration.
|
|
45
|
+
seq_len (`int`, *optional*):
|
|
46
|
+
The current sequence length. Unused for this type of RoPE.
|
|
47
|
+
Returns:
|
|
48
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
|
49
|
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
|
50
|
+
"""
|
|
51
|
+
base = config.rope_theta
|
|
52
|
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
|
53
|
+
head_dim = (
|
|
54
|
+
config.head_dim
|
|
55
|
+
if hasattr(config, "head_dim") and config.head_dim is not None
|
|
56
|
+
else config.hidden_size // config.num_attention_heads
|
|
57
|
+
)
|
|
58
|
+
dim = int(head_dim * partial_rotary_factor)
|
|
59
|
+
|
|
60
|
+
attention_factor = 1.0 # Unused in this type of RoPE
|
|
61
|
+
|
|
62
|
+
# Compute the inverse frequencies
|
|
63
|
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
|
64
|
+
return inv_freq, attention_factor
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _compute_linear_scaling_rope_parameters(
|
|
68
|
+
config: Optional[PretrainedConfig] = None,
|
|
69
|
+
seq_len: Optional[int] = None,
|
|
70
|
+
) -> Tuple["torch.Tensor", float]:
|
|
71
|
+
"""
|
|
72
|
+
Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
|
|
73
|
+
Args:
|
|
74
|
+
config ([`~transformers.PretrainedConfig`]):
|
|
75
|
+
The model configuration.
|
|
76
|
+
seq_len (`int`, *optional*):
|
|
77
|
+
The current sequence length. Unused for this type of RoPE.
|
|
78
|
+
Returns:
|
|
79
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
|
80
|
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
factor = config.rope_scaling["factor"]
|
|
84
|
+
|
|
85
|
+
# Gets the default RoPE parameters
|
|
86
|
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
|
|
87
|
+
|
|
88
|
+
# Then applies linear scaling to the frequencies.
|
|
89
|
+
# NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
|
|
90
|
+
# applying scaling to the inverse frequencies is equivalent.
|
|
91
|
+
inv_freq /= factor
|
|
92
|
+
return inv_freq, attention_factor
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _compute_dynamic_ntk_parameters(
|
|
96
|
+
config: Optional[PretrainedConfig] = None,
|
|
97
|
+
seq_len: Optional[int] = None,
|
|
98
|
+
) -> Tuple["torch.Tensor", float]:
|
|
99
|
+
"""
|
|
100
|
+
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
|
|
101
|
+
Args:
|
|
102
|
+
config ([`~transformers.PretrainedConfig`]):
|
|
103
|
+
The model configuration.
|
|
104
|
+
seq_len (`int`, *optional*):
|
|
105
|
+
The current sequence length, used to update the dynamic RoPE at inference time.
|
|
106
|
+
Returns:
|
|
107
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
|
108
|
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
base = config.rope_theta
|
|
112
|
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
|
113
|
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
|
114
|
+
dim = int(head_dim * partial_rotary_factor)
|
|
115
|
+
max_position_embeddings = config.max_position_embeddings
|
|
116
|
+
factor = config.rope_scaling["factor"]
|
|
117
|
+
|
|
118
|
+
attention_factor = 1.0 # Unused in this type of RoPE
|
|
119
|
+
|
|
120
|
+
# Process with chunk_size to reduce precesion error
|
|
121
|
+
chunk_size = 4096
|
|
122
|
+
chunks = (seq_len + chunk_size - 1) // chunk_size
|
|
123
|
+
|
|
124
|
+
inv_freq_list = []
|
|
125
|
+
for i in range(chunks):
|
|
126
|
+
start = i * chunk_size
|
|
127
|
+
end = min((i + 1) * chunk_size, seq_len)
|
|
128
|
+
|
|
129
|
+
seq_lens = torch.arange(start, end, dtype=torch.float32).view(-1, 1) + 1.0
|
|
130
|
+
seq_lens = torch.where(seq_lens > max_position_embeddings, seq_lens, max_position_embeddings)
|
|
131
|
+
|
|
132
|
+
# Compute the inverse frequencies for each chunk
|
|
133
|
+
scaled_base = base * ((factor * seq_lens / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
|
|
134
|
+
inv_freq = 1.0 / (scaled_base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
|
135
|
+
|
|
136
|
+
inv_freq_list.append(inv_freq)
|
|
137
|
+
|
|
138
|
+
final_inv_freq = torch.cat(inv_freq_list, dim=0)
|
|
139
|
+
|
|
140
|
+
return final_inv_freq, attention_factor
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] = None) -> Tuple["torch.Tensor", float]:
|
|
144
|
+
"""
|
|
145
|
+
Computes the inverse frequencies with NTK scaling. Please refer to the
|
|
146
|
+
[original paper](https://arxiv.org/abs/2309.00071)
|
|
147
|
+
Args:
|
|
148
|
+
config ([`~transformers.PretrainedConfig`]):
|
|
149
|
+
The model configuration.
|
|
150
|
+
seq_len (`int`, *optional*):
|
|
151
|
+
The current sequence length. Unused for this type of RoPE.
|
|
152
|
+
Returns:
|
|
153
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
|
154
|
+
post-processing scaling factor applied to the computed cos/sin.
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
base = config.rope_theta
|
|
158
|
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
|
159
|
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
|
160
|
+
dim = int(head_dim * partial_rotary_factor)
|
|
161
|
+
max_position_embeddings = config.max_position_embeddings
|
|
162
|
+
factor = config.rope_scaling["factor"]
|
|
163
|
+
|
|
164
|
+
# Sets the attention factor as suggested in the paper
|
|
165
|
+
attention_factor = config.rope_scaling.get("attention_factor")
|
|
166
|
+
if attention_factor is None:
|
|
167
|
+
attention_factor = 0.1 * math.log(factor) + 1.0
|
|
168
|
+
|
|
169
|
+
# Optional config options
|
|
170
|
+
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
|
171
|
+
beta_fast = config.rope_scaling.get("beta_fast") or 32
|
|
172
|
+
beta_slow = config.rope_scaling.get("beta_slow") or 1
|
|
173
|
+
|
|
174
|
+
# Compute the inverse frequencies
|
|
175
|
+
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
|
|
176
|
+
"""Inverse dimension formula to find the dimension based on the number of rotations"""
|
|
177
|
+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
|
178
|
+
|
|
179
|
+
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
|
|
180
|
+
"""Find dimension range bounds based on rotations"""
|
|
181
|
+
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
|
182
|
+
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
|
183
|
+
return max(low, 0), min(high, dim - 1)
|
|
184
|
+
|
|
185
|
+
def linear_ramp_factor(min, max, dim):
|
|
186
|
+
if min == max:
|
|
187
|
+
max += 0.001 # Prevent singularity
|
|
188
|
+
|
|
189
|
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
|
190
|
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
|
191
|
+
return ramp_func
|
|
192
|
+
|
|
193
|
+
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
|
|
194
|
+
# to expand the possible context length. In other words, interpolation = apply scaling factor.
|
|
195
|
+
pos_freqs = base ** (torch.arange(0, dim, 2).float() / dim)
|
|
196
|
+
inv_freq_extrapolation = 1.0 / pos_freqs
|
|
197
|
+
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
|
198
|
+
|
|
199
|
+
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
|
|
200
|
+
|
|
201
|
+
# Get n-dimensional rotational scaling corrected for extrapolation
|
|
202
|
+
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float()
|
|
203
|
+
inv_freq = (
|
|
204
|
+
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
|
205
|
+
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
return inv_freq, attention_factor
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _compute_longrope_parameters(
|
|
212
|
+
config: PretrainedConfig, seq_len: Optional[int] = None
|
|
213
|
+
) -> Tuple["torch.Tensor", float]:
|
|
214
|
+
"""
|
|
215
|
+
Computes the inverse frequencies with LongRoPE scaling. Please refer to the
|
|
216
|
+
[original implementation](https://github.com/microsoft/LongRoPE)
|
|
217
|
+
Args:
|
|
218
|
+
config ([`~transformers.PretrainedConfig`]):
|
|
219
|
+
The model configuration.
|
|
220
|
+
seq_len (`int`, *optional*):
|
|
221
|
+
The current sequence length. Unused for this type of RoPE.
|
|
222
|
+
Returns:
|
|
223
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
|
224
|
+
post-processing scaling factor applied to the computed cos/sin.
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
base = config.rope_theta
|
|
228
|
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
|
229
|
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
|
230
|
+
dim = int(head_dim * partial_rotary_factor)
|
|
231
|
+
long_factor = config.rope_scaling["long_factor"]
|
|
232
|
+
short_factor = config.rope_scaling["short_factor"]
|
|
233
|
+
factor = config.rope_scaling.get("factor")
|
|
234
|
+
attention_factor = config.rope_scaling.get("attention_factor")
|
|
235
|
+
|
|
236
|
+
# NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
|
|
237
|
+
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
|
238
|
+
# values to compute the default attention scaling factor, instead of using `factor`.
|
|
239
|
+
if hasattr(config, "original_max_position_embeddings"):
|
|
240
|
+
max_position_embeddings = config.original_max_position_embeddings
|
|
241
|
+
expanded_max_position_embeddings = config.max_position_embeddings
|
|
242
|
+
factor = expanded_max_position_embeddings / max_position_embeddings
|
|
243
|
+
else:
|
|
244
|
+
max_position_embeddings = config.max_position_embeddings
|
|
245
|
+
expanded_max_position_embeddings = max_position_embeddings * factor
|
|
246
|
+
|
|
247
|
+
# Sets the attention factor as suggested in the paper
|
|
248
|
+
if attention_factor is None:
|
|
249
|
+
if factor <= 1.0:
|
|
250
|
+
attention_factor = 1.0
|
|
251
|
+
else:
|
|
252
|
+
attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
|
|
253
|
+
|
|
254
|
+
# Compute the inverse frequencies -- scaled based on the target sequence length
|
|
255
|
+
if expanded_max_position_embeddings > max_position_embeddings:
|
|
256
|
+
ext_factors = torch.tensor(long_factor, dtype=torch.float32)
|
|
257
|
+
else:
|
|
258
|
+
ext_factors = torch.tensor(short_factor, dtype=torch.float32)
|
|
259
|
+
inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
|
|
260
|
+
inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
|
|
261
|
+
|
|
262
|
+
return inv_freq, attention_factor
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def _compute_llama3_parameters(
|
|
266
|
+
config: PretrainedConfig, seq_len: Optional[int] = None
|
|
267
|
+
) -> Tuple["torch.Tensor", float]:
|
|
268
|
+
"""
|
|
269
|
+
Computes the inverse frequencies for llama 3.1.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
config ([`~transformers.PretrainedConfig`]):
|
|
273
|
+
The model configuration.
|
|
274
|
+
seq_len (`int`, *optional*):
|
|
275
|
+
The current sequence length. Unused for this type of RoPE.
|
|
276
|
+
Returns:
|
|
277
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
|
278
|
+
post-processing scaling factor applied to the computed cos/sin.
|
|
279
|
+
"""
|
|
280
|
+
# Gets the default RoPE parameters
|
|
281
|
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
|
|
282
|
+
|
|
283
|
+
factor = config.rope_scaling["factor"] # `8` in the original implementation
|
|
284
|
+
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
|
|
285
|
+
high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
|
|
286
|
+
old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
|
|
287
|
+
|
|
288
|
+
low_freq_wavelen = old_context_len / low_freq_factor
|
|
289
|
+
high_freq_wavelen = old_context_len / high_freq_factor
|
|
290
|
+
|
|
291
|
+
wavelen = 2 * math.pi / inv_freq
|
|
292
|
+
# wavelen < high_freq_wavelen: do nothing
|
|
293
|
+
# wavelen > low_freq_wavelen: divide by factor
|
|
294
|
+
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
|
|
295
|
+
# otherwise: interpolate between the two, using a smooth factor
|
|
296
|
+
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
|
297
|
+
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
|
|
298
|
+
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
|
299
|
+
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
|
300
|
+
|
|
301
|
+
return inv_freq_llama, attention_factor
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
|
305
|
+
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
|
|
306
|
+
# parameterizations, as long as the callable has the same signature.
|
|
307
|
+
ROPE_INIT_FUNCTIONS = {
|
|
308
|
+
"default": _compute_default_rope_parameters,
|
|
309
|
+
"linear": _compute_linear_scaling_rope_parameters,
|
|
310
|
+
"dynamic": _compute_dynamic_ntk_parameters,
|
|
311
|
+
"yarn": _compute_yarn_parameters,
|
|
312
|
+
"longrope": _compute_longrope_parameters,
|
|
313
|
+
"llama3": _compute_llama3_parameters,
|
|
314
|
+
}
|