optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,637 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import importlib
|
|
16
|
+
import os
|
|
17
|
+
import shutil
|
|
18
|
+
from abc import ABC
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from tempfile import TemporaryDirectory
|
|
21
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
|
22
|
+
|
|
23
|
+
import rebel
|
|
24
|
+
import torch
|
|
25
|
+
from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConfig
|
|
26
|
+
from transformers.utils.hub import PushToHubMixin
|
|
27
|
+
|
|
28
|
+
from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig, get_rbln_config_class
|
|
29
|
+
from .utils.hub import pull_compiled_model_from_hub, validate_files
|
|
30
|
+
from .utils.logging import get_logger
|
|
31
|
+
from .utils.runtime_utils import UnavailableRuntime, tp_and_devices_are_ok
|
|
32
|
+
from .utils.save_utils import maybe_load_preprocessors
|
|
33
|
+
from .utils.submodule import SubModulesMixin
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
38
|
+
|
|
39
|
+
logger = get_logger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class PreTrainedModel(ABC): # noqa: F811
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class RBLNBaseModelConfig(RBLNModelConfig):
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
51
|
+
model_type = "rbln_model"
|
|
52
|
+
auto_model_class = AutoModel
|
|
53
|
+
config_class = AutoConfig
|
|
54
|
+
config_name = "config.json"
|
|
55
|
+
hf_library_name = "transformers"
|
|
56
|
+
_supports_non_fp32 = False
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
models: List[rebel.Runtime],
|
|
61
|
+
config: "PretrainedConfig",
|
|
62
|
+
rbln_config: RBLNModelConfig,
|
|
63
|
+
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
|
64
|
+
subfolder: str = "",
|
|
65
|
+
rbln_compiled_models: Optional[rebel.RBLNCompiledModel] = None,
|
|
66
|
+
rbln_submodules: List["RBLNBaseModel"] = [],
|
|
67
|
+
**kwargs,
|
|
68
|
+
):
|
|
69
|
+
self.model = models
|
|
70
|
+
self.config = config
|
|
71
|
+
self.rbln_config = rbln_config
|
|
72
|
+
if not rbln_config.is_frozen():
|
|
73
|
+
raise RuntimeError("`rbln_config` must be frozen. Please call `rbln_config.freeze()` first.")
|
|
74
|
+
self.compiled_models = rbln_compiled_models
|
|
75
|
+
|
|
76
|
+
# Registers the RBLN classes into the transformers AutoModel classes to avoid warnings when creating
|
|
77
|
+
# a pipeline https://github.com/huggingface/transformers/blob/3d3204c025b6b5de013e07dd364208e28b4d9589/src/transformers/pipelines/base.py#L940
|
|
78
|
+
AutoConfig.register(self.model_type, AutoConfig)
|
|
79
|
+
if hasattr(self.auto_model_class, "register"):
|
|
80
|
+
self.auto_model_class.register(AutoConfig, self.__class__)
|
|
81
|
+
|
|
82
|
+
# copied from tranformers PreTrainedModel __init__
|
|
83
|
+
if self.can_generate():
|
|
84
|
+
gen_config_dir = model_save_dir.name if isinstance(model_save_dir, TemporaryDirectory) else model_save_dir
|
|
85
|
+
self.generation_config = GenerationConfig.from_pretrained(gen_config_dir, trust_remote_code=True)
|
|
86
|
+
else:
|
|
87
|
+
self.generation_config = None
|
|
88
|
+
|
|
89
|
+
if self.generation_config is not None:
|
|
90
|
+
self.generation_config.use_cache = True
|
|
91
|
+
|
|
92
|
+
self.device = torch.device("cpu")
|
|
93
|
+
self.training = False
|
|
94
|
+
self.dtype = rbln_config.torch_dtype
|
|
95
|
+
|
|
96
|
+
# FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
|
|
97
|
+
# This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
|
|
98
|
+
# would end-up removing the directory containing the underlying RBLN model.
|
|
99
|
+
self._model_save_dir_tempdirectory_instance = None
|
|
100
|
+
if isinstance(model_save_dir, TemporaryDirectory):
|
|
101
|
+
self._model_save_dir_tempdirectory_instance = model_save_dir
|
|
102
|
+
self.model_save_dir = Path(model_save_dir.name)
|
|
103
|
+
elif isinstance(model_save_dir, str):
|
|
104
|
+
self.model_save_dir = Path(model_save_dir)
|
|
105
|
+
else:
|
|
106
|
+
self.model_save_dir = model_save_dir
|
|
107
|
+
self.subfolder = subfolder
|
|
108
|
+
|
|
109
|
+
self.rbln_submodules = rbln_submodules
|
|
110
|
+
self.__post_init__(**kwargs)
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def _load_compiled_model_dir(
|
|
114
|
+
cls,
|
|
115
|
+
model_id: Union[str, Path],
|
|
116
|
+
token: Optional[Union[bool, str]] = None,
|
|
117
|
+
revision: Optional[str] = None,
|
|
118
|
+
force_download: bool = False,
|
|
119
|
+
cache_dir: Optional[str] = None,
|
|
120
|
+
subfolder: str = "",
|
|
121
|
+
local_files_only: bool = False,
|
|
122
|
+
) -> str:
|
|
123
|
+
# Load the directory containing the compiled model files.
|
|
124
|
+
model_path = Path(model_id)
|
|
125
|
+
|
|
126
|
+
if model_path.is_dir():
|
|
127
|
+
model_path = model_path / subfolder
|
|
128
|
+
rbln_files = list(model_path.glob("*.rbln"))
|
|
129
|
+
rbln_config_filenames = list(model_path.glob("rbln_config.json"))
|
|
130
|
+
validate_files(rbln_files, rbln_config_filenames, f"directory {model_path}")
|
|
131
|
+
else:
|
|
132
|
+
model_path = pull_compiled_model_from_hub(
|
|
133
|
+
model_id=model_id,
|
|
134
|
+
subfolder=subfolder,
|
|
135
|
+
token=token,
|
|
136
|
+
revision=revision,
|
|
137
|
+
cache_dir=cache_dir,
|
|
138
|
+
force_download=force_download,
|
|
139
|
+
local_files_only=local_files_only,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return str(model_path)
|
|
143
|
+
|
|
144
|
+
@classmethod
|
|
145
|
+
def _load_compiled_models(cls, model_path: str, expected_compiled_model_names: List[str]):
|
|
146
|
+
compiled_models = Path(model_path).glob("*.rbln")
|
|
147
|
+
expected_compiled_models = [
|
|
148
|
+
Path(model_path) / f"{compiled_model_name}.rbln" for compiled_model_name in expected_compiled_model_names
|
|
149
|
+
]
|
|
150
|
+
unexpected_compiled_models = [cm for cm in compiled_models if cm not in expected_compiled_models]
|
|
151
|
+
if unexpected_compiled_models:
|
|
152
|
+
# TODO(jongho): fix after May release. raise error if unexpected compiled models are found
|
|
153
|
+
logger.warning(
|
|
154
|
+
f"Unexpected compiled models found: {[cm.name for cm in unexpected_compiled_models]}. "
|
|
155
|
+
f"Please check the model path: {model_path}"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
rbln_compiled_models = {}
|
|
159
|
+
for compiled_model in expected_compiled_models:
|
|
160
|
+
if not compiled_model.exists():
|
|
161
|
+
raise FileNotFoundError(
|
|
162
|
+
f"Expected RBLN compiled model '{compiled_model.name}' not found at '{model_path}'. "
|
|
163
|
+
"Please ensure all models specified in `rbln_config` are present."
|
|
164
|
+
)
|
|
165
|
+
rbln_compiled_models[compiled_model.stem] = rebel.RBLNCompiledModel(compiled_model)
|
|
166
|
+
return rbln_compiled_models
|
|
167
|
+
|
|
168
|
+
@classmethod
|
|
169
|
+
def _from_pretrained(
|
|
170
|
+
cls,
|
|
171
|
+
model_id: Union[str, Path],
|
|
172
|
+
config: Optional["PretrainedConfig"] = None,
|
|
173
|
+
token: Optional[Union[bool, str]] = None,
|
|
174
|
+
revision: Optional[str] = None,
|
|
175
|
+
force_download: bool = False,
|
|
176
|
+
cache_dir: Optional[str] = None,
|
|
177
|
+
subfolder: str = "",
|
|
178
|
+
local_files_only: bool = False,
|
|
179
|
+
trust_remote_code: bool = False,
|
|
180
|
+
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
|
181
|
+
# passed from compile function
|
|
182
|
+
rbln_config: Optional[RBLNModelConfig] = None,
|
|
183
|
+
rbln_compiled_models: Optional[Dict[str, rebel.RBLNCompiledModel]] = None,
|
|
184
|
+
rbln_submodules: List["RBLNBaseModel"] = [],
|
|
185
|
+
**kwargs,
|
|
186
|
+
) -> "RBLNBaseModel":
|
|
187
|
+
if rbln_compiled_models is None:
|
|
188
|
+
model_path_subfolder = cls._load_compiled_model_dir(
|
|
189
|
+
model_id=model_id,
|
|
190
|
+
token=token,
|
|
191
|
+
revision=revision,
|
|
192
|
+
force_download=force_download,
|
|
193
|
+
cache_dir=cache_dir,
|
|
194
|
+
subfolder=subfolder,
|
|
195
|
+
local_files_only=local_files_only,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if isinstance(rbln_config, dict):
|
|
199
|
+
rbln_config_as_kwargs = {f"rbln_{key}": value for key, value in rbln_config.items()}
|
|
200
|
+
kwargs.update(rbln_config_as_kwargs)
|
|
201
|
+
rbln_config = None
|
|
202
|
+
elif isinstance(rbln_config, RBLNModelConfig) and rbln_config.rbln_model_cls_name != cls.__name__:
|
|
203
|
+
raise ValueError(
|
|
204
|
+
f"Cannot use the passed rbln_config. Its model class name ({rbln_config.rbln_model_cls_name}) "
|
|
205
|
+
f"does not match the expected model class name ({cls.__name__})."
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
rbln_config, kwargs = RBLNAutoConfig.load(
|
|
209
|
+
model_path_subfolder, passed_rbln_config=rbln_config, kwargs=kwargs, return_unused_kwargs=True
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
if rbln_config.rbln_model_cls_name != cls.__name__:
|
|
213
|
+
raise NameError(
|
|
214
|
+
f"Cannot load the model. The model was originally compiled using "
|
|
215
|
+
f"{rbln_config.rbln_model_cls_name}, but you are trying to load it with {cls.__name__}."
|
|
216
|
+
"Please use the same model class that was used during compilation."
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
if len(cls._rbln_submodules) > 0:
|
|
220
|
+
rbln_submodules = cls._load_submodules(model_save_dir=model_id, rbln_config=rbln_config, **kwargs)
|
|
221
|
+
else:
|
|
222
|
+
rbln_submodules = []
|
|
223
|
+
|
|
224
|
+
rbln_config.freeze()
|
|
225
|
+
|
|
226
|
+
if config is None:
|
|
227
|
+
if cls.hf_library_name == "transformers":
|
|
228
|
+
config = AutoConfig.from_pretrained(
|
|
229
|
+
model_path_subfolder,
|
|
230
|
+
cache_dir=cache_dir,
|
|
231
|
+
force_download=force_download,
|
|
232
|
+
revision=revision,
|
|
233
|
+
token=token,
|
|
234
|
+
trust_remote_code=trust_remote_code,
|
|
235
|
+
)
|
|
236
|
+
elif cls.hf_library_name == "diffusers":
|
|
237
|
+
# import here to prevent diffusers dependency
|
|
238
|
+
# TODO(jongho): Remove diffusers dependency if use transformers only.
|
|
239
|
+
from diffusers.configuration_utils import ConfigMixin
|
|
240
|
+
|
|
241
|
+
class DummyConfigMixin(ConfigMixin):
|
|
242
|
+
# Just to load config, We need to specify `config_name`
|
|
243
|
+
config_name = "config.json"
|
|
244
|
+
|
|
245
|
+
config = DummyConfigMixin.load_config(
|
|
246
|
+
model_id,
|
|
247
|
+
cache_dir=cache_dir,
|
|
248
|
+
force_download=force_download,
|
|
249
|
+
local_files_only=local_files_only,
|
|
250
|
+
revision=revision,
|
|
251
|
+
token=token,
|
|
252
|
+
subfolder=subfolder,
|
|
253
|
+
)
|
|
254
|
+
config = PretrainedConfig(**config)
|
|
255
|
+
|
|
256
|
+
compiled_model_names = [cfg.compiled_model_name for cfg in rbln_config.compile_cfgs]
|
|
257
|
+
rbln_compiled_models = cls._load_compiled_models(model_path_subfolder, compiled_model_names)
|
|
258
|
+
|
|
259
|
+
if subfolder != "":
|
|
260
|
+
model_save_dir = Path(model_path_subfolder).absolute().parent
|
|
261
|
+
else:
|
|
262
|
+
model_save_dir = Path(model_path_subfolder).absolute()
|
|
263
|
+
|
|
264
|
+
return cls._from_compiled_models(
|
|
265
|
+
rbln_compiled_models=rbln_compiled_models,
|
|
266
|
+
rbln_config=rbln_config,
|
|
267
|
+
config=config,
|
|
268
|
+
model_save_dir=model_save_dir,
|
|
269
|
+
subfolder=subfolder,
|
|
270
|
+
rbln_submodules=rbln_submodules,
|
|
271
|
+
**kwargs,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
@classmethod
|
|
275
|
+
def _from_compiled_models(
|
|
276
|
+
cls,
|
|
277
|
+
rbln_compiled_models: Dict[str, rebel.RBLNCompiledModel],
|
|
278
|
+
rbln_config: RBLNModelConfig,
|
|
279
|
+
config: "PretrainedConfig",
|
|
280
|
+
model_save_dir: Union[Path, str],
|
|
281
|
+
subfolder: Union[Path, str],
|
|
282
|
+
rbln_submodules: List["RBLNBaseModel"] = [],
|
|
283
|
+
**kwargs,
|
|
284
|
+
):
|
|
285
|
+
if isinstance(model_save_dir, str):
|
|
286
|
+
model_save_dir = Path(model_save_dir)
|
|
287
|
+
|
|
288
|
+
# FIXME:: Should we convert it?
|
|
289
|
+
compiled_model_names = [cfg.compiled_model_name for cfg in rbln_config.compile_cfgs]
|
|
290
|
+
rbln_compiled_models = [rbln_compiled_models[cm_name] for cm_name in compiled_model_names]
|
|
291
|
+
|
|
292
|
+
# create runtimes only if `rbln_create_runtimes` is enabled
|
|
293
|
+
try:
|
|
294
|
+
models = (
|
|
295
|
+
cls._create_runtimes(rbln_compiled_models, rbln_config)
|
|
296
|
+
if rbln_config.create_runtimes
|
|
297
|
+
else UnavailableRuntime()
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
except rebel.core.exception.RBLNRuntimeError as e:
|
|
301
|
+
error_msg = (
|
|
302
|
+
f"\nFailed to create RBLN runtime: {str(e)}\n\n"
|
|
303
|
+
f"If you only need to compile the model without loading it to NPU, you can use:\n"
|
|
304
|
+
f" from_pretrained(..., rbln_create_runtimes=False) or\n"
|
|
305
|
+
f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
|
|
306
|
+
f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
|
|
307
|
+
f"Make sure your NPU is properly installed and operational."
|
|
308
|
+
)
|
|
309
|
+
raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
|
|
310
|
+
|
|
311
|
+
return cls(
|
|
312
|
+
models,
|
|
313
|
+
config,
|
|
314
|
+
rbln_config,
|
|
315
|
+
model_save_dir=model_save_dir,
|
|
316
|
+
subfolder=subfolder,
|
|
317
|
+
rbln_compiled_models=rbln_compiled_models,
|
|
318
|
+
rbln_submodules=rbln_submodules,
|
|
319
|
+
**kwargs,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
@classmethod
|
|
323
|
+
def _export(cls, model_id: Union[str, Path], **kwargs) -> "RBLNBaseModel":
|
|
324
|
+
subfolder = kwargs.get("subfolder", "")
|
|
325
|
+
model_save_dir = kwargs.pop("model_save_dir", None)
|
|
326
|
+
|
|
327
|
+
rbln_config, kwargs = cls.prepare_rbln_config(**kwargs)
|
|
328
|
+
|
|
329
|
+
model: "PreTrainedModel" = cls.get_pytorch_model(model_id=model_id, rbln_config=rbln_config, **kwargs)
|
|
330
|
+
preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
|
|
331
|
+
return cls.from_model(
|
|
332
|
+
model, preprocessors=preprocessors, model_save_dir=model_save_dir, rbln_config=rbln_config, **kwargs
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
@classmethod
|
|
336
|
+
def prepare_rbln_config(
|
|
337
|
+
cls, rbln_config: Optional[Union[Dict[str, Any], RBLNModelConfig]] = None, **kwargs
|
|
338
|
+
) -> Tuple[RBLNModelConfig, Dict[str, Any]]:
|
|
339
|
+
# Extract rbln-config from kwargs and convert it to RBLNModelConfig.
|
|
340
|
+
|
|
341
|
+
config_cls = cls.get_rbln_config_class()
|
|
342
|
+
rbln_config, kwargs = config_cls.initialize_from_kwargs(rbln_config, **kwargs)
|
|
343
|
+
return rbln_config, kwargs
|
|
344
|
+
|
|
345
|
+
@classmethod
|
|
346
|
+
def _is_compiled(
|
|
347
|
+
cls,
|
|
348
|
+
model_id: Union[str, Path],
|
|
349
|
+
token: Optional[Union[bool, str]] = None,
|
|
350
|
+
revision: Optional[str] = None,
|
|
351
|
+
force_download: bool = False,
|
|
352
|
+
cache_dir: Optional[str] = None,
|
|
353
|
+
subfolder: str = "",
|
|
354
|
+
local_files_only: bool = False,
|
|
355
|
+
) -> bool:
|
|
356
|
+
# Check if the model is already compiled.
|
|
357
|
+
try:
|
|
358
|
+
cls._load_compiled_model_dir(
|
|
359
|
+
model_id=model_id,
|
|
360
|
+
token=token,
|
|
361
|
+
revision=revision,
|
|
362
|
+
force_download=force_download,
|
|
363
|
+
cache_dir=cache_dir,
|
|
364
|
+
subfolder=subfolder,
|
|
365
|
+
local_files_only=local_files_only,
|
|
366
|
+
)
|
|
367
|
+
return True
|
|
368
|
+
except (FileNotFoundError, KeyError):
|
|
369
|
+
return False
|
|
370
|
+
|
|
371
|
+
@classmethod
|
|
372
|
+
def from_pretrained(
|
|
373
|
+
cls: Type["RBLNBaseModel"],
|
|
374
|
+
model_id: Union[str, Path],
|
|
375
|
+
export: Optional[bool] = None,
|
|
376
|
+
rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
|
|
377
|
+
**kwargs: Any,
|
|
378
|
+
) -> "RBLNBaseModel":
|
|
379
|
+
"""
|
|
380
|
+
The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
|
|
381
|
+
User can use this function to load a pre-trained model from the HuggingFace library and convert it to a RBLN model to be run on RBLN NPUs.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
model_id (Union[str, Path]): The model id of the pre-trained model to be loaded.
|
|
385
|
+
It can be downloaded from the HuggingFace model hub or a local path, or a model id of a compiled model using the RBLN Compiler.
|
|
386
|
+
export (Optional[bool]): A boolean flag to indicate whether the model should be compiled.
|
|
387
|
+
If None, it will be determined based on the existence of the compiled model files in the model_id.
|
|
388
|
+
rbln_config (Optional[Union[Dict, RBLNModelConfig]]): Configuration for RBLN model compilation and runtime.
|
|
389
|
+
This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
|
|
390
|
+
For detailed configuration options, see the specific model's configuration class documentation.
|
|
391
|
+
kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
|
|
392
|
+
|
|
393
|
+
Returns:
|
|
394
|
+
(RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
|
|
395
|
+
"""
|
|
396
|
+
|
|
397
|
+
if isinstance(model_id, Path):
|
|
398
|
+
model_id = model_id.as_posix()
|
|
399
|
+
|
|
400
|
+
if export is None:
|
|
401
|
+
export = not cls._is_compiled(
|
|
402
|
+
model_id=model_id,
|
|
403
|
+
token=kwargs.get("token"),
|
|
404
|
+
revision=kwargs.get("revision"),
|
|
405
|
+
force_download=kwargs.get("force_download", False),
|
|
406
|
+
cache_dir=kwargs.get("cache_dir"),
|
|
407
|
+
subfolder=kwargs.get("subfolder", ""),
|
|
408
|
+
local_files_only=kwargs.get("local_files_only", False),
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
from_pretrained_method = cls._export if export else cls._from_pretrained
|
|
412
|
+
return from_pretrained_method(model_id=model_id, **kwargs, rbln_config=rbln_config)
|
|
413
|
+
|
|
414
|
+
@classmethod
|
|
415
|
+
def compile(
|
|
416
|
+
cls,
|
|
417
|
+
model,
|
|
418
|
+
rbln_compile_config: RBLNCompileConfig,
|
|
419
|
+
create_runtimes: bool,
|
|
420
|
+
device: Union[int, List[int]],
|
|
421
|
+
**kwargs,
|
|
422
|
+
):
|
|
423
|
+
if create_runtimes:
|
|
424
|
+
runtime_cannot_be_created = tp_and_devices_are_ok(
|
|
425
|
+
tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
|
|
426
|
+
device=device,
|
|
427
|
+
npu=rbln_compile_config.npu,
|
|
428
|
+
)
|
|
429
|
+
if runtime_cannot_be_created:
|
|
430
|
+
raise ValueError(runtime_cannot_be_created)
|
|
431
|
+
|
|
432
|
+
compiled_model = rebel.compile_from_torch(
|
|
433
|
+
model,
|
|
434
|
+
input_info=rbln_compile_config.input_info,
|
|
435
|
+
npu=rbln_compile_config.npu,
|
|
436
|
+
tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
|
|
437
|
+
**kwargs,
|
|
438
|
+
)
|
|
439
|
+
return compiled_model
|
|
440
|
+
|
|
441
|
+
@classmethod
|
|
442
|
+
def update_rbln_config(
|
|
443
|
+
cls,
|
|
444
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
445
|
+
model: "PreTrainedModel",
|
|
446
|
+
model_config: "PretrainedConfig",
|
|
447
|
+
rbln_config: RBLNModelConfig,
|
|
448
|
+
) -> RBLNModelConfig:
|
|
449
|
+
rbln_config.torch_dtype = model.dtype
|
|
450
|
+
if not cls._supports_non_fp32 and rbln_config.torch_dtype != torch.float32:
|
|
451
|
+
raise NotImplementedError(
|
|
452
|
+
f"Currently, {cls.__name__} does not support non-fp32 dtype. Please use float32 dtype."
|
|
453
|
+
)
|
|
454
|
+
rbln_config = cls._update_rbln_config(
|
|
455
|
+
preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
|
|
456
|
+
)
|
|
457
|
+
rbln_config.freeze()
|
|
458
|
+
if rbln_config.rbln_model_cls_name != cls.__name__:
|
|
459
|
+
raise NameError(
|
|
460
|
+
f"Cannot get the rbln config. {cls.__name__} is not the same as {rbln_config.rbln_model_cls_name}. "
|
|
461
|
+
"This is an internal error. Please report it to the developers."
|
|
462
|
+
)
|
|
463
|
+
return rbln_config
|
|
464
|
+
|
|
465
|
+
@classmethod
|
|
466
|
+
def get_hf_class(cls):
|
|
467
|
+
# Lazily loads and caches the corresponding HuggingFace model class.
|
|
468
|
+
# Removes 'RBLN' prefix from the class name to get the original class name
|
|
469
|
+
# (e.g., RBLNLlamaForCausalLM -> LlamaForCausalLM) and imports it from
|
|
470
|
+
# the transformers/diffusers module.
|
|
471
|
+
|
|
472
|
+
# Returns:
|
|
473
|
+
# type: The original HuggingFace model class
|
|
474
|
+
if "_hf_class" not in cls.__dict__ or cls._hf_class is None:
|
|
475
|
+
hf_cls_name = cls.__name__[4:]
|
|
476
|
+
library = importlib.import_module(cls.hf_library_name)
|
|
477
|
+
cls._hf_class = getattr(library, hf_cls_name, None)
|
|
478
|
+
return cls._hf_class
|
|
479
|
+
|
|
480
|
+
@classmethod
|
|
481
|
+
def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
|
|
482
|
+
# Lazily loads and caches the corresponding RBLN model config class.
|
|
483
|
+
if "_rbln_config_class" not in cls.__dict__ or cls._rbln_config_class is None:
|
|
484
|
+
rbln_config_class_name = cls.__name__ + "Config"
|
|
485
|
+
cls._rbln_config_class = get_rbln_config_class(rbln_config_class_name)
|
|
486
|
+
return cls._rbln_config_class
|
|
487
|
+
|
|
488
|
+
def can_generate(self):
|
|
489
|
+
return False
|
|
490
|
+
|
|
491
|
+
def to(self, *args, **kwargs):
|
|
492
|
+
return self
|
|
493
|
+
|
|
494
|
+
def parameters(self):
|
|
495
|
+
# A dummy parameter generator for compatibility.
|
|
496
|
+
|
|
497
|
+
# This method mimics the interface of torch.nn.Module.parameters()
|
|
498
|
+
# specifically for code that uses `next(model.parameters())` to infer
|
|
499
|
+
# the device or dtype. It yields a single dummy tensor on CPU with model dtype.
|
|
500
|
+
|
|
501
|
+
# Warning:
|
|
502
|
+
# This does NOT yield the actual model parameters used by the RBLN runtime.
|
|
503
|
+
# Code relying on iterating through all model parameters will not work as expected.
|
|
504
|
+
yield torch.tensor([1.0], dtype=self.dtype, device=torch.device("cpu"))
|
|
505
|
+
|
|
506
|
+
def __call__(self, *args, **kwargs):
|
|
507
|
+
return self.forward(*args, **kwargs)
|
|
508
|
+
|
|
509
|
+
def __repr__(self):
|
|
510
|
+
has_submodules = len(self.rbln_submodules) > 0
|
|
511
|
+
repr_str: str = f"<{self.__class__.__name__}>\n"
|
|
512
|
+
repr_str += f"- Total {len(self.model)} Runtimes"
|
|
513
|
+
repr_str += f" and {len(self.rbln_submodules)} Submodules\n" if has_submodules else "\n"
|
|
514
|
+
repr_str += "[Runtimes]\n"
|
|
515
|
+
repr_str += "\n".join([repr(model) for model in self.model])
|
|
516
|
+
repr_str += "\n"
|
|
517
|
+
|
|
518
|
+
if has_submodules > 0:
|
|
519
|
+
for i, submodule in enumerate(self.rbln_submodules):
|
|
520
|
+
repr_str += f"[Submodules {i} : {self._rbln_submodules[i]['name']}]\n"
|
|
521
|
+
repr_str += repr(submodule) + "\n"
|
|
522
|
+
|
|
523
|
+
return repr_str
|
|
524
|
+
|
|
525
|
+
def __post_init__(self, **kwargs):
|
|
526
|
+
pass
|
|
527
|
+
|
|
528
|
+
def save_pretrained(
|
|
529
|
+
self,
|
|
530
|
+
save_directory: Union[str, Path],
|
|
531
|
+
push_to_hub: bool = False,
|
|
532
|
+
**kwargs,
|
|
533
|
+
):
|
|
534
|
+
"""
|
|
535
|
+
Saves a model and its configuration file to a directory, so that it can be re-loaded using the
|
|
536
|
+
[`~optimum.rbln.modeling_base.RBLNBaseModel.from_pretrained`] class method.
|
|
537
|
+
|
|
538
|
+
Args:
|
|
539
|
+
save_directory (Union[str, Path]):
|
|
540
|
+
Directory where to save the model file.
|
|
541
|
+
push_to_hub (bool):
|
|
542
|
+
Whether or not to push your model to the HuggingFace model hub after saving it.
|
|
543
|
+
|
|
544
|
+
"""
|
|
545
|
+
if os.path.isfile(save_directory):
|
|
546
|
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
|
547
|
+
return
|
|
548
|
+
|
|
549
|
+
# Normalize paths to handle relative paths and symlinks
|
|
550
|
+
real_save_dir = Path(self.model_save_dir).resolve() / self.subfolder
|
|
551
|
+
save_directory_path = Path(save_directory).resolve()
|
|
552
|
+
|
|
553
|
+
if not os.path.exists(real_save_dir) or not os.path.isdir(real_save_dir):
|
|
554
|
+
raise FileNotFoundError(
|
|
555
|
+
f"Unable to save the model. The model directory '{real_save_dir}' does not exist or is not accessible. "
|
|
556
|
+
f"Cannot save to the specified destination '{save_directory}'. "
|
|
557
|
+
f"Please ensure the model directory exists and you have the necessary permissions to access it."
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
if isinstance(self.config, PretrainedConfig):
|
|
561
|
+
self.config.save_pretrained(real_save_dir)
|
|
562
|
+
|
|
563
|
+
if save_directory_path == real_save_dir:
|
|
564
|
+
raise FileExistsError(
|
|
565
|
+
f"Cannot save model to '{save_directory}'. This directory already exists and contains the model files."
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
# Create a temporary directory with normalized path
|
|
569
|
+
tmp_dir = str(save_directory_path) + ".tmp"
|
|
570
|
+
try:
|
|
571
|
+
# Remove temporary directory if it exists from a previous failed attempt
|
|
572
|
+
if os.path.exists(tmp_dir):
|
|
573
|
+
shutil.rmtree(tmp_dir)
|
|
574
|
+
|
|
575
|
+
# First copy everything to a temporary directory
|
|
576
|
+
shutil.copytree(real_save_dir, tmp_dir)
|
|
577
|
+
|
|
578
|
+
# If everything succeeded, move files to target directory
|
|
579
|
+
if os.path.exists(save_directory_path):
|
|
580
|
+
# Merge files from tmp_dir into existing directory
|
|
581
|
+
def _merge_dir(src_root: str, dst_root: str):
|
|
582
|
+
for name in os.listdir(src_root):
|
|
583
|
+
src_item = os.path.join(src_root, name)
|
|
584
|
+
dst_item = os.path.join(dst_root, name)
|
|
585
|
+
|
|
586
|
+
if os.path.islink(src_item) or os.path.isfile(src_item):
|
|
587
|
+
os.makedirs(os.path.dirname(dst_item), exist_ok=True)
|
|
588
|
+
if os.path.isdir(dst_item) and not os.path.islink(dst_item):
|
|
589
|
+
shutil.rmtree(dst_item)
|
|
590
|
+
os.replace(src_item, dst_item)
|
|
591
|
+
elif os.path.isdir(src_item):
|
|
592
|
+
if os.path.islink(dst_item) or os.path.isfile(dst_item):
|
|
593
|
+
os.remove(dst_item)
|
|
594
|
+
os.makedirs(dst_item, exist_ok=True)
|
|
595
|
+
_merge_dir(src_item, dst_item)
|
|
596
|
+
else:
|
|
597
|
+
# Fallback for special file types
|
|
598
|
+
os.replace(src_item, dst_item)
|
|
599
|
+
|
|
600
|
+
_merge_dir(tmp_dir, str(save_directory_path))
|
|
601
|
+
|
|
602
|
+
# Remove the temporary directory tree after merge
|
|
603
|
+
shutil.rmtree(tmp_dir)
|
|
604
|
+
else:
|
|
605
|
+
# If target doesn't exist, just rename tmp_dir to target
|
|
606
|
+
os.rename(tmp_dir, save_directory_path)
|
|
607
|
+
|
|
608
|
+
except Exception as e:
|
|
609
|
+
# Clean up the temporary directory if anything fails
|
|
610
|
+
if os.path.exists(tmp_dir):
|
|
611
|
+
shutil.rmtree(tmp_dir)
|
|
612
|
+
raise e # Re-raise the exception after cleanup
|
|
613
|
+
|
|
614
|
+
if push_to_hub:
|
|
615
|
+
repo_id = kwargs.pop("repo_id", None)
|
|
616
|
+
if repo_id is None:
|
|
617
|
+
raise ValueError("`repo_id` must be provided to push the model to the HuggingFace model hub.")
|
|
618
|
+
return super().push_to_hub(repo_id=repo_id, **kwargs)
|
|
619
|
+
|
|
620
|
+
@staticmethod
|
|
621
|
+
def _raise_missing_compiled_file_error(missing_files: List[str]):
|
|
622
|
+
# Raises a KeyError with a message indicating missing compiled model files.
|
|
623
|
+
|
|
624
|
+
if len(missing_files) == 1:
|
|
625
|
+
message = f"The rbln model folder is missing the required '{missing_files[0]}.rbln' file. "
|
|
626
|
+
else:
|
|
627
|
+
files_str = ", ".join([f"'{f}.rbln'" for f in missing_files])
|
|
628
|
+
message = (
|
|
629
|
+
"The rbln model folder is missing required files. "
|
|
630
|
+
f"Ensure that {files_str} files are present in the folder. "
|
|
631
|
+
)
|
|
632
|
+
message += (
|
|
633
|
+
"These files are necessary for loading the rbln model. "
|
|
634
|
+
"If these files are missing, please recompile the model using the latest optimum-rbln "
|
|
635
|
+
"and ensure the compilation completes successfully."
|
|
636
|
+
)
|
|
637
|
+
raise KeyError(message)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .attn import *
|
|
16
|
+
from .flash_attn import *
|
|
17
|
+
from .kv_cache_update import *
|
|
18
|
+
from .linear import linear
|
|
19
|
+
from .sliding_window_attn import *
|