optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
from transformers.utils import _LazyModule
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
_import_structure = {
|
|
21
|
+
"auto_pipeline": [
|
|
22
|
+
"RBLNAutoPipelineForImage2Image",
|
|
23
|
+
"RBLNAutoPipelineForInpainting",
|
|
24
|
+
"RBLNAutoPipelineForText2Image",
|
|
25
|
+
],
|
|
26
|
+
"controlnet": [
|
|
27
|
+
"RBLNMultiControlNetModel",
|
|
28
|
+
"RBLNStableDiffusionControlNetImg2ImgPipeline",
|
|
29
|
+
"RBLNStableDiffusionControlNetPipeline",
|
|
30
|
+
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
|
31
|
+
"RBLNStableDiffusionXLControlNetPipeline",
|
|
32
|
+
],
|
|
33
|
+
"cosmos": [
|
|
34
|
+
"RBLNCosmosTextToWorldPipeline",
|
|
35
|
+
"RBLNCosmosVideoToWorldPipeline",
|
|
36
|
+
"RBLNCosmosSafetyChecker",
|
|
37
|
+
],
|
|
38
|
+
"kandinsky2_2": [
|
|
39
|
+
"RBLNKandinskyV22CombinedPipeline",
|
|
40
|
+
"RBLNKandinskyV22Img2ImgCombinedPipeline",
|
|
41
|
+
"RBLNKandinskyV22InpaintCombinedPipeline",
|
|
42
|
+
"RBLNKandinskyV22InpaintPipeline",
|
|
43
|
+
"RBLNKandinskyV22Img2ImgPipeline",
|
|
44
|
+
"RBLNKandinskyV22PriorPipeline",
|
|
45
|
+
"RBLNKandinskyV22Pipeline",
|
|
46
|
+
],
|
|
47
|
+
"stable_diffusion": [
|
|
48
|
+
"RBLNStableDiffusionImg2ImgPipeline",
|
|
49
|
+
"RBLNStableDiffusionPipeline",
|
|
50
|
+
"RBLNStableDiffusionInpaintPipeline",
|
|
51
|
+
],
|
|
52
|
+
"stable_diffusion_xl": [
|
|
53
|
+
"RBLNStableDiffusionXLImg2ImgPipeline",
|
|
54
|
+
"RBLNStableDiffusionXLPipeline",
|
|
55
|
+
"RBLNStableDiffusionXLInpaintPipeline",
|
|
56
|
+
],
|
|
57
|
+
"stable_diffusion_3": [
|
|
58
|
+
"RBLNStableDiffusion3Pipeline",
|
|
59
|
+
"RBLNStableDiffusion3Img2ImgPipeline",
|
|
60
|
+
"RBLNStableDiffusion3InpaintPipeline",
|
|
61
|
+
],
|
|
62
|
+
"stable_video_diffusion": [
|
|
63
|
+
"RBLNStableVideoDiffusionPipeline",
|
|
64
|
+
],
|
|
65
|
+
}
|
|
66
|
+
if TYPE_CHECKING:
|
|
67
|
+
from .auto_pipeline import (
|
|
68
|
+
RBLNAutoPipelineForImage2Image,
|
|
69
|
+
RBLNAutoPipelineForInpainting,
|
|
70
|
+
RBLNAutoPipelineForText2Image,
|
|
71
|
+
)
|
|
72
|
+
from .controlnet import (
|
|
73
|
+
RBLNMultiControlNetModel,
|
|
74
|
+
RBLNStableDiffusionControlNetImg2ImgPipeline,
|
|
75
|
+
RBLNStableDiffusionControlNetPipeline,
|
|
76
|
+
RBLNStableDiffusionXLControlNetImg2ImgPipeline,
|
|
77
|
+
RBLNStableDiffusionXLControlNetPipeline,
|
|
78
|
+
)
|
|
79
|
+
from .cosmos import RBLNCosmosSafetyChecker, RBLNCosmosTextToWorldPipeline, RBLNCosmosVideoToWorldPipeline
|
|
80
|
+
from .kandinsky2_2 import (
|
|
81
|
+
RBLNKandinskyV22CombinedPipeline,
|
|
82
|
+
RBLNKandinskyV22Img2ImgCombinedPipeline,
|
|
83
|
+
RBLNKandinskyV22Img2ImgPipeline,
|
|
84
|
+
RBLNKandinskyV22InpaintCombinedPipeline,
|
|
85
|
+
RBLNKandinskyV22InpaintPipeline,
|
|
86
|
+
RBLNKandinskyV22Pipeline,
|
|
87
|
+
RBLNKandinskyV22PriorPipeline,
|
|
88
|
+
)
|
|
89
|
+
from .stable_diffusion import (
|
|
90
|
+
RBLNStableDiffusionImg2ImgPipeline,
|
|
91
|
+
RBLNStableDiffusionInpaintPipeline,
|
|
92
|
+
RBLNStableDiffusionPipeline,
|
|
93
|
+
)
|
|
94
|
+
from .stable_diffusion_3 import (
|
|
95
|
+
RBLNStableDiffusion3Img2ImgPipeline,
|
|
96
|
+
RBLNStableDiffusion3InpaintPipeline,
|
|
97
|
+
RBLNStableDiffusion3Pipeline,
|
|
98
|
+
)
|
|
99
|
+
from .stable_diffusion_xl import (
|
|
100
|
+
RBLNStableDiffusionXLImg2ImgPipeline,
|
|
101
|
+
RBLNStableDiffusionXLInpaintPipeline,
|
|
102
|
+
RBLNStableDiffusionXLPipeline,
|
|
103
|
+
)
|
|
104
|
+
from .stable_video_diffusion import RBLNStableVideoDiffusionPipeline
|
|
105
|
+
else:
|
|
106
|
+
import sys
|
|
107
|
+
|
|
108
|
+
sys.modules[__name__] = _LazyModule(
|
|
109
|
+
__name__,
|
|
110
|
+
globals()["__file__"],
|
|
111
|
+
_import_structure,
|
|
112
|
+
module_spec=__spec__,
|
|
113
|
+
)
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
import importlib
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Any, Dict, Type, Union
|
|
19
|
+
|
|
20
|
+
from diffusers.models.controlnets import ControlNetUnionModel
|
|
21
|
+
from diffusers.pipelines.auto_pipeline import (
|
|
22
|
+
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
|
|
23
|
+
AUTO_INPAINT_PIPELINES_MAPPING,
|
|
24
|
+
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
|
|
25
|
+
AutoPipelineForImage2Image,
|
|
26
|
+
AutoPipelineForInpainting,
|
|
27
|
+
AutoPipelineForText2Image,
|
|
28
|
+
_get_task_class,
|
|
29
|
+
)
|
|
30
|
+
from huggingface_hub.utils import validate_hf_hub_args
|
|
31
|
+
|
|
32
|
+
from optimum.rbln.configuration_utils import RBLNModelConfig
|
|
33
|
+
from optimum.rbln.modeling_base import RBLNBaseModel
|
|
34
|
+
from optimum.rbln.utils.model_utils import (
|
|
35
|
+
MODEL_MAPPING,
|
|
36
|
+
convert_hf_to_rbln_model_name,
|
|
37
|
+
convert_rbln_to_hf_model_name,
|
|
38
|
+
get_rbln_model_cls,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class RBLNAutoPipelineBase:
|
|
43
|
+
_model_mapping = None
|
|
44
|
+
_model_mapping_names = None
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def get_rbln_cls(cls, pretrained_model_name_or_path: Union[str, Path], export: bool = None, **kwargs):
|
|
48
|
+
if isinstance(pretrained_model_name_or_path, Path):
|
|
49
|
+
pretrained_model_name_or_path = pretrained_model_name_or_path.as_posix()
|
|
50
|
+
|
|
51
|
+
if export is None:
|
|
52
|
+
export = not cls._is_compiled_pipeline(pretrained_model_name_or_path, **kwargs)
|
|
53
|
+
|
|
54
|
+
if export:
|
|
55
|
+
hf_model_class = cls.infer_hf_model_class(pretrained_model_name_or_path, **kwargs)
|
|
56
|
+
rbln_class_name = convert_hf_to_rbln_model_name(hf_model_class.__name__)
|
|
57
|
+
else:
|
|
58
|
+
rbln_class_name = cls.get_rbln_model_cls_name(pretrained_model_name_or_path, **kwargs)
|
|
59
|
+
if convert_rbln_to_hf_model_name(rbln_class_name) not in cls._model_mapping_names.values():
|
|
60
|
+
raise ValueError(
|
|
61
|
+
f"The architecture '{rbln_class_name}' is not supported by the `{cls.__name__}.from_pretrained()` method. "
|
|
62
|
+
"Please use the `from_pretrained()` method of the appropriate class to load this model, "
|
|
63
|
+
f"or directly use '{rbln_class_name}.from_pretrained()`."
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
rbln_cls = get_rbln_model_cls(rbln_class_name)
|
|
68
|
+
except AttributeError as e:
|
|
69
|
+
raise AttributeError(
|
|
70
|
+
f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{pretrained_model_name_or_path}'. "
|
|
71
|
+
"Ensure that the class name is correctly mapped and available in the 'optimum.rbln' module."
|
|
72
|
+
) from e
|
|
73
|
+
|
|
74
|
+
return rbln_cls
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def get_rbln_model_cls_name(cls, pretrained_model_name_or_path: Union[str, Path], **kwargs):
|
|
78
|
+
"""
|
|
79
|
+
Retrieve the path to the compiled model directory for a given RBLN model.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
pretrained_model_name_or_path (str): Identifier of the model.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
str: Path to the compiled model directory.
|
|
86
|
+
"""
|
|
87
|
+
model_index_config = cls.load_config(pretrained_model_name_or_path)
|
|
88
|
+
|
|
89
|
+
if "_class_name" not in model_index_config:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
"The `_class_name` field is missing from model_index_config. This is unexpected and should be reported as an issue. "
|
|
92
|
+
"Please use the `from_pretrained()` method of the appropriate class to load this model."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return model_index_config["_class_name"]
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def _is_compiled_pipeline(
|
|
99
|
+
cls,
|
|
100
|
+
pretrained_model_name_or_path: Union[str, Path],
|
|
101
|
+
cache_dir=None,
|
|
102
|
+
force_download=False,
|
|
103
|
+
proxies=None,
|
|
104
|
+
token=None,
|
|
105
|
+
local_files_only=False,
|
|
106
|
+
revision=None,
|
|
107
|
+
**kwargs,
|
|
108
|
+
):
|
|
109
|
+
config: dict = cls.load_config(
|
|
110
|
+
pretrained_model_name_or_path,
|
|
111
|
+
cache_dir=cache_dir,
|
|
112
|
+
force_download=force_download,
|
|
113
|
+
proxies=proxies,
|
|
114
|
+
token=token,
|
|
115
|
+
local_files_only=local_files_only,
|
|
116
|
+
revision=revision,
|
|
117
|
+
)
|
|
118
|
+
for value in config.values():
|
|
119
|
+
if isinstance(value, list) and len(value) > 0 and value[0] == "optimum.rbln":
|
|
120
|
+
return True
|
|
121
|
+
return False
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def infer_hf_model_class(
|
|
125
|
+
cls,
|
|
126
|
+
pretrained_model_or_path: Union[str, Path],
|
|
127
|
+
cache_dir=None,
|
|
128
|
+
force_download=False,
|
|
129
|
+
proxies=None,
|
|
130
|
+
token=None,
|
|
131
|
+
local_files_only=False,
|
|
132
|
+
revision=None,
|
|
133
|
+
**kwargs,
|
|
134
|
+
):
|
|
135
|
+
config = cls.load_config(
|
|
136
|
+
pretrained_model_or_path,
|
|
137
|
+
cache_dir=cache_dir,
|
|
138
|
+
force_download=force_download,
|
|
139
|
+
proxies=proxies,
|
|
140
|
+
token=token,
|
|
141
|
+
local_files_only=local_files_only,
|
|
142
|
+
revision=revision,
|
|
143
|
+
)
|
|
144
|
+
pipeline_key_name = cls.get_pipeline_key_name(config, **kwargs)
|
|
145
|
+
|
|
146
|
+
pipeline_cls = _get_task_class(cls._model_mapping, pipeline_key_name)
|
|
147
|
+
|
|
148
|
+
return pipeline_cls
|
|
149
|
+
|
|
150
|
+
@classmethod
|
|
151
|
+
def get_pipeline_key_name(cls, config, **kwargs):
|
|
152
|
+
orig_class_name = config["_class_name"]
|
|
153
|
+
if "ControlPipeline" in orig_class_name:
|
|
154
|
+
to_replace = "ControlPipeline"
|
|
155
|
+
else:
|
|
156
|
+
to_replace = "Pipeline"
|
|
157
|
+
|
|
158
|
+
if "controlnet" in kwargs:
|
|
159
|
+
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
|
160
|
+
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
|
|
161
|
+
else:
|
|
162
|
+
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
|
|
163
|
+
if "enable_pag" in kwargs:
|
|
164
|
+
enable_pag = kwargs.pop("enable_pag")
|
|
165
|
+
if enable_pag:
|
|
166
|
+
orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
|
|
167
|
+
|
|
168
|
+
return orig_class_name
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
@validate_hf_hub_args
|
|
172
|
+
def from_pretrained(
|
|
173
|
+
cls,
|
|
174
|
+
model_id: Union[str, Path],
|
|
175
|
+
*,
|
|
176
|
+
export: bool = None,
|
|
177
|
+
rbln_config: Union[Dict[str, Any], RBLNModelConfig] = {},
|
|
178
|
+
**kwargs: Any,
|
|
179
|
+
):
|
|
180
|
+
"""
|
|
181
|
+
Load an RBLN-accelerated Diffusers pipeline from a pretrained checkpoint or a compiled RBLN artifact.
|
|
182
|
+
|
|
183
|
+
This method determines the concrete `RBLN*` model class that corresponds to the
|
|
184
|
+
underlying Diffusers pipeline architecture and dispatches to that class's
|
|
185
|
+
`from_pretrained()` implementation. If a compiled RBLN folder is detected at `model_id`
|
|
186
|
+
(or `export=False` is explicitly passed), it loads the compiled artifacts; otherwise it
|
|
187
|
+
compiles from the original Diffusers checkpoint.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
model_id:
|
|
191
|
+
HF repo id or local path. For compiled models, this should point to a directory
|
|
192
|
+
(optionally under `subfolder`) that contains `*.rbln` files and `rbln_config.json`.
|
|
193
|
+
export:
|
|
194
|
+
Force compilation from a Diffusers checkpoint. When `None`, this is inferred by
|
|
195
|
+
checking whether compiled artifacts exist at `model_id`.
|
|
196
|
+
rbln_config:
|
|
197
|
+
RBLN compilation/runtime configuration. May be provided as a dictionary or as an
|
|
198
|
+
instance of the specific model's config class (e.g., `RBLNFluxPipelineConfig`).
|
|
199
|
+
kwargs: Additional keyword arguments.
|
|
200
|
+
- Arguments prefixed with `rbln_` are forwarded to the RBLN config.
|
|
201
|
+
- Remaining arguments are forwarded to the Diffusers loader.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
RBLNBaseModel: An instantiated RBLN model wrapping the Diffusers pipeline, ready for
|
|
205
|
+
inference on RBLN NPUs.
|
|
206
|
+
|
|
207
|
+
"""
|
|
208
|
+
rbln_cls = cls.get_rbln_cls(model_id, export=export, **kwargs)
|
|
209
|
+
return rbln_cls.from_pretrained(model_id, export=export, rbln_config=rbln_config, **kwargs)
|
|
210
|
+
|
|
211
|
+
@staticmethod
|
|
212
|
+
def register(rbln_cls: Type[RBLNBaseModel], exist_ok=False):
|
|
213
|
+
"""
|
|
214
|
+
Register a new RBLN model class.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
rbln_cls (Type[RBLNBaseModel]): The RBLN model class to register.
|
|
218
|
+
exist_ok (bool): Whether to allow registering an already registered model.
|
|
219
|
+
"""
|
|
220
|
+
if not issubclass(rbln_cls, RBLNBaseModel):
|
|
221
|
+
raise ValueError("`rbln_cls` must be a subclass of RBLNBaseModel.")
|
|
222
|
+
|
|
223
|
+
native_cls = getattr(importlib.import_module("optimum.rbln"), rbln_cls.__name__, None)
|
|
224
|
+
if rbln_cls.__name__ in MODEL_MAPPING or native_cls is not None:
|
|
225
|
+
if not exist_ok:
|
|
226
|
+
raise ValueError(f"Model for {rbln_cls.__name__} already registered.")
|
|
227
|
+
|
|
228
|
+
MODEL_MAPPING[rbln_cls.__name__] = rbln_cls
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class RBLNAutoPipelineForText2Image(RBLNAutoPipelineBase, AutoPipelineForText2Image):
|
|
232
|
+
"""Text2Image AutoPipeline for RBLN NPUs."""
|
|
233
|
+
|
|
234
|
+
_model_mapping = AUTO_TEXT2IMAGE_PIPELINES_MAPPING
|
|
235
|
+
_model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_TEXT2IMAGE_PIPELINES_MAPPING.items()}
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class RBLNAutoPipelineForImage2Image(RBLNAutoPipelineBase, AutoPipelineForImage2Image):
|
|
239
|
+
"""Image2Image AutoPipeline for RBLN NPUs."""
|
|
240
|
+
|
|
241
|
+
_model_mapping = AUTO_IMAGE2IMAGE_PIPELINES_MAPPING
|
|
242
|
+
_model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.items()}
|
|
243
|
+
|
|
244
|
+
@classmethod
|
|
245
|
+
def get_pipeline_key_name(cls, config, **kwargs):
|
|
246
|
+
orig_class_name = config["_class_name"]
|
|
247
|
+
# the `orig_class_name` can be:
|
|
248
|
+
# `- *Pipeline` (for regular text-to-image checkpoint)
|
|
249
|
+
# - `*ControlPipeline` (for Flux tools specific checkpoint)
|
|
250
|
+
# `- *Img2ImgPipeline` (for refiner checkpoint)
|
|
251
|
+
if "Img2Img" in orig_class_name:
|
|
252
|
+
to_replace = "Img2ImgPipeline"
|
|
253
|
+
elif "ControlPipeline" in orig_class_name:
|
|
254
|
+
to_replace = "ControlPipeline"
|
|
255
|
+
else:
|
|
256
|
+
to_replace = "Pipeline"
|
|
257
|
+
|
|
258
|
+
if "controlnet" in kwargs:
|
|
259
|
+
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
|
260
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
|
|
261
|
+
else:
|
|
262
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
|
|
263
|
+
if "enable_pag" in kwargs:
|
|
264
|
+
enable_pag = kwargs.pop("enable_pag")
|
|
265
|
+
if enable_pag:
|
|
266
|
+
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
|
|
267
|
+
|
|
268
|
+
if to_replace == "ControlPipeline":
|
|
269
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
|
|
270
|
+
|
|
271
|
+
return orig_class_name
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class RBLNAutoPipelineForInpainting(RBLNAutoPipelineBase, AutoPipelineForInpainting):
|
|
275
|
+
"""Inpainting AutoPipeline for RBLN NPUs."""
|
|
276
|
+
|
|
277
|
+
_model_mapping = AUTO_INPAINT_PIPELINES_MAPPING
|
|
278
|
+
_model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_INPAINT_PIPELINES_MAPPING.items()}
|
|
279
|
+
|
|
280
|
+
@classmethod
|
|
281
|
+
def get_pipeline_key_name(cls, config, **kwargs):
|
|
282
|
+
orig_class_name = config["_class_name"]
|
|
283
|
+
|
|
284
|
+
# The `orig_class_name`` can be:
|
|
285
|
+
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
|
|
286
|
+
# - `*ControlPipeline` (for Flux tools specific checkpoint)
|
|
287
|
+
# - or *Pipeline (for regular text-to-image checkpoint)
|
|
288
|
+
if "Inpaint" in orig_class_name:
|
|
289
|
+
to_replace = "InpaintPipeline"
|
|
290
|
+
elif "ControlPipeline" in orig_class_name:
|
|
291
|
+
to_replace = "ControlPipeline"
|
|
292
|
+
else:
|
|
293
|
+
to_replace = "Pipeline"
|
|
294
|
+
|
|
295
|
+
if "controlnet" in kwargs:
|
|
296
|
+
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
|
297
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
|
|
298
|
+
else:
|
|
299
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
|
|
300
|
+
if "enable_pag" in kwargs:
|
|
301
|
+
enable_pag = kwargs.pop("enable_pag")
|
|
302
|
+
if enable_pag:
|
|
303
|
+
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
|
|
304
|
+
if to_replace == "ControlPipeline":
|
|
305
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
|
|
306
|
+
|
|
307
|
+
return orig_class_name
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .multicontrolnet import RBLNMultiControlNetModel
|
|
16
|
+
from .pipeline_controlnet import RBLNStableDiffusionControlNetPipeline
|
|
17
|
+
from .pipeline_controlnet_img2img import RBLNStableDiffusionControlNetImg2ImgPipeline
|
|
18
|
+
from .pipeline_controlnet_sd_xl import RBLNStableDiffusionXLControlNetPipeline
|
|
19
|
+
from .pipeline_controlnet_sd_xl_img2img import RBLNStableDiffusionXLControlNetImg2ImgPipeline
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any, Dict, List, Optional, Union
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
|
21
|
+
|
|
22
|
+
from ....modeling import RBLNModel
|
|
23
|
+
from ....utils.logging import get_logger
|
|
24
|
+
from ...models.controlnet import RBLNControlNetModel
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
logger = get_logger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RBLNMultiControlNetModel(RBLNModel):
|
|
31
|
+
hf_library_name = "diffusers"
|
|
32
|
+
_hf_class = MultiControlNetModel
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
models: List[RBLNControlNetModel],
|
|
37
|
+
):
|
|
38
|
+
self.nets = models
|
|
39
|
+
self.dtype = torch.float32
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def compiled_models(self):
|
|
43
|
+
cm = []
|
|
44
|
+
for net in self.nets:
|
|
45
|
+
cm.extend(net.compiled_models)
|
|
46
|
+
return cm
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def _from_pretrained(
|
|
50
|
+
cls,
|
|
51
|
+
model_id: Union[str, Path],
|
|
52
|
+
**kwargs,
|
|
53
|
+
) -> RBLNModel:
|
|
54
|
+
idx = 0
|
|
55
|
+
controlnets = []
|
|
56
|
+
subfolder_name = kwargs.pop("subfolder", None)
|
|
57
|
+
if subfolder_name is not None:
|
|
58
|
+
model_path_to_load = model_id + "/" + subfolder_name
|
|
59
|
+
else:
|
|
60
|
+
model_path_to_load = model_id
|
|
61
|
+
|
|
62
|
+
base_model_path_to_load = model_path_to_load
|
|
63
|
+
|
|
64
|
+
while os.path.isdir(model_path_to_load):
|
|
65
|
+
controlnet = RBLNControlNetModel.from_pretrained(model_path_to_load, export=False, **kwargs)
|
|
66
|
+
controlnets.append(controlnet)
|
|
67
|
+
idx += 1
|
|
68
|
+
model_path_to_load = base_model_path_to_load + f"_{idx}"
|
|
69
|
+
|
|
70
|
+
return cls(
|
|
71
|
+
controlnets,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
|
|
75
|
+
for idx, model in enumerate(self.nets):
|
|
76
|
+
suffix = "" if idx == 0 else f"_{idx}"
|
|
77
|
+
real_save_path = save_directory + suffix
|
|
78
|
+
model.save_pretrained(real_save_path)
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def _update_rbln_config(cls, **rbln_config_kwargs):
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
def forward(
|
|
85
|
+
self,
|
|
86
|
+
sample: torch.FloatTensor,
|
|
87
|
+
timestep: Union[torch.Tensor, float, int],
|
|
88
|
+
encoder_hidden_states: torch.Tensor,
|
|
89
|
+
controlnet_cond: List[torch.Tensor],
|
|
90
|
+
conditioning_scale: List[float],
|
|
91
|
+
class_labels: Optional[torch.Tensor] = None,
|
|
92
|
+
timestep_cond: Optional[torch.Tensor] = None,
|
|
93
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
94
|
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
|
95
|
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
96
|
+
guess_mode: bool = False,
|
|
97
|
+
return_dict: bool = True,
|
|
98
|
+
):
|
|
99
|
+
"""
|
|
100
|
+
Forward pass for the RBLN-optimized MultiControlNetModel.
|
|
101
|
+
|
|
102
|
+
This method processes multiple ControlNet models in sequence, applying each one to the input sample
|
|
103
|
+
with its corresponding conditioning image and scale factor. The outputs from all ControlNets are
|
|
104
|
+
merged by addition to produce the final control signals.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
sample (torch.FloatTensor): The noisy input tensor.
|
|
108
|
+
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
|
|
109
|
+
encoder_hidden_states (torch.Tensor): The encoder hidden states from the text encoder.
|
|
110
|
+
controlnet_cond (List[torch.Tensor]): A list of conditional input tensors, one for each ControlNet model.
|
|
111
|
+
conditioning_scale (List[float]): A list of scale factors for each ControlNet output. Each scale
|
|
112
|
+
controls the strength of the corresponding ControlNet's influence on the generation.
|
|
113
|
+
return_dict (bool): Whether or not to return a dictionary instead of a plain tuple. Currently,
|
|
114
|
+
this method always returns a tuple regardless of this parameter.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
(Tuple[List[torch.Tensor], torch.Tensor])
|
|
118
|
+
"""
|
|
119
|
+
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
|
120
|
+
down_samples, mid_sample = controlnet(
|
|
121
|
+
sample=sample.contiguous(),
|
|
122
|
+
timestep=timestep.float(),
|
|
123
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
124
|
+
controlnet_cond=image,
|
|
125
|
+
conditioning_scale=torch.tensor(scale),
|
|
126
|
+
return_dict=return_dict,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# merge samples
|
|
130
|
+
if i == 0:
|
|
131
|
+
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
|
132
|
+
else:
|
|
133
|
+
down_block_res_samples = [
|
|
134
|
+
samples_prev + samples_curr
|
|
135
|
+
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
|
136
|
+
]
|
|
137
|
+
mid_block_res_sample += mid_sample
|
|
138
|
+
|
|
139
|
+
return down_block_res_samples, mid_block_res_sample
|