optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING, Dict, Optional, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from diffusers import ControlNetModel
|
|
19
|
+
from diffusers.models.controlnets.controlnet import ControlNetOutput
|
|
20
|
+
from transformers import PretrainedConfig
|
|
21
|
+
|
|
22
|
+
from ...configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
23
|
+
from ...modeling import RBLNModel
|
|
24
|
+
from ...utils.logging import get_logger
|
|
25
|
+
from ...utils.model_utils import get_rbln_model_cls
|
|
26
|
+
from ..configurations import RBLNControlNetModelConfig
|
|
27
|
+
from ..modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
logger = get_logger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class _ControlNetModel(torch.nn.Module):
|
|
38
|
+
def __init__(self, controlnet: "ControlNetModel"):
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.controlnet = controlnet
|
|
41
|
+
|
|
42
|
+
def forward(
|
|
43
|
+
self,
|
|
44
|
+
sample: torch.Tensor,
|
|
45
|
+
timestep: torch.Tensor,
|
|
46
|
+
controlnet_cond: torch.Tensor,
|
|
47
|
+
conditioning_scale,
|
|
48
|
+
text_embeds: Optional[torch.Tensor] = None,
|
|
49
|
+
time_ids: Optional[torch.Tensor] = None,
|
|
50
|
+
):
|
|
51
|
+
if text_embeds is not None and time_ids is not None:
|
|
52
|
+
added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids}
|
|
53
|
+
else:
|
|
54
|
+
added_cond_kwargs = {}
|
|
55
|
+
|
|
56
|
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
|
57
|
+
sample=sample,
|
|
58
|
+
timestep=timestep,
|
|
59
|
+
encoder_hidden_states=None,
|
|
60
|
+
controlnet_cond=controlnet_cond,
|
|
61
|
+
conditioning_scale=conditioning_scale,
|
|
62
|
+
added_cond_kwargs=added_cond_kwargs,
|
|
63
|
+
return_dict=False,
|
|
64
|
+
)
|
|
65
|
+
return down_block_res_samples, mid_block_res_sample
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class _ControlNetModel_Cross_Attention(torch.nn.Module):
|
|
69
|
+
def __init__(self, controlnet: "ControlNetModel"):
|
|
70
|
+
super().__init__()
|
|
71
|
+
self.controlnet = controlnet
|
|
72
|
+
|
|
73
|
+
def forward(
|
|
74
|
+
self,
|
|
75
|
+
sample: torch.Tensor,
|
|
76
|
+
timestep: torch.Tensor,
|
|
77
|
+
encoder_hidden_states: torch.Tensor,
|
|
78
|
+
controlnet_cond: torch.Tensor,
|
|
79
|
+
conditioning_scale,
|
|
80
|
+
text_embeds: Optional[torch.Tensor] = None,
|
|
81
|
+
time_ids: Optional[torch.Tensor] = None,
|
|
82
|
+
):
|
|
83
|
+
if text_embeds is not None and time_ids is not None:
|
|
84
|
+
added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids}
|
|
85
|
+
else:
|
|
86
|
+
added_cond_kwargs = {}
|
|
87
|
+
|
|
88
|
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
|
89
|
+
sample=sample,
|
|
90
|
+
timestep=timestep,
|
|
91
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
92
|
+
controlnet_cond=controlnet_cond,
|
|
93
|
+
conditioning_scale=conditioning_scale,
|
|
94
|
+
added_cond_kwargs=added_cond_kwargs,
|
|
95
|
+
return_dict=False,
|
|
96
|
+
)
|
|
97
|
+
return down_block_res_samples, mid_block_res_sample
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class RBLNControlNetModel(RBLNModel):
|
|
101
|
+
"""
|
|
102
|
+
RBLN implementation of ControlNetModel for diffusion models.
|
|
103
|
+
|
|
104
|
+
This model is used to accelerate ControlNetModel models from diffusers library on RBLN NPUs.
|
|
105
|
+
|
|
106
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
|
107
|
+
the library implements for all its models.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
hf_library_name = "diffusers"
|
|
111
|
+
auto_model_class = ControlNetModel
|
|
112
|
+
output_class = ControlNetOutput
|
|
113
|
+
|
|
114
|
+
def __post_init__(self, **kwargs):
|
|
115
|
+
super().__post_init__(**kwargs)
|
|
116
|
+
self.use_encoder_hidden_states = any(
|
|
117
|
+
item[0] == "encoder_hidden_states" for item in self.rbln_config.compile_cfgs[0].input_info
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
@classmethod
|
|
121
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
122
|
+
use_encoder_hidden_states = False
|
|
123
|
+
for down_block in model.down_blocks:
|
|
124
|
+
if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
|
|
125
|
+
break
|
|
126
|
+
|
|
127
|
+
if use_encoder_hidden_states:
|
|
128
|
+
return _ControlNetModel_Cross_Attention(model).eval()
|
|
129
|
+
else:
|
|
130
|
+
return _ControlNetModel(model).eval()
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
def update_rbln_config_using_pipe(
|
|
134
|
+
cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
|
135
|
+
) -> "RBLNDiffusionMixinConfig":
|
|
136
|
+
rbln_vae_cls = get_rbln_model_cls(f"RBLN{pipe.vae.__class__.__name__}")
|
|
137
|
+
rbln_unet_cls = get_rbln_model_cls(f"RBLN{pipe.unet.__class__.__name__}")
|
|
138
|
+
|
|
139
|
+
rbln_config.controlnet.max_seq_len = pipe.text_encoder.config.max_position_embeddings
|
|
140
|
+
text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
|
|
141
|
+
rbln_config.controlnet.text_model_hidden_size = text_model_hidden_size
|
|
142
|
+
rbln_config.controlnet.vae_sample_size = rbln_vae_cls.get_vae_sample_size(pipe, rbln_config.vae)
|
|
143
|
+
rbln_config.controlnet.unet_sample_size = rbln_unet_cls.get_unet_sample_size(
|
|
144
|
+
pipe, rbln_config.unet, image_size=rbln_config.image_size
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
return rbln_config
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
def _update_rbln_config(
|
|
151
|
+
cls,
|
|
152
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
153
|
+
model: "PreTrainedModel",
|
|
154
|
+
model_config: "PretrainedConfig",
|
|
155
|
+
rbln_config: RBLNControlNetModelConfig,
|
|
156
|
+
) -> RBLNModelConfig:
|
|
157
|
+
if rbln_config.unet_sample_size is None:
|
|
158
|
+
raise ValueError("`unet_sample_size` (latent height, width) must be specified (ex. unet's sample_size)")
|
|
159
|
+
|
|
160
|
+
if rbln_config.vae_sample_size is None:
|
|
161
|
+
raise ValueError("`vae_sample_size` (input image height, width) must be specified (ex. vae's sample_size)")
|
|
162
|
+
|
|
163
|
+
if rbln_config.max_seq_len is None:
|
|
164
|
+
raise ValueError("`max_seq_len` (ex. text_encoder's max_position_embeddings) must be specified")
|
|
165
|
+
|
|
166
|
+
input_info = [
|
|
167
|
+
(
|
|
168
|
+
"sample",
|
|
169
|
+
[
|
|
170
|
+
rbln_config.batch_size,
|
|
171
|
+
model_config.in_channels,
|
|
172
|
+
rbln_config.unet_sample_size[0],
|
|
173
|
+
rbln_config.unet_sample_size[1],
|
|
174
|
+
],
|
|
175
|
+
"float32",
|
|
176
|
+
),
|
|
177
|
+
("timestep", [], "float32"),
|
|
178
|
+
]
|
|
179
|
+
|
|
180
|
+
use_encoder_hidden_states = any(element != "DownBlock2D" for element in model_config.down_block_types)
|
|
181
|
+
if use_encoder_hidden_states:
|
|
182
|
+
input_info.append(
|
|
183
|
+
(
|
|
184
|
+
"encoder_hidden_states",
|
|
185
|
+
[rbln_config.batch_size, rbln_config.max_seq_len, model_config.cross_attention_dim],
|
|
186
|
+
"float32",
|
|
187
|
+
)
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
input_info.append(
|
|
191
|
+
(
|
|
192
|
+
"controlnet_cond",
|
|
193
|
+
[rbln_config.batch_size, 3, rbln_config.vae_sample_size[0], rbln_config.vae_sample_size[1]],
|
|
194
|
+
"float32",
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
input_info.append(("conditioning_scale", [], "float32"))
|
|
198
|
+
|
|
199
|
+
if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
|
|
200
|
+
input_info.append(("text_embeds", [rbln_config.batch_size, rbln_config.text_model_hidden_size], "float32"))
|
|
201
|
+
input_info.append(("time_ids", [rbln_config.batch_size, 6], "float32"))
|
|
202
|
+
|
|
203
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
|
204
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
|
205
|
+
return rbln_config
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def compiled_batch_size(self):
|
|
209
|
+
return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
|
|
210
|
+
|
|
211
|
+
def forward(
|
|
212
|
+
self,
|
|
213
|
+
sample: torch.FloatTensor,
|
|
214
|
+
timestep: Union[torch.Tensor, float, int],
|
|
215
|
+
encoder_hidden_states: torch.Tensor,
|
|
216
|
+
controlnet_cond: torch.FloatTensor,
|
|
217
|
+
conditioning_scale: torch.Tensor = 1.0,
|
|
218
|
+
added_cond_kwargs: Dict[str, torch.Tensor] = {},
|
|
219
|
+
return_dict: bool = True,
|
|
220
|
+
**kwargs,
|
|
221
|
+
):
|
|
222
|
+
"""
|
|
223
|
+
Forward pass for the RBLN-optimized ControlNetModel.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
sample (torch.FloatTensor): The noisy input tensor.
|
|
227
|
+
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
|
|
228
|
+
encoder_hidden_states (torch.Tensor): The encoder hidden states.
|
|
229
|
+
controlnet_cond (torch.FloatTensor): The conditional input tensor of shape `(batch_size, max_seq_len, hidden_size)`.
|
|
230
|
+
conditioning_scale (torch.Tensor): The scale factor for ControlNet outputs.
|
|
231
|
+
added_cond_kwargs (Dict[str, torch.Tensor]): Additional conditions for the Stable Diffusion XL UNet.
|
|
232
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
(Union[`~diffusers.models.controlnets.controlnet.ControlNetOutput`], Tuple)
|
|
236
|
+
"""
|
|
237
|
+
sample_batch_size = sample.size()[0]
|
|
238
|
+
compiled_batch_size = self.compiled_batch_size
|
|
239
|
+
if sample_batch_size != compiled_batch_size and (
|
|
240
|
+
sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
|
|
241
|
+
):
|
|
242
|
+
raise ValueError(
|
|
243
|
+
f"Mismatch between ControlNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
|
|
244
|
+
"This may be caused by the 'guidance_scale' parameter, which doubles the runtime batch size of ControlNet in Stable Diffusion. "
|
|
245
|
+
"Adjust the batch size of ControlNet during compilation to match the runtime batch size.\n\n"
|
|
246
|
+
"For details, see: https://docs.rbln.ai/software/optimum/model_api/diffusers/pipelines/controlnet.html#important-batch-size-configuration-for-guidance-scale"
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
|
|
250
|
+
if self.use_encoder_hidden_states:
|
|
251
|
+
output = self.model[0](
|
|
252
|
+
sample.contiguous(),
|
|
253
|
+
timestep.float(),
|
|
254
|
+
encoder_hidden_states,
|
|
255
|
+
controlnet_cond,
|
|
256
|
+
torch.tensor(conditioning_scale),
|
|
257
|
+
**added_cond_kwargs,
|
|
258
|
+
)
|
|
259
|
+
else:
|
|
260
|
+
output = self.model[0](
|
|
261
|
+
sample.contiguous(),
|
|
262
|
+
timestep.float(),
|
|
263
|
+
controlnet_cond,
|
|
264
|
+
torch.tensor(conditioning_scale),
|
|
265
|
+
**added_cond_kwargs,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
down_block_res_samples = output[:-1]
|
|
269
|
+
mid_block_res_sample = output[-1]
|
|
270
|
+
output = (down_block_res_samples, mid_block_res_sample)
|
|
271
|
+
output = self._prepare_output(output, return_dict)
|
|
272
|
+
return output
|
|
273
|
+
|
|
274
|
+
def _prepare_output(self, output, return_dict):
|
|
275
|
+
if not return_dict:
|
|
276
|
+
return (output,) if not isinstance(output, (tuple, list)) else output
|
|
277
|
+
else:
|
|
278
|
+
return ControlNetOutput(
|
|
279
|
+
down_block_res_samples=output[:-1],
|
|
280
|
+
mid_block_res_sample=output[-1],
|
|
281
|
+
)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .prior_transformer import RBLNPriorTransformer
|
|
16
|
+
from .transformer_cosmos import RBLNCosmosTransformer3DModel
|
|
17
|
+
from .transformer_sd3 import RBLNSD3Transformer2DModel
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from diffusers.models.transformers.prior_transformer import PriorTransformer, PriorTransformerOutput
|
|
20
|
+
|
|
21
|
+
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
22
|
+
from ....modeling import RBLNModel
|
|
23
|
+
from ....utils.logging import get_logger
|
|
24
|
+
from ...configurations.models import RBLNPriorTransformerConfig
|
|
25
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
|
30
|
+
|
|
31
|
+
logger = get_logger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class _PriorTransformer(torch.nn.Module):
|
|
35
|
+
def __init__(self, prior: PriorTransformer):
|
|
36
|
+
super().__init__()
|
|
37
|
+
self._prior = prior
|
|
38
|
+
|
|
39
|
+
def forward(
|
|
40
|
+
self,
|
|
41
|
+
hidden_states,
|
|
42
|
+
timestep,
|
|
43
|
+
proj_embedding,
|
|
44
|
+
encoder_hidden_states,
|
|
45
|
+
attention_mask,
|
|
46
|
+
return_dict=True,
|
|
47
|
+
):
|
|
48
|
+
return self._prior.forward(
|
|
49
|
+
hidden_states,
|
|
50
|
+
timestep,
|
|
51
|
+
proj_embedding,
|
|
52
|
+
encoder_hidden_states,
|
|
53
|
+
attention_mask,
|
|
54
|
+
return_dict=False,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class RBLNPriorTransformer(RBLNModel):
|
|
59
|
+
"""
|
|
60
|
+
RBLN implementation of PriorTransformer for diffusion models like Kandinsky V2.2.
|
|
61
|
+
|
|
62
|
+
The PriorTransformer takes text and/or image embeddings from encoders (like CLIP) and
|
|
63
|
+
maps them to a shared latent space that guides the diffusion process to generate the desired image.
|
|
64
|
+
|
|
65
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
|
66
|
+
the library implements for all its models.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
hf_library_name = "diffusers"
|
|
70
|
+
auto_model_class = PriorTransformer
|
|
71
|
+
_output_class = PriorTransformerOutput
|
|
72
|
+
|
|
73
|
+
def __post_init__(self, **kwargs):
|
|
74
|
+
super().__post_init__(**kwargs)
|
|
75
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
76
|
+
self.clip_mean = artifacts["clip_mean"]
|
|
77
|
+
self.clip_std = artifacts["clip_std"]
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
81
|
+
return _PriorTransformer(model).eval()
|
|
82
|
+
|
|
83
|
+
@classmethod
|
|
84
|
+
def update_rbln_config_using_pipe(
|
|
85
|
+
cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
|
86
|
+
) -> "RBLNDiffusionMixinConfig":
|
|
87
|
+
return rbln_config
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def save_torch_artifacts(
|
|
91
|
+
cls, model: "PreTrainedModel", save_dir_path: Path, subfolder: str, rbln_config: RBLNModelConfig
|
|
92
|
+
):
|
|
93
|
+
save_dict = {}
|
|
94
|
+
save_dict["clip_mean"] = model.clip_mean
|
|
95
|
+
save_dict["clip_std"] = model.clip_std
|
|
96
|
+
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def _update_rbln_config(
|
|
100
|
+
cls,
|
|
101
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
102
|
+
model: "PreTrainedModel",
|
|
103
|
+
model_config: "PretrainedConfig",
|
|
104
|
+
rbln_config: RBLNPriorTransformerConfig,
|
|
105
|
+
) -> RBLNPriorTransformerConfig:
|
|
106
|
+
rbln_config.embedding_dim = rbln_config.embedding_dim or model_config.embedding_dim
|
|
107
|
+
rbln_config.num_embeddings = rbln_config.num_embeddings or model_config.num_embeddings
|
|
108
|
+
|
|
109
|
+
input_info = [
|
|
110
|
+
("hidden_states", [rbln_config.batch_size, rbln_config.embedding_dim], "float32"),
|
|
111
|
+
("timestep", [], "float32"),
|
|
112
|
+
("proj_embedding", [rbln_config.batch_size, rbln_config.embedding_dim], "float32"),
|
|
113
|
+
(
|
|
114
|
+
"encoder_hidden_states",
|
|
115
|
+
[rbln_config.batch_size, rbln_config.num_embeddings, rbln_config.embedding_dim],
|
|
116
|
+
"float32",
|
|
117
|
+
),
|
|
118
|
+
("attention_mask", [rbln_config.batch_size, rbln_config.num_embeddings], "float32"),
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
|
122
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
|
123
|
+
return rbln_config
|
|
124
|
+
|
|
125
|
+
def post_process_latents(self, prior_latents):
|
|
126
|
+
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
|
127
|
+
return prior_latents
|
|
128
|
+
|
|
129
|
+
def forward(
|
|
130
|
+
self,
|
|
131
|
+
hidden_states: torch.Tensor,
|
|
132
|
+
timestep: Union[torch.Tensor, float, int],
|
|
133
|
+
proj_embedding: torch.Tensor,
|
|
134
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
135
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
136
|
+
return_dict: bool = True,
|
|
137
|
+
):
|
|
138
|
+
"""
|
|
139
|
+
Forward pass for the RBLN-optimized PriorTransformer.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
hidden_states (torch.Tensor): The currently predicted image embeddings.
|
|
143
|
+
timestep (Union[torch.Tensor, float, int]): Current denoising step.
|
|
144
|
+
proj_embedding (torch.Tensor): Projected embedding vector the denoising process is conditioned on.
|
|
145
|
+
encoder_hidden_states (Optional[torch.Tensor]): Hidden states of the text embeddings the denoising process is conditioned on.
|
|
146
|
+
attention_mask (Optional[torch.Tensor]): Text mask for the text embeddings.
|
|
147
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`] instead of a plain tuple.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
(Union[`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`, Tuple])
|
|
151
|
+
"""
|
|
152
|
+
# Convert timestep(long) and attention_mask(bool) to float
|
|
153
|
+
return super().forward(
|
|
154
|
+
hidden_states,
|
|
155
|
+
timestep.float(),
|
|
156
|
+
proj_embedding,
|
|
157
|
+
encoder_hidden_states,
|
|
158
|
+
attention_mask.float(),
|
|
159
|
+
return_dict=return_dict,
|
|
160
|
+
)
|