optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import math
|
|
16
|
+
from typing import TYPE_CHECKING, Tuple
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
|
|
21
|
+
from ..decoderonly.decoderonly_architecture import (
|
|
22
|
+
DecoderOnlyAttention,
|
|
23
|
+
DecoderOnlyLayer,
|
|
24
|
+
DecoderOnlyModel,
|
|
25
|
+
DecoderOnlyWrapper,
|
|
26
|
+
apply_rotary_pos_emb_partial,
|
|
27
|
+
rotate_half,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from transformers import PreTrainedModel as MidmLMHeadModel
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def apply_rotary_to_tensor(tensor, cos, sin, rot_dim):
|
|
36
|
+
"""Applies rotary position embedding to the specified dimension of the tensor."""
|
|
37
|
+
tensor_, tensor_pass = tensor[..., :rot_dim], tensor[..., rot_dim:]
|
|
38
|
+
tensor_embed = (tensor_ * cos) + (rotate_half(tensor_) * sin)
|
|
39
|
+
return torch.cat((tensor_embed, tensor_pass), dim=-1)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
|
43
|
+
"""Applies Rotary Position Embedding to the query and key tensors."""
|
|
44
|
+
rot_dim = cos.shape[-1]
|
|
45
|
+
q_embed = apply_rotary_to_tensor(q, cos, sin, rot_dim)
|
|
46
|
+
k_embed = apply_rotary_to_tensor(k, cos, sin, rot_dim)
|
|
47
|
+
return q_embed, k_embed
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
|
|
51
|
+
def get_rotary_emb(self, max_seq_len):
|
|
52
|
+
self.config.rope_theta = 10000
|
|
53
|
+
self.config.head_dim = self.config.n_embd // self.config.n_head
|
|
54
|
+
self.config.partial_rotary_factor = self.config.rotary_percentage
|
|
55
|
+
return super().get_rotary_emb(max_seq_len=max_seq_len)
|
|
56
|
+
|
|
57
|
+
def get_rbln_attn_class(self):
|
|
58
|
+
return MidmAttention
|
|
59
|
+
|
|
60
|
+
def get_rbln_layer_class(self):
|
|
61
|
+
return MidmLayer
|
|
62
|
+
|
|
63
|
+
def get_rbln_model_class(self):
|
|
64
|
+
return MidmModel
|
|
65
|
+
|
|
66
|
+
def get_model_layer(self, causal_lm: "MidmLMHeadModel"):
|
|
67
|
+
return causal_lm.transformer
|
|
68
|
+
|
|
69
|
+
def get_decoder_layers(self, causal_lm: "MidmLMHeadModel"):
|
|
70
|
+
return causal_lm.transformer.h
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class MidmModel(DecoderOnlyModel):
|
|
74
|
+
def get_layernorm1p(self, module: nn.LayerNorm):
|
|
75
|
+
def layernorm1p(input: torch.Tensor):
|
|
76
|
+
"""Applies Layer Normalization with a slight modification on the weights."""
|
|
77
|
+
return torch.nn.functional.layer_norm(
|
|
78
|
+
input, module.normalized_shape, module.weight + 1, module.bias, module.eps
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
return layernorm1p
|
|
82
|
+
|
|
83
|
+
def get_last_layernorm(self) -> nn.LayerNorm:
|
|
84
|
+
if self._original_mod.use_layernorm1p:
|
|
85
|
+
return self.get_layernorm1p(self._original_mod.ln_f)
|
|
86
|
+
else:
|
|
87
|
+
return self._original_mod.ln_f
|
|
88
|
+
|
|
89
|
+
def get_embedding(self) -> nn.Embedding:
|
|
90
|
+
return self._original_mod.wte
|
|
91
|
+
|
|
92
|
+
def get_pos_embedding(self) -> nn.Embedding:
|
|
93
|
+
return self._original_mod.wpe
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class MidmLayer(DecoderOnlyLayer):
|
|
97
|
+
def get_layernorm1p(self, module: nn.LayerNorm):
|
|
98
|
+
def layernorm1p(input: torch.Tensor):
|
|
99
|
+
"""Applies Layer Normalization with a slight modification on the weights."""
|
|
100
|
+
return torch.nn.functional.layer_norm(
|
|
101
|
+
input, module.normalized_shape, module.weight + 1, module.bias, module.eps
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
return layernorm1p
|
|
105
|
+
|
|
106
|
+
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
|
107
|
+
if self._original_mod.use_layernorm1p:
|
|
108
|
+
return self.get_layernorm1p(self._original_mod.ln_1)
|
|
109
|
+
else:
|
|
110
|
+
return self._original_mod.ln_1
|
|
111
|
+
|
|
112
|
+
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
|
113
|
+
if self._original_mod.use_layernorm1p:
|
|
114
|
+
return self.get_layernorm1p(self._original_mod.ln_2)
|
|
115
|
+
else:
|
|
116
|
+
return self._original_mod.ln_2
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class MidmAttention(DecoderOnlyAttention):
|
|
120
|
+
def __post_init__(self):
|
|
121
|
+
self.c_attn = self._original_mod.c_attn
|
|
122
|
+
self.o_proj = self._original_mod.c_proj
|
|
123
|
+
self.split_size = self._original_mod.split_size
|
|
124
|
+
self.num_key_value_heads = self._original_mod.num_heads
|
|
125
|
+
|
|
126
|
+
def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
127
|
+
if lora_int_id is not None:
|
|
128
|
+
raise NotImplementedError("LoRA is not supported for MidmAttention")
|
|
129
|
+
|
|
130
|
+
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
|
131
|
+
return query_states, key_states, value_states
|
|
132
|
+
|
|
133
|
+
def get_attn_scale(self):
|
|
134
|
+
scale = 1.0
|
|
135
|
+
if self._original_mod.scale_attn_weights:
|
|
136
|
+
scale /= math.sqrt(self.head_dim)
|
|
137
|
+
|
|
138
|
+
if self._original_mod.scale_attn_by_inverse_layer_idx and not self._original_mod.scale_qk_by_inverse_layer_idx:
|
|
139
|
+
scale /= 1 + self.layer_idx
|
|
140
|
+
|
|
141
|
+
return scale
|
|
142
|
+
|
|
143
|
+
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
|
144
|
+
return apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim=cos.shape[-1])
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import inspect
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any, Callable, Dict, Optional, Union
|
|
18
|
+
|
|
19
|
+
from transformers import AutoModelForCausalLM
|
|
20
|
+
from transformers.generation.utils import GenerationMixin
|
|
21
|
+
|
|
22
|
+
from ....configuration_utils import RBLNModelConfig
|
|
23
|
+
from ....utils import logging
|
|
24
|
+
from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
|
|
25
|
+
from .midm_architecture import MidmLMHeadModelWrapper
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
logger = logging.get_logger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
32
|
+
"""
|
|
33
|
+
The MIDM Model transformer with a language modeling head (linear layer) on top.
|
|
34
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
35
|
+
|
|
36
|
+
A class to convert and run pre-trained transformers based MidmForCausalLM model on RBLN devices.
|
|
37
|
+
It implements the methods to convert a pre-trained transformers MidmForCausalLM model into a RBLN transformer model by:
|
|
38
|
+
|
|
39
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
40
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
41
|
+
|
|
42
|
+
**Configuration:**
|
|
43
|
+
This model uses [`RBLNMidmLMHeadModelConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
44
|
+
the `rbln_config` parameter should be an instance of [`RBLNMidmLMHeadModelConfig`] or a dictionary conforming to its structure.
|
|
45
|
+
|
|
46
|
+
See the [`RBLNMidmLMHeadModelConfig`] class for all available configuration options.
|
|
47
|
+
|
|
48
|
+
Examples:
|
|
49
|
+
```python
|
|
50
|
+
from optimum.rbln import RBLNMidmLMHeadModel
|
|
51
|
+
|
|
52
|
+
# Simple usage using rbln_* arguments
|
|
53
|
+
# `max_seq_len` is automatically inferred from the model config
|
|
54
|
+
model = RBLNMidmLMHeadModel.from_pretrained(
|
|
55
|
+
"KT-AI/midm-bitext-S-7B-inst-v1",
|
|
56
|
+
export=True,
|
|
57
|
+
rbln_batch_size=1,
|
|
58
|
+
rbln_tensor_parallel_size=4,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# Using a config dictionary
|
|
63
|
+
rbln_config = {
|
|
64
|
+
"batch_size": 1,
|
|
65
|
+
"max_seq_len": 4096,
|
|
66
|
+
"tensor_parallel_size": 4,
|
|
67
|
+
}
|
|
68
|
+
model = RBLNMidmLMHeadModel.from_pretrained(
|
|
69
|
+
"KT-AI/midm-bitext-S-7B-inst-v1",
|
|
70
|
+
export=True,
|
|
71
|
+
rbln_config=rbln_config
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# Using a RBLNMidmLMHeadModelConfig instance (recommended for type checking)
|
|
76
|
+
from optimum.rbln import RBLNMidmLMHeadModelConfig
|
|
77
|
+
|
|
78
|
+
config = RBLNMidmLMHeadModelConfig(
|
|
79
|
+
batch_size=1,
|
|
80
|
+
max_seq_len=4096,
|
|
81
|
+
tensor_parallel_size=4
|
|
82
|
+
)
|
|
83
|
+
model = RBLNMidmLMHeadModel.from_pretrained(
|
|
84
|
+
"KT-AI/midm-bitext-S-7B-inst-v1",
|
|
85
|
+
export=True,
|
|
86
|
+
rbln_config=config
|
|
87
|
+
)
|
|
88
|
+
```
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
_decoder_wrapper_cls = MidmLMHeadModelWrapper
|
|
92
|
+
_hf_class = AutoModelForCausalLM
|
|
93
|
+
_supports_cache_class = True
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def from_pretrained(
|
|
97
|
+
cls,
|
|
98
|
+
model_id: Union[str, Path],
|
|
99
|
+
*,
|
|
100
|
+
export: Optional[bool] = None,
|
|
101
|
+
rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
|
|
102
|
+
trust_remote_code: Optional[bool] = None,
|
|
103
|
+
**kwargs: Any,
|
|
104
|
+
):
|
|
105
|
+
"""
|
|
106
|
+
The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
|
|
107
|
+
User can use this function to load a pre-trained model from the HuggingFace library and convert it to a RBLN model to be run on RBLN NPUs.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
model_id (Union[str, Path]): The model id of the pre-trained model to be loaded.
|
|
111
|
+
It can be downloaded from the HuggingFace model hub or a local path, or a model id of a compiled model using the RBLN Compiler.
|
|
112
|
+
export (Optional[bool]): A boolean flag to indicate whether the model should be compiled.
|
|
113
|
+
If None, it will be determined based on the existence of the compiled model files in the model_id.
|
|
114
|
+
rbln_config (Optional[Union[Dict, RBLNModelConfig]]): Configuration for RBLN model compilation and runtime.
|
|
115
|
+
This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNMidmLMHeadModelConfig` for Mi:dm models).
|
|
116
|
+
For detailed configuration options, see the specific model's configuration class documentation.
|
|
117
|
+
trust_remote_code (bool): Whether or not to trust the remote code when loading a model from the Hub.
|
|
118
|
+
kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
(RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
if trust_remote_code is not None:
|
|
125
|
+
kwargs["trust_remote_code"] = trust_remote_code
|
|
126
|
+
elif "trust_remote_code" not in kwargs:
|
|
127
|
+
kwargs["trust_remote_code"] = True
|
|
128
|
+
|
|
129
|
+
return super().from_pretrained(
|
|
130
|
+
model_id=model_id,
|
|
131
|
+
export=export,
|
|
132
|
+
rbln_config=rbln_config,
|
|
133
|
+
**kwargs,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def __getattr__(self, __name: str) -> Any:
|
|
137
|
+
def redirect(func):
|
|
138
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
139
|
+
|
|
140
|
+
val = getattr(GenerationMixin, __name)
|
|
141
|
+
|
|
142
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
143
|
+
return redirect(val)
|
|
144
|
+
return val
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_mistral import RBLNMistralForCausalLMConfig, RBLNMistralModelConfig
|
|
16
|
+
from .modeling_mistral import RBLNMistralForCausalLM, RBLNMistralModel
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RBLNMistralForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
19
|
+
"""
|
|
20
|
+
Configuration class for RBLN Mistral models.
|
|
21
|
+
|
|
22
|
+
This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
|
|
23
|
+
|
|
24
|
+
Example usage:
|
|
25
|
+
```python
|
|
26
|
+
from optimum.rbln import RBLNMistralForCausalLM, RBLNMistralForCausalLMConfig
|
|
27
|
+
|
|
28
|
+
# Create a configuration object
|
|
29
|
+
config = RBLNMistralForCausalLMConfig(
|
|
30
|
+
batch_size=1,
|
|
31
|
+
max_seq_len=4096,
|
|
32
|
+
tensor_parallel_size=4
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Use the configuration with from_pretrained
|
|
36
|
+
model = RBLNMistralForCausalLM.from_pretrained(
|
|
37
|
+
"mistralai/Mistral-7B-v0.1",
|
|
38
|
+
export=True,
|
|
39
|
+
rbln_config=config
|
|
40
|
+
)
|
|
41
|
+
```
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class RBLNMistralModelConfig(RBLNDecoderOnlyModelConfig):
|
|
46
|
+
"""
|
|
47
|
+
Configuration class for RBLN Mistral models.
|
|
48
|
+
|
|
49
|
+
This class is an alias of RBLNDecoderOnlyModelConfig.
|
|
50
|
+
"""
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MistralWrapper(DecoderOnlyWrapper):
|
|
19
|
+
pass
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from transformers import PretrainedConfig
|
|
16
|
+
|
|
17
|
+
from ....utils import logging
|
|
18
|
+
from ...models.decoderonly import (
|
|
19
|
+
RBLNDecoderOnlyModel,
|
|
20
|
+
RBLNDecoderOnlyModelForCausalLM,
|
|
21
|
+
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
22
|
+
)
|
|
23
|
+
from .mistral_architecture import MistralWrapper
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
logger = logging.get_logger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
30
|
+
"""
|
|
31
|
+
The Mistral Model transformer with a language modeling head (linear layer) on top.
|
|
32
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
33
|
+
|
|
34
|
+
A class to convert and run pre-trained transformers based MistralForCausalLM model on RBLN devices.
|
|
35
|
+
It implements the methods to convert a pre-trained transformers MistralForCausalLM model into a RBLN transformer model by:
|
|
36
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
37
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
38
|
+
|
|
39
|
+
**Configuration:**
|
|
40
|
+
This model uses [`RBLNMistralForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
41
|
+
the `rbln_config` parameter should be an instance of [`RBLNMistralForCausalLMConfig`] or a dictionary conforming to its structure.
|
|
42
|
+
|
|
43
|
+
See the [`RBLNMistralForCausalLMConfig`] class for all available configuration options.
|
|
44
|
+
|
|
45
|
+
Examples:
|
|
46
|
+
```python
|
|
47
|
+
from optimum.rbln import RBLNMistralForCausalLM
|
|
48
|
+
|
|
49
|
+
# Simple usage using rbln_* arguments
|
|
50
|
+
# `max_seq_len` is automatically inferred from the model config
|
|
51
|
+
model = RBLNMistralForCausalLM.from_pretrained(
|
|
52
|
+
"mistralai/Mistral-7B-v0.1",
|
|
53
|
+
export=True,
|
|
54
|
+
rbln_batch_size=1,
|
|
55
|
+
rbln_tensor_parallel_size=4,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Using a config dictionary
|
|
59
|
+
rbln_config = {
|
|
60
|
+
"batch_size": 1,
|
|
61
|
+
"max_seq_len": 4096,
|
|
62
|
+
"tensor_parallel_size": 4,
|
|
63
|
+
}
|
|
64
|
+
model = RBLNMistralForCausalLM.from_pretrained(
|
|
65
|
+
"mistralai/Mistral-7B-v0.1",
|
|
66
|
+
export=True,
|
|
67
|
+
rbln_config=rbln_config
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Using a RBLNMistralForCausalLMConfig instance (recommended for type checking)
|
|
71
|
+
from optimum.rbln import RBLNMistralForCausalLMConfig
|
|
72
|
+
|
|
73
|
+
config = RBLNMistralForCausalLMConfig(
|
|
74
|
+
batch_size=1,
|
|
75
|
+
max_seq_len=4096,
|
|
76
|
+
tensor_parallel_size=4
|
|
77
|
+
)
|
|
78
|
+
model = RBLNMistralForCausalLM.from_pretrained(
|
|
79
|
+
"mistralai/Mistral-7B-v0.1",
|
|
80
|
+
export=True,
|
|
81
|
+
rbln_config=config
|
|
82
|
+
)
|
|
83
|
+
```
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
_decoder_wrapper_cls = MistralWrapper
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def _update_sliding_window_config(
|
|
90
|
+
cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
91
|
+
):
|
|
92
|
+
rbln_config.cache_impl = "sliding_window"
|
|
93
|
+
rbln_config.sliding_window = model_config.sliding_window
|
|
94
|
+
rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
|
|
95
|
+
|
|
96
|
+
return rbln_config
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class RBLNMistralModel(RBLNDecoderOnlyModel):
|
|
100
|
+
"""
|
|
101
|
+
The Mistral Model transformer without a language modeling head.
|
|
102
|
+
This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
_decoder_wrapper_cls = MistralWrapper
|
|
106
|
+
|
|
107
|
+
@classmethod
|
|
108
|
+
def _update_sliding_window_config(
|
|
109
|
+
cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
110
|
+
):
|
|
111
|
+
rbln_config.cache_impl = "sliding_window"
|
|
112
|
+
rbln_config.sliding_window = model_config.sliding_window
|
|
113
|
+
rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
|
|
114
|
+
|
|
115
|
+
return rbln_config
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_opt import RBLNOPTForCausalLMConfig, RBLNOPTModelConfig
|
|
16
|
+
from .modeling_opt import RBLNOPTForCausalLM, RBLNOPTModel
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RBLNOPTForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
19
|
+
"""
|
|
20
|
+
Configuration class for OPT causal language model.
|
|
21
|
+
Inherits from RBLNDecoderOnlyModelForCausalLMConfig with no additional parameters.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class RBLNOPTModelConfig(RBLNDecoderOnlyModelConfig):
|
|
26
|
+
"""
|
|
27
|
+
Configuration class for OPT model.
|
|
28
|
+
Inherits from RBLNDecoderOnlyModelConfig with no additional parameters.
|
|
29
|
+
"""
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
from transformers import PreTrainedModel
|
|
17
|
+
|
|
18
|
+
from ....utils import logging
|
|
19
|
+
from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
|
|
20
|
+
from ...models.decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
21
|
+
from .opt_architecture import OPTWrapper
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
logger = logging.get_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MLP(nn.Module):
|
|
28
|
+
def __init__(self, fc1, fc2, activation_fn):
|
|
29
|
+
super(MLP, self).__init__()
|
|
30
|
+
self.fc1 = fc1
|
|
31
|
+
self.fc2 = fc2
|
|
32
|
+
self.activation_fn = activation_fn
|
|
33
|
+
|
|
34
|
+
def forward(self, x):
|
|
35
|
+
x = self.fc1(x)
|
|
36
|
+
x = self.activation_fn(x)
|
|
37
|
+
x = self.fc2(x)
|
|
38
|
+
return x
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class RBLNOPTForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
42
|
+
"""
|
|
43
|
+
The OPT Model transformer with a language modeling head (linear layer) on top.
|
|
44
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
45
|
+
|
|
46
|
+
A class to convert and run pre-trained transformers based OPTForCausalLM model on RBLN devices.
|
|
47
|
+
It implements the methods to convert a pre-trained transformers OPTForCausalLM model into a RBLN transformer model by:
|
|
48
|
+
|
|
49
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
50
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
51
|
+
|
|
52
|
+
**Configuration:**
|
|
53
|
+
This model uses [`RBLNOPTForCausalLM`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
54
|
+
the `rbln_config` parameter should be an instance of [`RBLNOPTForCausalLM`] or a dictionary conforming to its structure.
|
|
55
|
+
|
|
56
|
+
See the [`RBLNOPTForCausalLM`] class for all available configuration options.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
_decoder_wrapper_cls = OPTWrapper
|
|
60
|
+
_use_rotary_emb = False
|
|
61
|
+
|
|
62
|
+
def modify_opt_decoder_layer(layer):
|
|
63
|
+
mlp = MLP(layer.fc1, layer.fc2, layer.activation_fn)
|
|
64
|
+
layer.mlp = mlp
|
|
65
|
+
del layer.fc1
|
|
66
|
+
del layer.fc2
|
|
67
|
+
del layer.activation_fn
|
|
68
|
+
|
|
69
|
+
return layer
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
73
|
+
for i in range(len(model.model.decoder.layers)):
|
|
74
|
+
model.model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.model.decoder.layers[i])
|
|
75
|
+
|
|
76
|
+
return cls._decoder_wrapper_cls(model, rbln_config=rbln_config, use_rotary_emb=cls._use_rotary_emb).eval()
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class RBLNOPTModel(RBLNDecoderOnlyModel):
|
|
80
|
+
"""
|
|
81
|
+
The OPT Model transformer without a language modeling head.
|
|
82
|
+
This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
_decoder_wrapper_cls = OPTWrapper
|
|
86
|
+
_use_rotary_emb = False
|
|
87
|
+
|
|
88
|
+
def modify_opt_decoder_layer(layer):
|
|
89
|
+
mlp = MLP(layer.fc1, layer.fc2, layer.activation_fn)
|
|
90
|
+
layer.mlp = mlp
|
|
91
|
+
del layer.fc1
|
|
92
|
+
del layer.fc2
|
|
93
|
+
del layer.activation_fn
|
|
94
|
+
|
|
95
|
+
return layer
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
99
|
+
for i in range(len(model.decoder.layers)):
|
|
100
|
+
model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.decoder.layers[i])
|
|
101
|
+
|
|
102
|
+
return cls._decoder_wrapper_cls(model, rbln_config=rbln_config, use_rotary_emb=cls._use_rotary_emb).eval()
|