optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import copy
|
|
16
|
+
import importlib
|
|
17
|
+
from os import PathLike
|
|
18
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
from ..configuration_utils import ContextRblnConfig, RBLNModelConfig, get_rbln_config_class
|
|
23
|
+
from ..modeling import RBLNModel
|
|
24
|
+
from ..utils.decorator_utils import remove_compile_time_kwargs
|
|
25
|
+
from ..utils.logging import get_logger
|
|
26
|
+
from ..utils.model_utils import get_rbln_model_cls
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
logger = get_logger(__name__)
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class RBLNDiffusionMixinConfig(RBLNModelConfig):
|
|
36
|
+
"""
|
|
37
|
+
Configuration class for RBLN diffusion pipelines.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class RBLNDiffusionMixin:
|
|
44
|
+
"""
|
|
45
|
+
RBLNDiffusionMixin provides essential functionalities for compiling Stable Diffusion pipeline components to run on RBLN NPUs.
|
|
46
|
+
This mixin class serves as a base for implementing RBLN-compatible Stable Diffusion pipelines. It contains shared logic for
|
|
47
|
+
handling the core components of Stable Diffusion.
|
|
48
|
+
|
|
49
|
+
To use this mixin:
|
|
50
|
+
|
|
51
|
+
1. Create a new pipeline class that inherits from both this mixin and the original StableDiffusionPipeline.
|
|
52
|
+
2. Define the required _submodules and _optional_submodules class variable listing the components to be compiled.
|
|
53
|
+
|
|
54
|
+
Example:
|
|
55
|
+
```python
|
|
56
|
+
class RBLNStableDiffusionPipeline(RBLNDiffusionMixin, StableDiffusionPipeline):
|
|
57
|
+
_submodules = ["text_encoder", "unet", "vae"]
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
Class Variables:
|
|
61
|
+
- `_submodules`: List of submodule names that should be compiled (typically ["text_encoder", "unet", "vae"])
|
|
62
|
+
- `_optional_submodules`: List of submodule names compiled without inheriting RBLNModel (typically ["safety_checker"])
|
|
63
|
+
|
|
64
|
+
Methods:
|
|
65
|
+
from_pretrained: Creates and optionally compiles a model from a pretrained checkpoint
|
|
66
|
+
|
|
67
|
+
Notes:
|
|
68
|
+
- When `export=True`, all compatible submodules will be compiled for NPU inference
|
|
69
|
+
- The compilation config can be customized per submodule by including submodule names
|
|
70
|
+
as keys in rbln_config
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
_connected_classes = {}
|
|
74
|
+
_submodules = []
|
|
75
|
+
_optional_submodules = []
|
|
76
|
+
_prefix = {}
|
|
77
|
+
|
|
78
|
+
@staticmethod
|
|
79
|
+
def _maybe_apply_and_fuse_lora(
|
|
80
|
+
model: torch.nn.Module,
|
|
81
|
+
lora_ids: Optional[Union[str, List[str]]] = None,
|
|
82
|
+
lora_weights_names: Optional[Union[str, List[str]]] = None,
|
|
83
|
+
lora_scales: Optional[Union[float, List[float]]] = None,
|
|
84
|
+
) -> torch.nn.Module:
|
|
85
|
+
lora_ids = [lora_ids] if isinstance(lora_ids, str) else lora_ids
|
|
86
|
+
lora_weights_names = [lora_weights_names] if isinstance(lora_weights_names, str) else lora_weights_names
|
|
87
|
+
lora_scales = [lora_scales] if isinstance(lora_scales, float) else lora_scales
|
|
88
|
+
|
|
89
|
+
# adapt lora weight into pipeline before compilation
|
|
90
|
+
if lora_ids and lora_weights_names:
|
|
91
|
+
if len(lora_ids) == 1:
|
|
92
|
+
if len(lora_ids) != len(lora_weights_names):
|
|
93
|
+
raise ValueError(
|
|
94
|
+
f"You must define the same number of lora ids ({len(lora_ids)} and lora weights ({len(lora_weights_names)}))"
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
model.load_lora_weights(lora_ids[0], weight_name=lora_weights_names[0])
|
|
98
|
+
model.fuse_lora(lora_scale=lora_scales[0] if lora_scales else 1.0)
|
|
99
|
+
elif len(lora_ids) > 1:
|
|
100
|
+
if not len(lora_ids) == len(lora_weights_names):
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"If you fuse {len(lora_ids)} lora models, but you must define the same number for lora weights and adapters."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
adapter_names = [f"adapter_{i}" for i in range(len(lora_ids))]
|
|
106
|
+
|
|
107
|
+
for lora_id, lora_weight, adapter_name in zip(lora_ids, lora_weights_names, adapter_names):
|
|
108
|
+
model.load_lora_weights(lora_id, weight_name=lora_weight, adapter_name=adapter_name)
|
|
109
|
+
|
|
110
|
+
if lora_scales:
|
|
111
|
+
model.set_adapters(adapter_names, adapter_weights=lora_scales)
|
|
112
|
+
|
|
113
|
+
model.fuse_lora()
|
|
114
|
+
return model
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
|
|
118
|
+
# Lazily loads and caches the corresponding RBLN model config class.
|
|
119
|
+
if "_rbln_config_class" not in cls.__dict__ or cls._rbln_config_class is None:
|
|
120
|
+
rbln_config_class_name = cls.__name__ + "Config"
|
|
121
|
+
cls._rbln_config_class = get_rbln_config_class(rbln_config_class_name)
|
|
122
|
+
return cls._rbln_config_class
|
|
123
|
+
|
|
124
|
+
@classmethod
|
|
125
|
+
def get_hf_class(cls):
|
|
126
|
+
if "_hf_class" not in cls.__dict__ or cls._hf_class is None:
|
|
127
|
+
hf_cls_name = cls.__name__[4:]
|
|
128
|
+
library = importlib.import_module("diffusers")
|
|
129
|
+
cls._hf_class = getattr(library, hf_cls_name, None)
|
|
130
|
+
return cls._hf_class
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
def from_pretrained(
|
|
134
|
+
cls,
|
|
135
|
+
model_id: str,
|
|
136
|
+
*,
|
|
137
|
+
export: bool = None,
|
|
138
|
+
model_save_dir: Optional[PathLike] = None,
|
|
139
|
+
rbln_config: Dict[str, Any] = {},
|
|
140
|
+
lora_ids: Optional[Union[str, List[str]]] = None,
|
|
141
|
+
lora_weights_names: Optional[Union[str, List[str]]] = None,
|
|
142
|
+
lora_scales: Optional[Union[float, List[float]]] = None,
|
|
143
|
+
**kwargs: Any,
|
|
144
|
+
) -> "RBLNDiffusionMixin":
|
|
145
|
+
"""
|
|
146
|
+
Load a pretrained diffusion pipeline from a model checkpoint, with optional compilation for RBLN NPUs.
|
|
147
|
+
|
|
148
|
+
This method has two distinct operating modes:
|
|
149
|
+
- When `export=True`: Takes a PyTorch-based diffusion model, compiles it for RBLN NPUs, and loads the compiled model
|
|
150
|
+
- When `export=False`: Loads an already compiled RBLN model from `model_id` without recompilation
|
|
151
|
+
|
|
152
|
+
It supports various diffusion pipelines including Stable Diffusion, Kandinsky, ControlNet, and other diffusers-based models.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
model_id (`str`):
|
|
156
|
+
The model ID or path to the pretrained model to load. Can be either:
|
|
157
|
+
|
|
158
|
+
- A model ID from the HuggingFace Hub
|
|
159
|
+
- A local path to a saved model directory
|
|
160
|
+
export:
|
|
161
|
+
If True, takes a PyTorch model from `model_id` and compiles it for RBLN NPU execution.
|
|
162
|
+
If False, loads an already compiled RBLN model from `model_id` without recompilation.
|
|
163
|
+
model_save_dir:
|
|
164
|
+
Directory to save the compiled model artifacts. Only used when `export=True`.
|
|
165
|
+
If not provided and `export=True`, a temporary directory is used.
|
|
166
|
+
rbln_config:
|
|
167
|
+
Configuration options for RBLN compilation. Can include settings for specific submodules
|
|
168
|
+
such as `text_encoder`, `unet`, and `vae`. Configuration can be tailored to the specific
|
|
169
|
+
pipeline being compiled.
|
|
170
|
+
lora_ids:
|
|
171
|
+
LoRA adapter ID(s) to load and apply before compilation. LoRA weights are fused
|
|
172
|
+
into the model weights during compilation. Only used when `export=True`.
|
|
173
|
+
lora_weights_names:
|
|
174
|
+
Names of specific LoRA weight files to load, corresponding to lora_ids. Only used when `export=True`.
|
|
175
|
+
lora_scales:
|
|
176
|
+
Scaling factor(s) to apply to the LoRA adapter(s). Only used when `export=True`.
|
|
177
|
+
kwargs:
|
|
178
|
+
Additional arguments to pass to the underlying diffusion pipeline constructor or the
|
|
179
|
+
RBLN compilation process. These may include parameters specific to individual submodules
|
|
180
|
+
or the particular diffusion pipeline being used.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
A compiled or loaded diffusion pipeline that can be used for inference on RBLN NPU.
|
|
184
|
+
The returned object is an instance of the class that called this method, inheriting from RBLNDiffusionMixin.
|
|
185
|
+
"""
|
|
186
|
+
rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
|
|
187
|
+
|
|
188
|
+
if export is None:
|
|
189
|
+
export = any(
|
|
190
|
+
not RBLNModel._is_compiled(
|
|
191
|
+
model_id,
|
|
192
|
+
token=kwargs.get("token"),
|
|
193
|
+
revision=kwargs.get("revision"),
|
|
194
|
+
force_download=kwargs.get("force_download", False),
|
|
195
|
+
cache_dir=kwargs.get("cache_dir"),
|
|
196
|
+
subfolder=submodule_name,
|
|
197
|
+
local_files_only=kwargs.get("local_files_only", False),
|
|
198
|
+
)
|
|
199
|
+
for submodule_name in cls._submodules
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
if export:
|
|
203
|
+
# keep submodules if user passed any of them.
|
|
204
|
+
passed_submodules = {
|
|
205
|
+
name: kwargs.pop(name)
|
|
206
|
+
for name in cls._submodules + cls._optional_submodules
|
|
207
|
+
if isinstance(kwargs.get(name), RBLNModel)
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
else:
|
|
211
|
+
# raise error if any of submodules are torch module.
|
|
212
|
+
model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
|
|
213
|
+
for submodule_name in cls._submodules + cls._optional_submodules:
|
|
214
|
+
passed_submodule = kwargs.get(submodule_name, None)
|
|
215
|
+
|
|
216
|
+
if passed_submodule is None:
|
|
217
|
+
module_name, class_name = model_index_config[submodule_name]
|
|
218
|
+
if module_name != "optimum.rbln":
|
|
219
|
+
raise ValueError(
|
|
220
|
+
f"Invalid module_name '{module_name}' found in model_index.json for "
|
|
221
|
+
f"submodule '{submodule_name}'. "
|
|
222
|
+
"Expected 'optimum.rbln'. Please check the model_index.json configuration."
|
|
223
|
+
"If you want to compile, set `export=True`."
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
submodule_cls = get_rbln_model_cls(class_name)
|
|
227
|
+
submodule_config = getattr(rbln_config, submodule_name)
|
|
228
|
+
submodule = submodule_cls.from_pretrained(
|
|
229
|
+
model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
else:
|
|
233
|
+
if passed_submodule.__class__.__name__.startswith("RBLN"):
|
|
234
|
+
submodule = passed_submodule
|
|
235
|
+
|
|
236
|
+
elif isinstance(passed_submodule, torch.nn.Module):
|
|
237
|
+
raise AssertionError(
|
|
238
|
+
f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
kwargs[submodule_name] = submodule
|
|
242
|
+
|
|
243
|
+
with ContextRblnConfig(
|
|
244
|
+
device=rbln_config.device,
|
|
245
|
+
device_map=rbln_config.device_map,
|
|
246
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
247
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
248
|
+
timeout=rbln_config.timeout,
|
|
249
|
+
):
|
|
250
|
+
model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
|
|
251
|
+
|
|
252
|
+
if not export:
|
|
253
|
+
return model
|
|
254
|
+
|
|
255
|
+
model = cls._maybe_apply_and_fuse_lora(
|
|
256
|
+
model,
|
|
257
|
+
lora_ids=lora_ids,
|
|
258
|
+
lora_weights_names=lora_weights_names,
|
|
259
|
+
lora_scales=lora_scales,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
if cls._load_connected_pipes:
|
|
263
|
+
compiled_submodules = cls._compile_pipelines(model, passed_submodules, model_save_dir, rbln_config)
|
|
264
|
+
else:
|
|
265
|
+
compiled_submodules = cls._compile_submodules(model, passed_submodules, model_save_dir, rbln_config)
|
|
266
|
+
return cls._construct_pipe(model, compiled_submodules, model_save_dir, rbln_config)
|
|
267
|
+
|
|
268
|
+
@classmethod
|
|
269
|
+
def _compile_pipelines(
|
|
270
|
+
cls,
|
|
271
|
+
model: torch.nn.Module,
|
|
272
|
+
passed_submodules: Dict[str, RBLNModel],
|
|
273
|
+
model_save_dir: Optional[PathLike],
|
|
274
|
+
rbln_config: "RBLNDiffusionMixinConfig",
|
|
275
|
+
) -> Dict[str, RBLNModel]:
|
|
276
|
+
compiled_submodules = {}
|
|
277
|
+
for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
|
|
278
|
+
connected_pipe_submodules = {}
|
|
279
|
+
prefix = cls._prefix.get(connected_pipe_name, "")
|
|
280
|
+
for submodule_name in connected_pipe_cls._submodules:
|
|
281
|
+
connected_pipe_submodules[submodule_name] = passed_submodules.get(prefix + submodule_name, None)
|
|
282
|
+
connected_pipe = getattr(model, connected_pipe_name)
|
|
283
|
+
connected_pipe_compiled_submodules = connected_pipe_cls._compile_submodules(
|
|
284
|
+
connected_pipe,
|
|
285
|
+
connected_pipe_submodules,
|
|
286
|
+
model_save_dir,
|
|
287
|
+
getattr(rbln_config, connected_pipe_name),
|
|
288
|
+
prefix,
|
|
289
|
+
)
|
|
290
|
+
for submodule_name, compiled_submodule in connected_pipe_compiled_submodules.items():
|
|
291
|
+
compiled_submodules[prefix + submodule_name] = compiled_submodule
|
|
292
|
+
return compiled_submodules
|
|
293
|
+
|
|
294
|
+
@classmethod
|
|
295
|
+
def _compile_submodules(
|
|
296
|
+
cls,
|
|
297
|
+
model: torch.nn.Module,
|
|
298
|
+
passed_submodules: Dict[str, RBLNModel],
|
|
299
|
+
model_save_dir: Optional[PathLike],
|
|
300
|
+
rbln_config: RBLNDiffusionMixinConfig,
|
|
301
|
+
prefix: Optional[str] = "",
|
|
302
|
+
) -> Dict[str, RBLNModel]:
|
|
303
|
+
compiled_submodules = {}
|
|
304
|
+
|
|
305
|
+
for submodule_name in cls._submodules:
|
|
306
|
+
submodule = passed_submodules.get(submodule_name) or getattr(model, submodule_name, None)
|
|
307
|
+
|
|
308
|
+
if getattr(rbln_config, submodule_name, None) is None:
|
|
309
|
+
raise ValueError(f"RBLN config for submodule {submodule_name} is not provided.")
|
|
310
|
+
|
|
311
|
+
submodule_rbln_cls: Type[RBLNModel] = getattr(rbln_config, submodule_name).rbln_model_cls
|
|
312
|
+
rbln_config = submodule_rbln_cls.update_rbln_config_using_pipe(model, rbln_config, submodule_name)
|
|
313
|
+
|
|
314
|
+
if submodule is None:
|
|
315
|
+
raise ValueError(f"submodule ({submodule_name}) cannot be accessed since it is not provided.")
|
|
316
|
+
elif isinstance(submodule, RBLNModel):
|
|
317
|
+
pass
|
|
318
|
+
elif submodule_name == "controlnet" and hasattr(submodule, "nets"):
|
|
319
|
+
submodule = cls._compile_multicontrolnet(
|
|
320
|
+
controlnets=submodule,
|
|
321
|
+
model_save_dir=model_save_dir,
|
|
322
|
+
controlnet_rbln_config=getattr(rbln_config, submodule_name),
|
|
323
|
+
prefix=prefix,
|
|
324
|
+
)
|
|
325
|
+
elif isinstance(submodule, torch.nn.Module):
|
|
326
|
+
subfolder = prefix + submodule_name
|
|
327
|
+
submodule = submodule_rbln_cls.from_model(
|
|
328
|
+
model=submodule,
|
|
329
|
+
subfolder=subfolder,
|
|
330
|
+
model_save_dir=model_save_dir,
|
|
331
|
+
rbln_config=getattr(rbln_config, submodule_name),
|
|
332
|
+
)
|
|
333
|
+
else:
|
|
334
|
+
raise ValueError(f"Unknown class of submodule({submodule_name}) : {submodule.__class__.__name__} ")
|
|
335
|
+
|
|
336
|
+
compiled_submodules[submodule_name] = submodule
|
|
337
|
+
return compiled_submodules
|
|
338
|
+
|
|
339
|
+
@classmethod
|
|
340
|
+
def _compile_multicontrolnet(
|
|
341
|
+
cls,
|
|
342
|
+
controlnets: "MultiControlNetModel",
|
|
343
|
+
model_save_dir: Optional[PathLike],
|
|
344
|
+
controlnet_rbln_config: RBLNModelConfig,
|
|
345
|
+
prefix: Optional[str] = "",
|
|
346
|
+
):
|
|
347
|
+
# Compile multiple ControlNet models for a MultiControlNet setup
|
|
348
|
+
from .models.controlnet import RBLNControlNetModel
|
|
349
|
+
from .pipelines.controlnet import RBLNMultiControlNetModel
|
|
350
|
+
|
|
351
|
+
compiled_controlnets = []
|
|
352
|
+
for i, controlnet in enumerate(controlnets.nets):
|
|
353
|
+
_controlnet_rbln_config = copy.deepcopy(controlnet_rbln_config)
|
|
354
|
+
compiled_controlnets.append(
|
|
355
|
+
RBLNControlNetModel.from_model(
|
|
356
|
+
model=controlnet,
|
|
357
|
+
subfolder=f"{prefix}controlnet" if i == 0 else f"{prefix}controlnet_{i}",
|
|
358
|
+
model_save_dir=model_save_dir,
|
|
359
|
+
rbln_config=_controlnet_rbln_config,
|
|
360
|
+
)
|
|
361
|
+
)
|
|
362
|
+
return RBLNMultiControlNetModel(compiled_controlnets)
|
|
363
|
+
|
|
364
|
+
@classmethod
|
|
365
|
+
def _construct_pipe(cls, model, submodules, model_save_dir, rbln_config):
|
|
366
|
+
# Construct finalize pipe setup with compiled submodules and configurations
|
|
367
|
+
if model_save_dir is not None:
|
|
368
|
+
# To skip saving original pytorch modules
|
|
369
|
+
for submodule_name in cls._submodules:
|
|
370
|
+
delattr(model, submodule_name)
|
|
371
|
+
|
|
372
|
+
if cls._load_connected_pipes:
|
|
373
|
+
for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
|
|
374
|
+
for submodule_name in connected_pipe_cls._submodules:
|
|
375
|
+
delattr(getattr(model, connected_pipe_name), submodule_name)
|
|
376
|
+
|
|
377
|
+
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
|
378
|
+
# So config must be saved again, later.
|
|
379
|
+
model.save_pretrained(model_save_dir)
|
|
380
|
+
# FIXME: Here, model touches its submodules such as model.unet,
|
|
381
|
+
# Causing warning messeages.
|
|
382
|
+
|
|
383
|
+
update_dict = {}
|
|
384
|
+
for submodule_name in cls._submodules + cls._optional_submodules:
|
|
385
|
+
# replace submodule
|
|
386
|
+
if submodule_name in submodules:
|
|
387
|
+
setattr(model, submodule_name, submodules[submodule_name])
|
|
388
|
+
update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
|
|
389
|
+
else:
|
|
390
|
+
# It assumes that the modules in _optional_components is compiled
|
|
391
|
+
# and already registered as an attribute of the model.
|
|
392
|
+
update_dict[submodule_name] = ("optimum.rbln", getattr(model, submodule_name).__class__.__name__)
|
|
393
|
+
|
|
394
|
+
if cls._load_connected_pipes:
|
|
395
|
+
for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
|
|
396
|
+
prefix = cls._prefix.get(connected_pipe_name, "")
|
|
397
|
+
for submodule_name in connected_pipe_cls._submodules:
|
|
398
|
+
setattr(getattr(model, connected_pipe_name), submodule_name, submodules[prefix + submodule_name])
|
|
399
|
+
|
|
400
|
+
# Update config to be able to load from model directory.
|
|
401
|
+
#
|
|
402
|
+
# e.g)
|
|
403
|
+
# update_dict = {
|
|
404
|
+
# "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
|
|
405
|
+
# "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
|
|
406
|
+
# "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
|
|
407
|
+
# }
|
|
408
|
+
model.register_to_config(**update_dict)
|
|
409
|
+
|
|
410
|
+
if model_save_dir:
|
|
411
|
+
# overwrite to replace incorrect config
|
|
412
|
+
model.save_config(model_save_dir)
|
|
413
|
+
|
|
414
|
+
# Keep compiled_model objs to further analysis. -> TODO: remove soon...
|
|
415
|
+
model.compiled_models = []
|
|
416
|
+
for name in cls._submodules:
|
|
417
|
+
submodule = getattr(model, name)
|
|
418
|
+
model.compiled_models.extend(submodule.compiled_models)
|
|
419
|
+
|
|
420
|
+
return model
|
|
421
|
+
|
|
422
|
+
def get_compiled_image_size(self):
|
|
423
|
+
if hasattr(self, "vae") and hasattr(self.vae, "image_size"):
|
|
424
|
+
compiled_image_size = self.vae.image_size
|
|
425
|
+
else:
|
|
426
|
+
compiled_image_size = None
|
|
427
|
+
return compiled_image_size
|
|
428
|
+
|
|
429
|
+
def handle_additional_kwargs(self, **kwargs):
|
|
430
|
+
# Function to handle additional compile-time parameters during inference.
|
|
431
|
+
|
|
432
|
+
# If the additional variable is determined by another module, this method should be overrided.
|
|
433
|
+
|
|
434
|
+
# Example:
|
|
435
|
+
# ```python
|
|
436
|
+
# if hasattr(self, "movq"):
|
|
437
|
+
# compiled_image_size = self.movq.image_size
|
|
438
|
+
# kwargs["height"] = compiled_image_size[0]
|
|
439
|
+
# kwargs["width"] = compiled_image_size[1]
|
|
440
|
+
|
|
441
|
+
# compiled_num_frames = self.unet.rbln_config.num_frames
|
|
442
|
+
# if compiled_num_frames is not None:
|
|
443
|
+
# kwargs["num_frames"] = compiled_num_frames
|
|
444
|
+
# return kwargs
|
|
445
|
+
# ```
|
|
446
|
+
return kwargs
|
|
447
|
+
|
|
448
|
+
@remove_compile_time_kwargs
|
|
449
|
+
def __call__(self, *args, **kwargs):
|
|
450
|
+
kwargs = self.handle_additional_kwargs(**kwargs)
|
|
451
|
+
return super().__call__(*args, **kwargs)
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
from transformers.utils import _LazyModule
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
_import_structure = {
|
|
21
|
+
"autoencoders": [
|
|
22
|
+
"RBLNAutoencoderKL",
|
|
23
|
+
"RBLNAutoencoderKLCosmos",
|
|
24
|
+
"RBLNVQModel",
|
|
25
|
+
"RBLNAutoencoderKLTemporalDecoder",
|
|
26
|
+
],
|
|
27
|
+
"unets": [
|
|
28
|
+
"RBLNUNet2DConditionModel",
|
|
29
|
+
"RBLNUNetSpatioTemporalConditionModel",
|
|
30
|
+
],
|
|
31
|
+
"controlnet": ["RBLNControlNetModel"],
|
|
32
|
+
"transformers": [
|
|
33
|
+
"RBLNPriorTransformer",
|
|
34
|
+
"RBLNCosmosTransformer3DModel",
|
|
35
|
+
"RBLNSD3Transformer2DModel",
|
|
36
|
+
],
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
if TYPE_CHECKING:
|
|
40
|
+
from .autoencoders import (
|
|
41
|
+
RBLNAutoencoderKL,
|
|
42
|
+
RBLNAutoencoderKLCosmos,
|
|
43
|
+
RBLNAutoencoderKLTemporalDecoder,
|
|
44
|
+
RBLNVQModel,
|
|
45
|
+
)
|
|
46
|
+
from .controlnet import RBLNControlNetModel
|
|
47
|
+
from .transformers import (
|
|
48
|
+
RBLNCosmosTransformer3DModel,
|
|
49
|
+
RBLNPriorTransformer,
|
|
50
|
+
RBLNSD3Transformer2DModel,
|
|
51
|
+
)
|
|
52
|
+
from .unets import (
|
|
53
|
+
RBLNUNet2DConditionModel,
|
|
54
|
+
RBLNUNetSpatioTemporalConditionModel,
|
|
55
|
+
)
|
|
56
|
+
else:
|
|
57
|
+
import sys
|
|
58
|
+
|
|
59
|
+
sys.modules[__name__] = _LazyModule(
|
|
60
|
+
__name__,
|
|
61
|
+
globals()["__file__"],
|
|
62
|
+
_import_structure,
|
|
63
|
+
module_spec=__spec__,
|
|
64
|
+
)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .autoencoder_kl import RBLNAutoencoderKL
|
|
16
|
+
from .autoencoder_kl_cosmos import RBLNAutoencoderKLCosmos
|
|
17
|
+
from .autoencoder_kl_temporal_decoder import RBLNAutoencoderKLTemporalDecoder
|
|
18
|
+
from .vq_model import RBLNVQModel
|