optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,408 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput
|
|
20
|
+
from transformers import PretrainedConfig
|
|
21
|
+
|
|
22
|
+
from ....configuration_utils import RBLNCompileConfig
|
|
23
|
+
from ....modeling import RBLNModel
|
|
24
|
+
from ....utils.logging import get_logger
|
|
25
|
+
from ...configurations import RBLNUNet2DConditionModelConfig
|
|
26
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
31
|
+
|
|
32
|
+
logger = get_logger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class _UNet_SD(torch.nn.Module):
|
|
36
|
+
def __init__(self, unet: "UNet2DConditionModel"):
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.unet = unet
|
|
39
|
+
|
|
40
|
+
def forward(
|
|
41
|
+
self,
|
|
42
|
+
sample: torch.Tensor,
|
|
43
|
+
timestep: Union[torch.Tensor, float, int],
|
|
44
|
+
encoder_hidden_states: torch.Tensor,
|
|
45
|
+
*down_and_mid_block_additional_residuals: Optional[Tuple[torch.Tensor]],
|
|
46
|
+
text_embeds: Optional[torch.Tensor] = None,
|
|
47
|
+
time_ids: Optional[torch.Tensor] = None,
|
|
48
|
+
) -> torch.Tensor:
|
|
49
|
+
if text_embeds is not None and time_ids is not None:
|
|
50
|
+
added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids}
|
|
51
|
+
else:
|
|
52
|
+
added_cond_kwargs = {}
|
|
53
|
+
|
|
54
|
+
if len(down_and_mid_block_additional_residuals) != 0:
|
|
55
|
+
down_block_additional_residuals, mid_block_additional_residual = (
|
|
56
|
+
down_and_mid_block_additional_residuals[:-1],
|
|
57
|
+
down_and_mid_block_additional_residuals[-1],
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
down_block_additional_residuals, mid_block_additional_residual = None, None
|
|
61
|
+
|
|
62
|
+
unet_out = self.unet(
|
|
63
|
+
sample=sample,
|
|
64
|
+
timestep=timestep,
|
|
65
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
66
|
+
down_block_additional_residuals=down_block_additional_residuals,
|
|
67
|
+
mid_block_additional_residual=mid_block_additional_residual,
|
|
68
|
+
added_cond_kwargs=added_cond_kwargs,
|
|
69
|
+
return_dict=False,
|
|
70
|
+
)
|
|
71
|
+
return unet_out
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class _UNet_SDXL(torch.nn.Module):
|
|
75
|
+
def __init__(self, unet: "UNet2DConditionModel"):
|
|
76
|
+
super().__init__()
|
|
77
|
+
self.unet = unet
|
|
78
|
+
|
|
79
|
+
def forward(
|
|
80
|
+
self,
|
|
81
|
+
sample: torch.Tensor,
|
|
82
|
+
timestep: Union[torch.Tensor, float, int],
|
|
83
|
+
encoder_hidden_states: torch.Tensor,
|
|
84
|
+
*down_and_mid_block_additional_residuals: Optional[Tuple[torch.Tensor]],
|
|
85
|
+
) -> torch.Tensor:
|
|
86
|
+
if len(down_and_mid_block_additional_residuals) == 2:
|
|
87
|
+
added_cond_kwargs = {
|
|
88
|
+
"text_embeds": down_and_mid_block_additional_residuals[0],
|
|
89
|
+
"time_ids": down_and_mid_block_additional_residuals[1],
|
|
90
|
+
}
|
|
91
|
+
down_block_additional_residuals = None
|
|
92
|
+
mid_block_additional_residual = None
|
|
93
|
+
elif len(down_and_mid_block_additional_residuals) > 2:
|
|
94
|
+
added_cond_kwargs = {
|
|
95
|
+
"text_embeds": down_and_mid_block_additional_residuals[-2],
|
|
96
|
+
"time_ids": down_and_mid_block_additional_residuals[-1],
|
|
97
|
+
}
|
|
98
|
+
down_block_additional_residuals, mid_block_additional_residual = (
|
|
99
|
+
down_and_mid_block_additional_residuals[:-3],
|
|
100
|
+
down_and_mid_block_additional_residuals[-3],
|
|
101
|
+
)
|
|
102
|
+
else:
|
|
103
|
+
added_cond_kwargs = {}
|
|
104
|
+
down_block_additional_residuals = None
|
|
105
|
+
mid_block_additional_residual = None
|
|
106
|
+
|
|
107
|
+
unet_out = self.unet(
|
|
108
|
+
sample=sample,
|
|
109
|
+
timestep=timestep,
|
|
110
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
111
|
+
down_block_additional_residuals=down_block_additional_residuals,
|
|
112
|
+
mid_block_additional_residual=mid_block_additional_residual,
|
|
113
|
+
added_cond_kwargs=added_cond_kwargs,
|
|
114
|
+
return_dict=False,
|
|
115
|
+
)
|
|
116
|
+
return unet_out
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class _UNet_Kandinsky(torch.nn.Module):
|
|
120
|
+
def __init__(self, unet: "UNet2DConditionModel"):
|
|
121
|
+
super().__init__()
|
|
122
|
+
self.unet = unet
|
|
123
|
+
|
|
124
|
+
def forward(
|
|
125
|
+
self,
|
|
126
|
+
sample: torch.Tensor,
|
|
127
|
+
timestep: Union[torch.Tensor, float, int],
|
|
128
|
+
image_embeds: torch.Tensor,
|
|
129
|
+
) -> torch.Tensor:
|
|
130
|
+
added_cond_kwargs = {"image_embeds": image_embeds}
|
|
131
|
+
|
|
132
|
+
unet_out = self.unet(
|
|
133
|
+
sample=sample,
|
|
134
|
+
timestep=timestep,
|
|
135
|
+
encoder_hidden_states=None,
|
|
136
|
+
added_cond_kwargs=added_cond_kwargs,
|
|
137
|
+
return_dict=False,
|
|
138
|
+
)
|
|
139
|
+
return unet_out
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class RBLNUNet2DConditionModel(RBLNModel):
|
|
143
|
+
"""
|
|
144
|
+
RBLN implementation of UNet2DConditionModel for diffusion models.
|
|
145
|
+
|
|
146
|
+
This model is used to accelerate UNet2DCondition models from diffusers library on RBLN NPUs.
|
|
147
|
+
It is a key component in diffusion-based image generation models like Stable Diffusion.
|
|
148
|
+
|
|
149
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
|
150
|
+
the library implements for all its models.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
hf_library_name = "diffusers"
|
|
154
|
+
auto_model_class = UNet2DConditionModel
|
|
155
|
+
_rbln_config_class = RBLNUNet2DConditionModelConfig
|
|
156
|
+
_output_class = UNet2DConditionOutput
|
|
157
|
+
|
|
158
|
+
def __post_init__(self, **kwargs):
|
|
159
|
+
super().__post_init__(**kwargs)
|
|
160
|
+
self.in_features = self.rbln_config.in_features
|
|
161
|
+
if self.in_features is not None:
|
|
162
|
+
|
|
163
|
+
@dataclass
|
|
164
|
+
class LINEAR1:
|
|
165
|
+
in_features: int
|
|
166
|
+
|
|
167
|
+
@dataclass
|
|
168
|
+
class ADDEMBEDDING:
|
|
169
|
+
linear_1: LINEAR1
|
|
170
|
+
|
|
171
|
+
self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
|
|
172
|
+
|
|
173
|
+
@classmethod
|
|
174
|
+
def _wrap_model_if_needed(
|
|
175
|
+
cls, model: torch.nn.Module, rbln_config: RBLNUNet2DConditionModelConfig
|
|
176
|
+
) -> torch.nn.Module:
|
|
177
|
+
if model.config.addition_embed_type == "text_time":
|
|
178
|
+
return _UNet_SDXL(model).eval()
|
|
179
|
+
elif model.config.addition_embed_type == "image":
|
|
180
|
+
return _UNet_Kandinsky(model).eval()
|
|
181
|
+
else:
|
|
182
|
+
return _UNet_SD(model).eval()
|
|
183
|
+
|
|
184
|
+
@classmethod
|
|
185
|
+
def get_unet_sample_size(
|
|
186
|
+
cls,
|
|
187
|
+
pipe: RBLNDiffusionMixin,
|
|
188
|
+
rbln_config: RBLNUNet2DConditionModelConfig,
|
|
189
|
+
image_size: Optional[Tuple[int, int]] = None,
|
|
190
|
+
) -> Tuple[int, int]:
|
|
191
|
+
if hasattr(pipe, "movq"):
|
|
192
|
+
scale_factor = 2 ** (len(pipe.movq.config.block_out_channels) - 1)
|
|
193
|
+
else:
|
|
194
|
+
scale_factor = pipe.vae_scale_factor
|
|
195
|
+
|
|
196
|
+
if image_size is None:
|
|
197
|
+
if "Img2Img" in pipe.__class__.__name__:
|
|
198
|
+
if hasattr(pipe, "vae"):
|
|
199
|
+
# In case of img2img, sample size of unet is determined by vae encoder.
|
|
200
|
+
vae_sample_size = pipe.vae.config.sample_size
|
|
201
|
+
if isinstance(vae_sample_size, int):
|
|
202
|
+
vae_sample_size = (vae_sample_size, vae_sample_size)
|
|
203
|
+
|
|
204
|
+
sample_size = (
|
|
205
|
+
vae_sample_size[0] // scale_factor,
|
|
206
|
+
vae_sample_size[1] // scale_factor,
|
|
207
|
+
)
|
|
208
|
+
elif hasattr(pipe, "movq"):
|
|
209
|
+
logger.warning(
|
|
210
|
+
"RBLN config 'image_size' should have been provided for this pipeline. "
|
|
211
|
+
"Both variable will be set 512 by default."
|
|
212
|
+
)
|
|
213
|
+
sample_size = (512 // scale_factor, 512 // scale_factor)
|
|
214
|
+
else:
|
|
215
|
+
sample_size = pipe.unet.config.sample_size
|
|
216
|
+
if isinstance(sample_size, int):
|
|
217
|
+
sample_size = (sample_size, sample_size)
|
|
218
|
+
else:
|
|
219
|
+
sample_size = (image_size[0] // scale_factor, image_size[1] // scale_factor)
|
|
220
|
+
|
|
221
|
+
return sample_size
|
|
222
|
+
|
|
223
|
+
@classmethod
|
|
224
|
+
def update_rbln_config_using_pipe(
|
|
225
|
+
cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
|
226
|
+
) -> "RBLNDiffusionMixinConfig":
|
|
227
|
+
rbln_config.unet.text_model_hidden_size = (
|
|
228
|
+
pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
|
|
229
|
+
)
|
|
230
|
+
rbln_config.unet.image_model_hidden_size = pipe.unet.config.encoder_hid_dim if hasattr(pipe, "unet") else None
|
|
231
|
+
|
|
232
|
+
rbln_config.unet.max_seq_len = (
|
|
233
|
+
pipe.text_encoder.config.max_position_embeddings if hasattr(pipe, "text_encoder") else None
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
rbln_config.unet.sample_size = cls.get_unet_sample_size(
|
|
237
|
+
pipe, rbln_config.unet, image_size=rbln_config.image_size
|
|
238
|
+
)
|
|
239
|
+
rbln_config.unet.use_additional_residuals = "controlnet" in pipe.config.keys()
|
|
240
|
+
|
|
241
|
+
return rbln_config
|
|
242
|
+
|
|
243
|
+
@classmethod
|
|
244
|
+
def _update_rbln_config(
|
|
245
|
+
cls,
|
|
246
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
247
|
+
model: "PreTrainedModel",
|
|
248
|
+
model_config: "PretrainedConfig",
|
|
249
|
+
rbln_config: RBLNUNet2DConditionModelConfig,
|
|
250
|
+
) -> RBLNUNet2DConditionModelConfig:
|
|
251
|
+
if rbln_config.sample_size is None:
|
|
252
|
+
rbln_config.sample_size = model_config.sample_size
|
|
253
|
+
|
|
254
|
+
if isinstance(rbln_config.sample_size, int):
|
|
255
|
+
rbln_config.sample_size = (rbln_config.sample_size, rbln_config.sample_size)
|
|
256
|
+
|
|
257
|
+
input_info = [
|
|
258
|
+
(
|
|
259
|
+
"sample",
|
|
260
|
+
[
|
|
261
|
+
rbln_config.batch_size,
|
|
262
|
+
model_config.in_channels,
|
|
263
|
+
rbln_config.sample_size[0],
|
|
264
|
+
rbln_config.sample_size[1],
|
|
265
|
+
],
|
|
266
|
+
"float32",
|
|
267
|
+
),
|
|
268
|
+
("timestep", [], "float32"),
|
|
269
|
+
]
|
|
270
|
+
|
|
271
|
+
if rbln_config.max_seq_len is not None:
|
|
272
|
+
input_info.append(
|
|
273
|
+
(
|
|
274
|
+
"encoder_hidden_states",
|
|
275
|
+
[rbln_config.batch_size, rbln_config.max_seq_len, model_config.cross_attention_dim],
|
|
276
|
+
"float32",
|
|
277
|
+
),
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
if rbln_config.use_additional_residuals:
|
|
281
|
+
# down block addtional residuals
|
|
282
|
+
first_shape = [
|
|
283
|
+
rbln_config.batch_size,
|
|
284
|
+
model_config.block_out_channels[0],
|
|
285
|
+
rbln_config.sample_size[0],
|
|
286
|
+
rbln_config.sample_size[1],
|
|
287
|
+
]
|
|
288
|
+
height, width = rbln_config.sample_size[0], rbln_config.sample_size[1]
|
|
289
|
+
input_info.append(("down_block_additional_residuals_0", first_shape, "float32"))
|
|
290
|
+
name_idx = 1
|
|
291
|
+
for idx, _ in enumerate(model_config.down_block_types):
|
|
292
|
+
shape = [rbln_config.batch_size, model_config.block_out_channels[idx], height, width]
|
|
293
|
+
for _ in range(model_config.layers_per_block):
|
|
294
|
+
input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
|
|
295
|
+
name_idx += 1
|
|
296
|
+
if idx != len(model_config.down_block_types) - 1:
|
|
297
|
+
height = height // 2
|
|
298
|
+
width = width // 2
|
|
299
|
+
shape = [rbln_config.batch_size, model_config.block_out_channels[idx], height, width]
|
|
300
|
+
input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
|
|
301
|
+
name_idx += 1
|
|
302
|
+
|
|
303
|
+
# mid block addtional residual
|
|
304
|
+
num_cross_attn_blocks = model_config.down_block_types.count("CrossAttnDownBlock2D")
|
|
305
|
+
out_channels = model_config.block_out_channels[-1]
|
|
306
|
+
shape = [
|
|
307
|
+
rbln_config.batch_size,
|
|
308
|
+
out_channels,
|
|
309
|
+
rbln_config.sample_size[0] // 2**num_cross_attn_blocks,
|
|
310
|
+
rbln_config.sample_size[1] // 2**num_cross_attn_blocks,
|
|
311
|
+
]
|
|
312
|
+
input_info.append(("mid_block_additional_residual", shape, "float32"))
|
|
313
|
+
|
|
314
|
+
if hasattr(model_config, "addition_embed_type"):
|
|
315
|
+
if model_config.addition_embed_type == "text_time":
|
|
316
|
+
rbln_config.in_features = model_config.projection_class_embeddings_input_dim
|
|
317
|
+
input_info.append(
|
|
318
|
+
("text_embeds", [rbln_config.batch_size, rbln_config.text_model_hidden_size], "float32")
|
|
319
|
+
)
|
|
320
|
+
input_info.append(("time_ids", [rbln_config.batch_size, 6], "float32"))
|
|
321
|
+
elif model_config.addition_embed_type == "image":
|
|
322
|
+
input_info.append(
|
|
323
|
+
("image_embeds", [rbln_config.batch_size, rbln_config.image_model_hidden_size], "float32")
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
|
327
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
|
328
|
+
|
|
329
|
+
return rbln_config
|
|
330
|
+
|
|
331
|
+
@property
|
|
332
|
+
def compiled_batch_size(self):
|
|
333
|
+
return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
|
|
334
|
+
|
|
335
|
+
def forward(
|
|
336
|
+
self,
|
|
337
|
+
sample: torch.Tensor,
|
|
338
|
+
timestep: Union[torch.Tensor, float, int],
|
|
339
|
+
encoder_hidden_states: torch.Tensor,
|
|
340
|
+
class_labels: Optional[torch.Tensor] = None,
|
|
341
|
+
timestep_cond: Optional[torch.Tensor] = None,
|
|
342
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
343
|
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
344
|
+
added_cond_kwargs: Dict[str, torch.Tensor] = {},
|
|
345
|
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
|
346
|
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
|
347
|
+
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
|
348
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
349
|
+
return_dict: bool = True,
|
|
350
|
+
**kwargs,
|
|
351
|
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
|
352
|
+
"""
|
|
353
|
+
Forward pass for the RBLN-optimized UNet2DConditionModel.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
sample (torch.Tensor): The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
|
357
|
+
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
|
|
358
|
+
encoder_hidden_states (torch.Tensor): The encoder hidden states.
|
|
359
|
+
added_cond_kwargs (Dict[str, torch.Tensor]): A kwargs dictionary containing additional embeddings that
|
|
360
|
+
if specified are added to the embeddings that are passed along to the UNet blocks.
|
|
361
|
+
down_block_additional_residuals (Optional[Tuple[torch.Tensor]]): A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
|
362
|
+
mid_block_additional_residual (Optional[torch.Tensor]): A tensor that if specified is added to the residual of the middle unet block.
|
|
363
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
(Union[`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`], Tuple)
|
|
367
|
+
"""
|
|
368
|
+
sample_batch_size = sample.size()[0]
|
|
369
|
+
compiled_batch_size = self.compiled_batch_size
|
|
370
|
+
if sample_batch_size != compiled_batch_size and (
|
|
371
|
+
sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
|
|
372
|
+
):
|
|
373
|
+
raise ValueError(
|
|
374
|
+
f"Mismatch between UNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
|
|
375
|
+
"This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size of UNet in Stable Diffusion. "
|
|
376
|
+
"Adjust the batch size of UNet during compilation to match the runtime batch size.\n\n"
|
|
377
|
+
"For details, see: https://docs.rbln.ai/software/optimum/model_api/diffusers/pipelines/stable_diffusion.html#important-batch-size-configuration-for-guidance-scale"
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
|
|
381
|
+
|
|
382
|
+
if down_block_additional_residuals is not None:
|
|
383
|
+
down_block_additional_residuals = [t.contiguous() for t in down_block_additional_residuals]
|
|
384
|
+
return super().forward(
|
|
385
|
+
sample.contiguous(),
|
|
386
|
+
timestep.float(),
|
|
387
|
+
encoder_hidden_states,
|
|
388
|
+
*down_block_additional_residuals,
|
|
389
|
+
mid_block_additional_residual,
|
|
390
|
+
**added_cond_kwargs,
|
|
391
|
+
return_dict=return_dict,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
if "image_embeds" in added_cond_kwargs:
|
|
395
|
+
return super().forward(
|
|
396
|
+
sample.contiguous(),
|
|
397
|
+
timestep.float(),
|
|
398
|
+
**added_cond_kwargs,
|
|
399
|
+
return_dict=return_dict,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
return super().forward(
|
|
403
|
+
sample.contiguous(),
|
|
404
|
+
timestep.float(),
|
|
405
|
+
encoder_hidden_states,
|
|
406
|
+
**added_cond_kwargs,
|
|
407
|
+
return_dict=return_dict,
|
|
408
|
+
)
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from diffusers.models.unets.unet_spatio_temporal_condition import (
|
|
20
|
+
UNetSpatioTemporalConditionModel,
|
|
21
|
+
UNetSpatioTemporalConditionOutput,
|
|
22
|
+
)
|
|
23
|
+
from transformers import PretrainedConfig
|
|
24
|
+
|
|
25
|
+
from ....configuration_utils import RBLNCompileConfig
|
|
26
|
+
from ....modeling import RBLNModel
|
|
27
|
+
from ....utils.logging import get_logger
|
|
28
|
+
from ...configurations import RBLNUNetSpatioTemporalConditionModelConfig
|
|
29
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, PreTrainedModel
|
|
34
|
+
|
|
35
|
+
logger = get_logger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class _UNet_STCM(torch.nn.Module):
|
|
39
|
+
def __init__(self, unet: "UNetSpatioTemporalConditionModel"):
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.unet = unet
|
|
42
|
+
|
|
43
|
+
def forward(
|
|
44
|
+
self,
|
|
45
|
+
sample: torch.Tensor,
|
|
46
|
+
timestep: Union[torch.Tensor, float, int],
|
|
47
|
+
encoder_hidden_states: torch.Tensor,
|
|
48
|
+
added_time_ids: torch.Tensor,
|
|
49
|
+
) -> torch.Tensor:
|
|
50
|
+
unet_out = self.unet(
|
|
51
|
+
sample=sample,
|
|
52
|
+
timestep=timestep,
|
|
53
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
54
|
+
added_time_ids=added_time_ids,
|
|
55
|
+
return_dict=False,
|
|
56
|
+
)
|
|
57
|
+
return unet_out
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class RBLNUNetSpatioTemporalConditionModel(RBLNModel):
|
|
61
|
+
hf_library_name = "diffusers"
|
|
62
|
+
auto_model_class = UNetSpatioTemporalConditionModel
|
|
63
|
+
_rbln_config_class = RBLNUNetSpatioTemporalConditionModelConfig
|
|
64
|
+
output_class = UNetSpatioTemporalConditionOutput
|
|
65
|
+
output_key = "sample"
|
|
66
|
+
|
|
67
|
+
def __post_init__(self, **kwargs):
|
|
68
|
+
super().__post_init__(**kwargs)
|
|
69
|
+
self.in_features = self.rbln_config.in_features
|
|
70
|
+
if self.in_features is not None:
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class LINEAR1:
|
|
74
|
+
in_features: int
|
|
75
|
+
|
|
76
|
+
@dataclass
|
|
77
|
+
class ADDEMBEDDING:
|
|
78
|
+
linear_1: LINEAR1
|
|
79
|
+
|
|
80
|
+
self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def _wrap_model_if_needed(
|
|
84
|
+
cls, model: torch.nn.Module, rbln_config: RBLNUNetSpatioTemporalConditionModelConfig
|
|
85
|
+
) -> torch.nn.Module:
|
|
86
|
+
return _UNet_STCM(model).eval()
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def get_unet_sample_size(
|
|
90
|
+
cls,
|
|
91
|
+
pipe: RBLNDiffusionMixin,
|
|
92
|
+
rbln_config: RBLNUNetSpatioTemporalConditionModelConfig,
|
|
93
|
+
image_size: Optional[Tuple[int, int]] = None,
|
|
94
|
+
) -> Union[int, Tuple[int, int]]:
|
|
95
|
+
scale_factor = pipe.vae_scale_factor
|
|
96
|
+
|
|
97
|
+
if image_size is None:
|
|
98
|
+
vae_sample_size = pipe.vae.config.sample_size
|
|
99
|
+
if isinstance(vae_sample_size, int):
|
|
100
|
+
vae_sample_size = (vae_sample_size, vae_sample_size)
|
|
101
|
+
|
|
102
|
+
sample_size = (
|
|
103
|
+
vae_sample_size[0] // scale_factor,
|
|
104
|
+
vae_sample_size[1] // scale_factor,
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
sample_size = (image_size[0] // scale_factor, image_size[1] // scale_factor)
|
|
108
|
+
return sample_size
|
|
109
|
+
|
|
110
|
+
@classmethod
|
|
111
|
+
def update_rbln_config_using_pipe(
|
|
112
|
+
cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
|
113
|
+
) -> Dict[str, Any]:
|
|
114
|
+
rbln_config.unet.sample_size = cls.get_unet_sample_size(
|
|
115
|
+
pipe, rbln_config.unet, image_size=rbln_config.image_size
|
|
116
|
+
)
|
|
117
|
+
return rbln_config
|
|
118
|
+
|
|
119
|
+
@classmethod
|
|
120
|
+
def _update_rbln_config(
|
|
121
|
+
cls,
|
|
122
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor"],
|
|
123
|
+
model: "PreTrainedModel",
|
|
124
|
+
model_config: "PretrainedConfig",
|
|
125
|
+
rbln_config: RBLNUNetSpatioTemporalConditionModelConfig,
|
|
126
|
+
) -> RBLNUNetSpatioTemporalConditionModelConfig:
|
|
127
|
+
if rbln_config.num_frames is None:
|
|
128
|
+
rbln_config.num_frames = model_config.num_frames
|
|
129
|
+
|
|
130
|
+
if rbln_config.sample_size is None:
|
|
131
|
+
rbln_config.sample_size = model_config.sample_size
|
|
132
|
+
|
|
133
|
+
input_info = [
|
|
134
|
+
(
|
|
135
|
+
"sample",
|
|
136
|
+
[
|
|
137
|
+
rbln_config.batch_size,
|
|
138
|
+
rbln_config.num_frames,
|
|
139
|
+
model_config.in_channels,
|
|
140
|
+
rbln_config.sample_size[0],
|
|
141
|
+
rbln_config.sample_size[1],
|
|
142
|
+
],
|
|
143
|
+
"float32",
|
|
144
|
+
),
|
|
145
|
+
("timestep", [], "float32"),
|
|
146
|
+
("encoder_hidden_states", [rbln_config.batch_size, 1, model_config.cross_attention_dim], "float32"),
|
|
147
|
+
("added_time_ids", [rbln_config.batch_size, 3], "float32"),
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
if hasattr(model_config, "addition_time_embed_dim"):
|
|
151
|
+
rbln_config.in_features = model_config.projection_class_embeddings_input_dim
|
|
152
|
+
|
|
153
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
|
154
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
|
155
|
+
|
|
156
|
+
return rbln_config
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def compiled_batch_size(self):
|
|
160
|
+
return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
|
|
161
|
+
|
|
162
|
+
def forward(
|
|
163
|
+
self,
|
|
164
|
+
sample: torch.Tensor,
|
|
165
|
+
timestep: Union[torch.Tensor, float, int],
|
|
166
|
+
encoder_hidden_states: torch.Tensor,
|
|
167
|
+
added_time_ids: torch.Tensor,
|
|
168
|
+
return_dict: bool = True,
|
|
169
|
+
**kwargs,
|
|
170
|
+
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
|
|
171
|
+
"""
|
|
172
|
+
Forward pass for the RBLN-optimized UNetSpatioTemporalConditionModel.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
sample (torch.Tensor): The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
|
176
|
+
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
|
|
177
|
+
encoder_hidden_states (torch.Tensor): The encoder hidden states.
|
|
178
|
+
added_time_ids (torch.Tensor): A tensor containing additional sinusoidal embeddings and added to the time embeddings.
|
|
179
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.unets.unet_spatio_temporal_condition.UNetSpatioTemporalConditionOutput`] instead of a plain tuple.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
(Union[`~diffusers.models.unets.unet_spatio_temporal_condition.UNetSpatioTemporalConditionOutput`], Tuple)
|
|
183
|
+
"""
|
|
184
|
+
sample_batch_size = sample.size()[0]
|
|
185
|
+
compiled_batch_size = self.compiled_batch_size
|
|
186
|
+
if sample_batch_size != compiled_batch_size and (
|
|
187
|
+
sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
|
|
188
|
+
):
|
|
189
|
+
raise ValueError(
|
|
190
|
+
f"Mismatch between UNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
|
|
191
|
+
"This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
|
|
192
|
+
"Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
|
|
193
|
+
"For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
|
|
194
|
+
)
|
|
195
|
+
return super().forward(
|
|
196
|
+
sample.contiguous(),
|
|
197
|
+
timestep.float(),
|
|
198
|
+
encoder_hidden_states,
|
|
199
|
+
added_time_ids,
|
|
200
|
+
return_dict=return_dict,
|
|
201
|
+
)
|