optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import importlib
|
|
15
|
+
import inspect
|
|
16
|
+
import warnings
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Any, Dict, Optional, Type, Union
|
|
19
|
+
|
|
20
|
+
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
|
|
21
|
+
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
|
22
|
+
from transformers.models.auto.auto_factory import _get_model_class
|
|
23
|
+
|
|
24
|
+
from optimum.rbln.configuration_utils import RBLNAutoConfig, RBLNModelConfig
|
|
25
|
+
from optimum.rbln.modeling_base import RBLNBaseModel
|
|
26
|
+
from optimum.rbln.utils.model_utils import (
|
|
27
|
+
MODEL_MAPPING,
|
|
28
|
+
convert_hf_to_rbln_model_name,
|
|
29
|
+
convert_rbln_to_hf_model_name,
|
|
30
|
+
get_rbln_model_cls,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class _BaseAutoModelClass:
|
|
35
|
+
# Base class for auto models.
|
|
36
|
+
_model_mapping = None
|
|
37
|
+
|
|
38
|
+
def __init__(self, *args, **kwargs):
|
|
39
|
+
raise EnvironmentError(
|
|
40
|
+
f"{self.__class__.__name__} is designed to be instantiated "
|
|
41
|
+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)`"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def get_rbln_cls(
|
|
46
|
+
cls,
|
|
47
|
+
pretrained_model_name_or_path: Union[str, Path],
|
|
48
|
+
*args: Any,
|
|
49
|
+
export: bool = None,
|
|
50
|
+
**kwargs: Any,
|
|
51
|
+
):
|
|
52
|
+
"""
|
|
53
|
+
Determine the appropriate RBLN model class based on the given model ID and configuration.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
pretrained_model_name_or_path (str): Identifier or path to the pretrained model.
|
|
57
|
+
export (bool): Whether to infer the class based on HuggingFace (HF) architecture.
|
|
58
|
+
kwargs: Additional arguments for configuration and loading.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
RBLNBaseModel: The corresponding RBLN model class.
|
|
62
|
+
"""
|
|
63
|
+
if isinstance(pretrained_model_name_or_path, Path):
|
|
64
|
+
pretrained_model_name_or_path = pretrained_model_name_or_path.as_posix()
|
|
65
|
+
|
|
66
|
+
if export is None:
|
|
67
|
+
export = not RBLNBaseModel._is_compiled(
|
|
68
|
+
model_id=pretrained_model_name_or_path,
|
|
69
|
+
token=kwargs.get("token"),
|
|
70
|
+
revision=kwargs.get("revision"),
|
|
71
|
+
force_download=kwargs.get("force_download", False),
|
|
72
|
+
cache_dir=kwargs.get("cache_dir"),
|
|
73
|
+
subfolder=kwargs.get("subfolder", ""),
|
|
74
|
+
local_files_only=kwargs.get("local_files_only", False),
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if export:
|
|
78
|
+
hf_model_class = cls.infer_hf_model_class(pretrained_model_name_or_path, **kwargs)
|
|
79
|
+
rbln_class_name = convert_hf_to_rbln_model_name(hf_model_class.__name__)
|
|
80
|
+
else:
|
|
81
|
+
rbln_class_name = cls.get_rbln_model_cls_name(pretrained_model_name_or_path, **kwargs)
|
|
82
|
+
|
|
83
|
+
if convert_rbln_to_hf_model_name(rbln_class_name) not in cls._model_mapping_names.values():
|
|
84
|
+
raise ValueError(
|
|
85
|
+
f"The architecture '{rbln_class_name}' is not supported by the `{cls.__name__}.from_pretrained()` method. "
|
|
86
|
+
"Please use the `from_pretrained()` method of the appropriate class to load this model, "
|
|
87
|
+
f"or directly use '{rbln_class_name}.from_pretrained()`."
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
rbln_cls = get_rbln_model_cls(rbln_class_name)
|
|
92
|
+
except AttributeError as e:
|
|
93
|
+
raise AttributeError(
|
|
94
|
+
f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{pretrained_model_name_or_path}'. "
|
|
95
|
+
"Ensure that the class name is correctly mapped and available in the 'optimum.rbln' module."
|
|
96
|
+
) from e
|
|
97
|
+
|
|
98
|
+
return rbln_cls
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def infer_hf_model_class(
|
|
102
|
+
cls,
|
|
103
|
+
pretrained_model_name_or_path: Union[str, Path],
|
|
104
|
+
*args: Any,
|
|
105
|
+
**kwargs: Any,
|
|
106
|
+
):
|
|
107
|
+
"""
|
|
108
|
+
Infer the HuggingFace model class based on the configuration or model name.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
pretrained_model_name_or_path (str): Identifier or path to the pretrained model.
|
|
112
|
+
kwargs: Additional arguments for configuration and loading.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
PretrainedModel: The inferred HuggingFace model class.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
# Try to load configuration if provided or retrieve it from the model ID
|
|
119
|
+
config = kwargs.pop("config", None)
|
|
120
|
+
kwargs.update({"trust_remote_code": True})
|
|
121
|
+
kwargs["_from_auto"] = True
|
|
122
|
+
|
|
123
|
+
# Load configuration if not already provided
|
|
124
|
+
if not isinstance(config, PretrainedConfig):
|
|
125
|
+
config, kwargs = AutoConfig.from_pretrained(
|
|
126
|
+
pretrained_model_name_or_path,
|
|
127
|
+
return_unused_kwargs=True,
|
|
128
|
+
**kwargs,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Get hf_model_class from Config
|
|
132
|
+
has_remote_code = (
|
|
133
|
+
hasattr(config, "auto_map") and convert_rbln_to_hf_model_name(cls.__name__) in config.auto_map
|
|
134
|
+
)
|
|
135
|
+
if has_remote_code:
|
|
136
|
+
class_ref = config.auto_map[convert_rbln_to_hf_model_name(cls.__name__)]
|
|
137
|
+
model_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
|
|
138
|
+
elif type(config) in cls._model_mapping.keys():
|
|
139
|
+
model_class = _get_model_class(config, cls._model_mapping)
|
|
140
|
+
else:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
|
143
|
+
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
if model_class.__name__ != config.architectures[0]:
|
|
147
|
+
warnings.warn(
|
|
148
|
+
f"`{cls.__name__}.from_pretrained()` is invoking `{convert_hf_to_rbln_model_name(model_class.__name__)}.from_pretrained()`, which does not match the "
|
|
149
|
+
f"expected architecture `RBLN{config.architectures[0]}` from config. This mismatch could cause some operations to not be properly loaded "
|
|
150
|
+
f"from the checkpoint, leading to potential unintended behavior. If this is not intentional, consider calling the "
|
|
151
|
+
f"`from_pretrained()` method directly from the `RBLN{config.architectures[0]}` class instead.",
|
|
152
|
+
UserWarning,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
return model_class
|
|
156
|
+
|
|
157
|
+
@classmethod
|
|
158
|
+
def get_rbln_model_cls_name(cls, pretrained_model_name_or_path: Union[str, Path], **kwargs):
|
|
159
|
+
"""
|
|
160
|
+
Retrieve the path to the compiled model directory for a given RBLN model.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
pretrained_model_name_or_path (str): Identifier of the model.
|
|
164
|
+
kwargs: Additional arguments that match the parameters of `_load_compiled_model_dir`.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
str: Path to the compiled model directory.
|
|
168
|
+
"""
|
|
169
|
+
sig = inspect.signature(RBLNBaseModel._load_compiled_model_dir)
|
|
170
|
+
valid_params = sig.parameters.keys()
|
|
171
|
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
|
|
172
|
+
|
|
173
|
+
model_path_subfolder = RBLNBaseModel._load_compiled_model_dir(
|
|
174
|
+
model_id=pretrained_model_name_or_path, **filtered_kwargs
|
|
175
|
+
)
|
|
176
|
+
rbln_config = RBLNAutoConfig.load(model_path_subfolder)
|
|
177
|
+
|
|
178
|
+
return rbln_config.rbln_model_cls_name
|
|
179
|
+
|
|
180
|
+
@classmethod
|
|
181
|
+
def from_pretrained(
|
|
182
|
+
cls,
|
|
183
|
+
model_id: Union[str, Path],
|
|
184
|
+
export: bool = None,
|
|
185
|
+
rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
|
|
186
|
+
**kwargs,
|
|
187
|
+
):
|
|
188
|
+
"""
|
|
189
|
+
Load an RBLN-accelerated model from a pretrained checkpoint or a compiled RBLN artifact.
|
|
190
|
+
|
|
191
|
+
This convenience method determines the concrete `RBLN*` model class that matches the
|
|
192
|
+
underlying HuggingFace architecture and dispatches to that class's
|
|
193
|
+
`from_pretrained()` implementation. Depending on whether a compiled RBLN folder is
|
|
194
|
+
detected (or if `export=True` is passed), it will either:
|
|
195
|
+
|
|
196
|
+
- Compile from a HuggingFace checkpoint to an RBLN model
|
|
197
|
+
- Or load an already-compiled RBLN model directory/repository
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
model_id:
|
|
201
|
+
HF repo id or local path. For compiled models, this should point to a directory
|
|
202
|
+
(optionally under `subfolder`) that contains `*.rbln` files and `rbln_config.json`.
|
|
203
|
+
export:
|
|
204
|
+
Force compilation from a HuggingFace checkpoint. When `None`, this is inferred by
|
|
205
|
+
checking whether compiled artifacts exist at `model_id`.
|
|
206
|
+
rbln_config:
|
|
207
|
+
RBLN compilation/runtime configuration. May be provided as a dictionary or as an
|
|
208
|
+
instance of the specific model's config class (e.g., `RBLNLlamaForCausalLMConfig`).
|
|
209
|
+
kwargs: Additional keyword arguments.
|
|
210
|
+
- Arguments prefixed with `rbln_` are forwarded to the RBLN config.
|
|
211
|
+
- Remaining arguments are forwarded to the HuggingFace loader (e.g., `revision`,
|
|
212
|
+
`token`, `trust_remote_code`, `cache_dir`, `subfolder`, `local_files_only`).
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
An instantiated RBLN model ready for inference on RBLN NPUs.
|
|
216
|
+
"""
|
|
217
|
+
rbln_cls = cls.get_rbln_cls(model_id, export=export, **kwargs)
|
|
218
|
+
return rbln_cls.from_pretrained(model_id, export=export, rbln_config=rbln_config, **kwargs)
|
|
219
|
+
|
|
220
|
+
@classmethod
|
|
221
|
+
def from_model(
|
|
222
|
+
cls,
|
|
223
|
+
model: PreTrainedModel,
|
|
224
|
+
config: Optional[PretrainedConfig] = None,
|
|
225
|
+
rbln_config: Optional[Union[RBLNModelConfig, Dict]] = None,
|
|
226
|
+
**kwargs: Any,
|
|
227
|
+
) -> RBLNBaseModel:
|
|
228
|
+
"""
|
|
229
|
+
Convert and compile an in-memory HuggingFace model into an RBLN model.
|
|
230
|
+
|
|
231
|
+
This method resolves the appropriate concrete `RBLN*` class from the input model's class
|
|
232
|
+
name (e.g., `LlamaForCausalLM` -> `RBLNLlamaForCausalLM`) and then delegates to that
|
|
233
|
+
class's `from_model()` implementation.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
model: A HuggingFace model instance to convert.
|
|
237
|
+
config: The configuration object associated with the model.
|
|
238
|
+
rbln_config:
|
|
239
|
+
RBLN compilation/runtime configuration. May be provided as a dictionary or as an
|
|
240
|
+
instance of the specific model's config class.
|
|
241
|
+
kwargs: Additional keyword arguments.
|
|
242
|
+
- Arguments prefixed with `rbln_` are forwarded to the RBLN config.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
An instantiated RBLN model ready for inference on RBLN NPUs.
|
|
246
|
+
"""
|
|
247
|
+
rbln_cls = get_rbln_model_cls(f"RBLN{model.__class__.__name__}")
|
|
248
|
+
return rbln_cls.from_model(model, config=config, rbln_config=rbln_config, **kwargs)
|
|
249
|
+
|
|
250
|
+
@staticmethod
|
|
251
|
+
def register(rbln_cls: Type[RBLNBaseModel], exist_ok: bool = False):
|
|
252
|
+
"""
|
|
253
|
+
Register a new RBLN model class.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
rbln_cls (Type[RBLNBaseModel]): The RBLN model class to register.
|
|
257
|
+
exist_ok (bool): Whether to allow registering an already registered model.
|
|
258
|
+
"""
|
|
259
|
+
if not issubclass(rbln_cls, RBLNBaseModel):
|
|
260
|
+
raise ValueError("`rbln_cls` must be a subclass of RBLNBaseModel.")
|
|
261
|
+
|
|
262
|
+
native_cls = getattr(importlib.import_module("optimum.rbln"), rbln_cls.__name__, None)
|
|
263
|
+
if rbln_cls.__name__ in MODEL_MAPPING or native_cls is not None:
|
|
264
|
+
if not exist_ok:
|
|
265
|
+
raise ValueError(f"Model for {rbln_cls.__name__} already registered.")
|
|
266
|
+
|
|
267
|
+
MODEL_MAPPING[rbln_cls.__name__] = rbln_cls
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from transformers.models.auto.modeling_auto import (
|
|
16
|
+
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
|
|
17
|
+
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
|
18
|
+
MODEL_FOR_CAUSAL_LM_MAPPING,
|
|
19
|
+
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
|
20
|
+
MODEL_FOR_CTC_MAPPING,
|
|
21
|
+
MODEL_FOR_CTC_MAPPING_NAMES,
|
|
22
|
+
MODEL_FOR_DEPTH_ESTIMATION_MAPPING,
|
|
23
|
+
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES,
|
|
24
|
+
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
|
25
|
+
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
|
26
|
+
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING,
|
|
27
|
+
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
|
|
28
|
+
MODEL_FOR_MASKED_LM_MAPPING,
|
|
29
|
+
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
|
30
|
+
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
|
31
|
+
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
|
32
|
+
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
|
33
|
+
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
|
34
|
+
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
|
35
|
+
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
|
36
|
+
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
|
37
|
+
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
|
38
|
+
MODEL_FOR_TEXT_ENCODING_MAPPING,
|
|
39
|
+
MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES,
|
|
40
|
+
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
|
41
|
+
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
|
|
42
|
+
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
|
|
43
|
+
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES,
|
|
44
|
+
MODEL_MAPPING,
|
|
45
|
+
MODEL_MAPPING_NAMES,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
from .auto_factory import _BaseAutoModelClass
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update(
|
|
52
|
+
{
|
|
53
|
+
"midm": "MidmLMHeadModel",
|
|
54
|
+
"exaone": "ExaoneForCausalLM",
|
|
55
|
+
}
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class RBLNAutoModel(_BaseAutoModelClass):
|
|
60
|
+
"""Automatically detect all supported transformers models."""
|
|
61
|
+
|
|
62
|
+
_model_mapping = MODEL_MAPPING
|
|
63
|
+
_model_mapping_names = MODEL_MAPPING_NAMES
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class RBLNAutoModelForCTC(_BaseAutoModelClass):
|
|
67
|
+
"""Automatically detect Connectionist Temporal Classification (CTC) head Models."""
|
|
68
|
+
|
|
69
|
+
_model_mapping = MODEL_FOR_CTC_MAPPING
|
|
70
|
+
_model_mapping_names = MODEL_FOR_CTC_MAPPING_NAMES
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class RBLNAutoModelForCausalLM(_BaseAutoModelClass):
|
|
74
|
+
"""Automatically detect Casual Language Models."""
|
|
75
|
+
|
|
76
|
+
""""""
|
|
77
|
+
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
|
|
78
|
+
_model_mapping_names = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class RBLNAutoModelForSeq2SeqLM(_BaseAutoModelClass):
|
|
82
|
+
"""Automatically detect Sequence to Sequence Language Models."""
|
|
83
|
+
|
|
84
|
+
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
|
85
|
+
_model_mapping_names = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class RBLNAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
|
|
89
|
+
"""Automatically detect Sequence to Sequence Generation Models."""
|
|
90
|
+
|
|
91
|
+
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
|
92
|
+
_model_mapping_names = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class RBLNAutoModelForDepthEstimation(_BaseAutoModelClass):
|
|
96
|
+
"""Automatically detect Speech Sequence to Sequence Language Models."""
|
|
97
|
+
|
|
98
|
+
_model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
|
|
99
|
+
_model_mapping_names = MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class RBLNAutoModelForSequenceClassification(_BaseAutoModelClass):
|
|
103
|
+
"""Automatically detect Sequence Classification Models."""
|
|
104
|
+
|
|
105
|
+
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
|
106
|
+
_model_mapping_names = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class RBLNAutoModelForVision2Seq(_BaseAutoModelClass):
|
|
110
|
+
"""Automatically detect Vision to Sequence Generation Models."""
|
|
111
|
+
|
|
112
|
+
_model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
|
|
113
|
+
_model_mapping_names = MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class RBLNAutoModelForImageTextToText(_BaseAutoModelClass):
|
|
117
|
+
"""Automatically detect Image and Text to Text Generation Models."""
|
|
118
|
+
|
|
119
|
+
_model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING
|
|
120
|
+
_model_mapping_names = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class RBLNAutoModelForMaskedLM(_BaseAutoModelClass):
|
|
124
|
+
"""Automatically detect Masked Lanuage Models."""
|
|
125
|
+
|
|
126
|
+
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING
|
|
127
|
+
_model_mapping_names = MODEL_FOR_MASKED_LM_MAPPING_NAMES
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class RBLNAutoModelForAudioClassification(_BaseAutoModelClass):
|
|
131
|
+
"""Automatically detect Audio Classification Models."""
|
|
132
|
+
|
|
133
|
+
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
|
134
|
+
_model_mapping_names = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class RBLNAutoModelForImageClassification(_BaseAutoModelClass):
|
|
138
|
+
"""Automatically detect Image Classification Models."""
|
|
139
|
+
|
|
140
|
+
_model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
|
141
|
+
_model_mapping_names = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class RBLNAutoModelForQuestionAnswering(_BaseAutoModelClass):
|
|
145
|
+
"""Automatically detect Question Answering Models."""
|
|
146
|
+
|
|
147
|
+
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
|
148
|
+
_model_mapping_names = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class RBLNAutoModelForTextEncoding(_BaseAutoModelClass):
|
|
152
|
+
"""Automatically detect Text Encoding Models."""
|
|
153
|
+
|
|
154
|
+
_model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
|
|
155
|
+
_model_mapping_names = MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class RBLNAutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
|
|
159
|
+
"""Automatically detect Zero Shot Object Detection Models."""
|
|
160
|
+
|
|
161
|
+
_model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
|
|
162
|
+
_model_mapping_names = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ....ops import paged_attn_decode, paged_causal_attn_decode
|
|
16
|
+
from .configuration_bart import RBLNBartForConditionalGenerationConfig, RBLNBartModelConfig
|
|
17
|
+
from .modeling_bart import RBLNBartForConditionalGeneration, RBLNBartModel
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import nn
|
|
19
|
+
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
|
20
|
+
from transformers.utils import logging
|
|
21
|
+
|
|
22
|
+
from ..seq2seq.seq2seq_architecture import (
|
|
23
|
+
Seq2SeqCrossAttention,
|
|
24
|
+
Seq2SeqDecoder,
|
|
25
|
+
Seq2SeqDecoderLayer,
|
|
26
|
+
Seq2SeqDecoderWrapper,
|
|
27
|
+
Seq2SeqEncoderWrapper,
|
|
28
|
+
Seq2SeqForConditionalGeneration,
|
|
29
|
+
Seq2SeqSelfAttention,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
logger = logging.get_logger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class BartWrapper:
|
|
37
|
+
def __init__(self, model: nn.Module, enc_max_seq_len: int, use_attention_mask: bool):
|
|
38
|
+
self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
|
|
39
|
+
self.decoder = BartDecoderWrapper(model, use_attention_mask=use_attention_mask)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class BartDecoderWrapper(Seq2SeqDecoderWrapper):
|
|
43
|
+
def convert_to_rbln_conditional_generation(self, model: nn.Module):
|
|
44
|
+
new_layers = []
|
|
45
|
+
for layer in model.get_decoder().layers:
|
|
46
|
+
self_attn = BartSelfAttention(layer.self_attn, use_attention_mask=self.use_attention_mask)
|
|
47
|
+
cross_attn = BartCrossAttention(layer.encoder_attn)
|
|
48
|
+
new_layers.append(BartDecoderLayer(layer, self_attn, cross_attn))
|
|
49
|
+
|
|
50
|
+
decoder_model = BartDecoder(model.get_decoder(), new_layers)
|
|
51
|
+
new_model = BartForConditionalGeneration(model, decoder_model)
|
|
52
|
+
|
|
53
|
+
return new_model
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class BartForConditionalGeneration(Seq2SeqForConditionalGeneration):
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class BartDecoder(Seq2SeqDecoder):
|
|
61
|
+
has_pos_emb = True
|
|
62
|
+
|
|
63
|
+
def __post_init__(self):
|
|
64
|
+
self.embed_positions = self._original_mod.embed_positions
|
|
65
|
+
self.layernorm_embedding = self._original_mod.layernorm_embedding
|
|
66
|
+
self.embed_scale = getattr(self._original_mod, "embed_scale", None)
|
|
67
|
+
|
|
68
|
+
def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
|
|
69
|
+
if attention_mask is not None:
|
|
70
|
+
attention_mask = attention_mask[:, None, None, :]
|
|
71
|
+
encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, torch.float32, tgt_len=1)
|
|
72
|
+
|
|
73
|
+
return attention_mask, encoder_attention_mask
|
|
74
|
+
|
|
75
|
+
def apply_position_embedding(self, inputs_embeds, cache_position):
|
|
76
|
+
hidden_all = []
|
|
77
|
+
for i in range(inputs_embeds.shape[0]):
|
|
78
|
+
positions_idx = cache_position[i]
|
|
79
|
+
position_weight = self.embed_positions.weight[2:]
|
|
80
|
+
position = position_weight[positions_idx]
|
|
81
|
+
batch_hidden = position + inputs_embeds[i]
|
|
82
|
+
hidden_all.append(batch_hidden)
|
|
83
|
+
hidden_states = torch.stack(hidden_all, dim=0)
|
|
84
|
+
|
|
85
|
+
hidden_states = self.layernorm_embedding(hidden_states)
|
|
86
|
+
|
|
87
|
+
return hidden_states
|
|
88
|
+
|
|
89
|
+
def get_embedding(self):
|
|
90
|
+
if self.embed_scale is not None:
|
|
91
|
+
return lambda x: self.embed_tokens(x) * self.embed_scale
|
|
92
|
+
else:
|
|
93
|
+
return self.embed_tokens
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class BartLayerFF(nn.Module):
|
|
97
|
+
def __init__(self, decoder_layer):
|
|
98
|
+
super().__init__()
|
|
99
|
+
self.fc1 = decoder_layer.fc1
|
|
100
|
+
self.fc2 = decoder_layer.fc2
|
|
101
|
+
self.activation_fn = decoder_layer.activation_fn
|
|
102
|
+
self.layer_norm = decoder_layer.final_layer_norm
|
|
103
|
+
|
|
104
|
+
def forward(self, hidden_states):
|
|
105
|
+
# Residual Connection
|
|
106
|
+
residual = hidden_states
|
|
107
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
108
|
+
hidden_states = self.fc2(hidden_states)
|
|
109
|
+
hidden_states = residual + hidden_states
|
|
110
|
+
hidden_states = self.layer_norm(hidden_states)
|
|
111
|
+
return hidden_states
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class BartDecoderLayer(Seq2SeqDecoderLayer):
|
|
115
|
+
def __post_init__(self):
|
|
116
|
+
self.self_attn_layer_norm = self._original_mod.self_attn_layer_norm
|
|
117
|
+
self.encoder_attn = self._original_mod.encoder_attn
|
|
118
|
+
self.encoder_attn_layer_norm = self._original_mod.encoder_attn_layer_norm
|
|
119
|
+
self.ff_layer = BartLayerFF(self._original_mod)
|
|
120
|
+
|
|
121
|
+
def pre_self_attn_layer_norm(self, hidden_states):
|
|
122
|
+
return hidden_states
|
|
123
|
+
|
|
124
|
+
def post_self_attn_layer_norm(self, hidden_states):
|
|
125
|
+
return self.self_attn_layer_norm(hidden_states)
|
|
126
|
+
|
|
127
|
+
def pre_cross_attn_layer_norm(self, hidden_states):
|
|
128
|
+
return hidden_states
|
|
129
|
+
|
|
130
|
+
def post_cross_attn_layer_norm(self, hidden_states):
|
|
131
|
+
return self.encoder_attn_layer_norm(hidden_states)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class BartSelfAttention(Seq2SeqSelfAttention):
|
|
135
|
+
def __post_init__(self, use_attention_mask: bool = True):
|
|
136
|
+
self.q_proj = self._original_mod.q_proj
|
|
137
|
+
self.k_proj = self._original_mod.k_proj
|
|
138
|
+
self.v_proj = self._original_mod.v_proj
|
|
139
|
+
self.out_proj = self._original_mod.out_proj
|
|
140
|
+
self.num_heads = self._original_mod.num_heads
|
|
141
|
+
self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
|
|
142
|
+
self.scaling = self.head_dim**-0.5
|
|
143
|
+
if use_attention_mask:
|
|
144
|
+
self.attn_decode = torch.ops.rbln_custom_ops.paged_attn_decode
|
|
145
|
+
else:
|
|
146
|
+
self.attn_decode = torch.ops.rbln_custom_ops.paged_causal_attn_decode
|
|
147
|
+
|
|
148
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
149
|
+
query_states = self.q_proj(hidden_states) * self.scaling
|
|
150
|
+
key_states = self.k_proj(hidden_states)
|
|
151
|
+
value_states = self.v_proj(hidden_states)
|
|
152
|
+
return query_states, key_states, value_states
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class BartCrossAttention(Seq2SeqCrossAttention):
|
|
156
|
+
def __post_init__(self):
|
|
157
|
+
self.q_proj = self._original_mod.q_proj
|
|
158
|
+
self.k_proj = self._original_mod.k_proj
|
|
159
|
+
self.v_proj = self._original_mod.v_proj
|
|
160
|
+
self.out_proj = self._original_mod.out_proj
|
|
161
|
+
self.num_heads = self._original_mod.num_heads
|
|
162
|
+
self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
|
|
163
|
+
self.embed_dim = self._original_mod.embed_dim
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
|
|
16
|
+
from ..seq2seq import RBLNModelForSeq2SeqLMConfig
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RBLNBartModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
|
|
20
|
+
"""
|
|
21
|
+
Configuration class for RBLNBartModel.
|
|
22
|
+
|
|
23
|
+
This configuration class stores the configuration parameters specific to
|
|
24
|
+
RBLN-optimized BART models for feature extraction tasks.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class RBLNBartForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
|
|
29
|
+
"""
|
|
30
|
+
Configuration class for RBLNBartForConditionalGeneration.
|
|
31
|
+
|
|
32
|
+
This configuration class stores the configuration parameters specific to
|
|
33
|
+
RBLN-optimized BART models for conditional text generation tasks.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
support_paged_attention = True
|