optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import TYPE_CHECKING, List, Optional, Union
|
|
17
|
+
|
|
18
|
+
import rebel
|
|
19
|
+
import torch
|
|
20
|
+
from diffusers import CosmosTransformer3DModel
|
|
21
|
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
|
22
|
+
from diffusers.models.transformers.transformer_cosmos import (
|
|
23
|
+
CosmosEmbedding,
|
|
24
|
+
CosmosLearnablePositionalEmbed,
|
|
25
|
+
CosmosPatchEmbed,
|
|
26
|
+
CosmosRotaryPosEmbed,
|
|
27
|
+
)
|
|
28
|
+
from torchvision import transforms
|
|
29
|
+
|
|
30
|
+
from ....configuration_utils import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNModelConfig
|
|
31
|
+
from ....modeling import RBLNModel
|
|
32
|
+
from ....utils.logging import get_logger
|
|
33
|
+
from ...configurations import RBLNCosmosTransformer3DModelConfig
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
|
38
|
+
|
|
39
|
+
from ...modeling_diffusers import RBLNCosmosTransformer3DModelConfig, RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
logger = get_logger(__name__)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class CosmosTransformer3DModelWrapper(torch.nn.Module):
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
model: CosmosTransformer3DModel,
|
|
49
|
+
num_latent_frames: int = 16,
|
|
50
|
+
latent_height: int = 88,
|
|
51
|
+
latent_width: int = 160,
|
|
52
|
+
) -> None:
|
|
53
|
+
super().__init__()
|
|
54
|
+
self.model = model
|
|
55
|
+
self.num_latent_frames = num_latent_frames
|
|
56
|
+
self.latent_height = latent_height
|
|
57
|
+
self.latent_width = latent_width
|
|
58
|
+
self.p_t, self.p_h, self.p_w = model.config.patch_size
|
|
59
|
+
|
|
60
|
+
def forward(
|
|
61
|
+
self,
|
|
62
|
+
hidden_states: torch.Tensor,
|
|
63
|
+
encoder_hidden_states: torch.Tensor,
|
|
64
|
+
embedded_timestep: torch.Tensor,
|
|
65
|
+
temb: torch.Tensor,
|
|
66
|
+
image_rotary_emb_0: torch.Tensor,
|
|
67
|
+
image_rotary_emb_1: torch.Tensor,
|
|
68
|
+
extra_pos_emb: Optional[torch.Tensor] = None,
|
|
69
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
70
|
+
return_dict: bool = False,
|
|
71
|
+
):
|
|
72
|
+
image_rotary_emb = [image_rotary_emb_0, image_rotary_emb_1]
|
|
73
|
+
for block in self.model.transformer_blocks:
|
|
74
|
+
hidden_states = block(
|
|
75
|
+
hidden_states=hidden_states,
|
|
76
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
77
|
+
embedded_timestep=embedded_timestep,
|
|
78
|
+
temb=temb,
|
|
79
|
+
image_rotary_emb=image_rotary_emb,
|
|
80
|
+
extra_pos_emb=extra_pos_emb,
|
|
81
|
+
attention_mask=attention_mask,
|
|
82
|
+
)
|
|
83
|
+
post_patch_num_frames = self.num_latent_frames // self.p_t
|
|
84
|
+
post_patch_height = self.latent_height // self.p_h
|
|
85
|
+
post_patch_width = self.latent_width // self.p_w
|
|
86
|
+
hidden_states = self.model.norm_out(hidden_states, embedded_timestep, temb)
|
|
87
|
+
hidden_states = self.model.proj_out(hidden_states)
|
|
88
|
+
hidden_states = hidden_states.unflatten(2, (self.p_h, self.p_w, self.p_t, -1))
|
|
89
|
+
hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
|
|
90
|
+
hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
|
|
91
|
+
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
|
92
|
+
|
|
93
|
+
return (hidden_states,)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class RBLNCosmosTransformer3DModel(RBLNModel):
|
|
97
|
+
"""
|
|
98
|
+
RBLN implementation of CosmosTransformer3DModel for diffusion models like Cosmos.
|
|
99
|
+
|
|
100
|
+
The CosmosTransformer3DModel takes text and/or image embeddings from encoders (like CLIP) and
|
|
101
|
+
maps them to a shared latent space that guides the diffusion process to generate the desired image.
|
|
102
|
+
|
|
103
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
|
104
|
+
the library implements for all its models.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
hf_library_name = "diffusers"
|
|
108
|
+
auto_model_class = CosmosTransformer3DModel
|
|
109
|
+
|
|
110
|
+
def __post_init__(self, **kwargs):
|
|
111
|
+
super().__post_init__(**kwargs)
|
|
112
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
113
|
+
|
|
114
|
+
hidden_size = self.config.num_attention_heads * self.config.attention_head_dim
|
|
115
|
+
patch_embed_in_channels = (
|
|
116
|
+
self.config.in_channels + 1 if self.config.concat_padding_mask else self.config.in_channels
|
|
117
|
+
)
|
|
118
|
+
self.rope = CosmosRotaryPosEmbed(
|
|
119
|
+
hidden_size=self.config.attention_head_dim,
|
|
120
|
+
max_size=self.config.max_size,
|
|
121
|
+
patch_size=self.config.patch_size,
|
|
122
|
+
rope_scale=self.config.rope_scale,
|
|
123
|
+
)
|
|
124
|
+
self.rope.load_state_dict(artifacts["rope"])
|
|
125
|
+
if artifacts["learnable_pos_embed"] is None:
|
|
126
|
+
self.learnable_pos_embed = None
|
|
127
|
+
else:
|
|
128
|
+
self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
|
|
129
|
+
hidden_size=hidden_size,
|
|
130
|
+
max_size=self.config.max_size,
|
|
131
|
+
patch_size=self.config.patch_size,
|
|
132
|
+
)
|
|
133
|
+
self.learnable_pos_embed.load_state_dict(artifacts["learnable_pos_embed"])
|
|
134
|
+
self.patch_embed = CosmosPatchEmbed(patch_embed_in_channels, hidden_size, self.config.patch_size, bias=False)
|
|
135
|
+
self.patch_embed.load_state_dict(artifacts["patch_embed"])
|
|
136
|
+
self.time_embed = CosmosEmbedding(hidden_size, hidden_size)
|
|
137
|
+
self.time_embed.load_state_dict(artifacts["time_embed"])
|
|
138
|
+
|
|
139
|
+
def compute_embedding(
|
|
140
|
+
self,
|
|
141
|
+
hidden_states: torch.Tensor,
|
|
142
|
+
timestep: torch.Tensor,
|
|
143
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
144
|
+
fps: Optional[int] = None,
|
|
145
|
+
condition_mask: Optional[torch.Tensor] = None,
|
|
146
|
+
padding_mask: Optional[torch.Tensor] = None,
|
|
147
|
+
):
|
|
148
|
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
|
149
|
+
|
|
150
|
+
# 1. Concatenate padding mask if needed & prepare attention mask
|
|
151
|
+
if condition_mask is not None:
|
|
152
|
+
hidden_states = torch.cat([hidden_states, condition_mask], dim=1)
|
|
153
|
+
|
|
154
|
+
if self.config.concat_padding_mask:
|
|
155
|
+
padding_mask = transforms.functional.resize(
|
|
156
|
+
padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
|
157
|
+
)
|
|
158
|
+
hidden_states = torch.cat(
|
|
159
|
+
[hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
if attention_mask is not None:
|
|
163
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
|
|
164
|
+
|
|
165
|
+
# 2. Generate positional embeddings
|
|
166
|
+
image_rotary_emb = self.rope(hidden_states, fps=fps)
|
|
167
|
+
extra_pos_emb = self.learnable_pos_embed(hidden_states) if self.config.extra_pos_embed_type else None
|
|
168
|
+
|
|
169
|
+
# 3. Patchify input
|
|
170
|
+
p_t, p_h, p_w = self.config.patch_size
|
|
171
|
+
hidden_states = self.patch_embed(hidden_states)
|
|
172
|
+
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
|
|
173
|
+
|
|
174
|
+
# 4. Timestep embeddings
|
|
175
|
+
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
|
|
176
|
+
|
|
177
|
+
return (
|
|
178
|
+
hidden_states,
|
|
179
|
+
temb,
|
|
180
|
+
embedded_timestep,
|
|
181
|
+
image_rotary_emb[0],
|
|
182
|
+
image_rotary_emb[1],
|
|
183
|
+
extra_pos_emb,
|
|
184
|
+
attention_mask,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
@classmethod
|
|
188
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
189
|
+
num_latent_frames = rbln_config.num_latent_frames
|
|
190
|
+
latent_height = rbln_config.latent_height
|
|
191
|
+
latent_width = rbln_config.latent_width
|
|
192
|
+
return CosmosTransformer3DModelWrapper(
|
|
193
|
+
model=model,
|
|
194
|
+
num_latent_frames=num_latent_frames,
|
|
195
|
+
latent_height=latent_height,
|
|
196
|
+
latent_width=latent_width,
|
|
197
|
+
).eval()
|
|
198
|
+
|
|
199
|
+
@classmethod
|
|
200
|
+
def update_rbln_config_using_pipe(
|
|
201
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
|
202
|
+
) -> RBLNCosmosTransformer3DModelConfig:
|
|
203
|
+
rbln_config.transformer.num_latent_frames = (
|
|
204
|
+
rbln_config.transformer.num_frames - 1
|
|
205
|
+
) // pipe.vae_scale_factor_temporal + 1
|
|
206
|
+
rbln_config.transformer.latent_height = rbln_config.transformer.height // pipe.vae_scale_factor_spatial
|
|
207
|
+
rbln_config.transformer.latent_width = rbln_config.transformer.width // pipe.vae_scale_factor_spatial
|
|
208
|
+
rbln_config.transformer.max_seq_len = pipe.text_encoder.config.n_positions
|
|
209
|
+
rbln_config.transformer.embedding_dim = pipe.text_encoder.encoder.embed_tokens.embedding_dim
|
|
210
|
+
|
|
211
|
+
return rbln_config
|
|
212
|
+
|
|
213
|
+
@classmethod
|
|
214
|
+
def save_torch_artifacts(
|
|
215
|
+
cls,
|
|
216
|
+
model: "PreTrainedModel",
|
|
217
|
+
save_dir_path: Path,
|
|
218
|
+
subfolder: str,
|
|
219
|
+
rbln_config: RBLNModelConfig,
|
|
220
|
+
):
|
|
221
|
+
save_dict = {}
|
|
222
|
+
save_dict["rope"] = model.rope.state_dict()
|
|
223
|
+
if model.learnable_pos_embed is not None:
|
|
224
|
+
save_dict["learnable_pos_embed"] = model.learnable_pos_embed.state_dict()
|
|
225
|
+
save_dict["patch_embed"] = model.patch_embed.state_dict()
|
|
226
|
+
save_dict["time_embed"] = model.time_embed.state_dict()
|
|
227
|
+
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
228
|
+
|
|
229
|
+
@classmethod
|
|
230
|
+
def _update_rbln_config(
|
|
231
|
+
cls,
|
|
232
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
233
|
+
model: "PreTrainedModel",
|
|
234
|
+
model_config: "PretrainedConfig",
|
|
235
|
+
rbln_config: "RBLNCosmosTransformer3DModelConfig",
|
|
236
|
+
) -> RBLNCosmosTransformer3DModelConfig:
|
|
237
|
+
p_t, p_h, p_w = model_config.patch_size
|
|
238
|
+
hidden_dim = (
|
|
239
|
+
(rbln_config.num_latent_frames // p_t)
|
|
240
|
+
* (rbln_config.latent_height // p_h)
|
|
241
|
+
* (rbln_config.latent_width // p_w)
|
|
242
|
+
)
|
|
243
|
+
attention_head_dim = model_config.attention_head_dim
|
|
244
|
+
hidden_size = model.config.num_attention_heads * model.config.attention_head_dim
|
|
245
|
+
input_info = [
|
|
246
|
+
(
|
|
247
|
+
"hidden_states",
|
|
248
|
+
[
|
|
249
|
+
rbln_config.batch_size,
|
|
250
|
+
hidden_dim,
|
|
251
|
+
hidden_size,
|
|
252
|
+
],
|
|
253
|
+
"float32",
|
|
254
|
+
),
|
|
255
|
+
(
|
|
256
|
+
"encoder_hidden_states",
|
|
257
|
+
[
|
|
258
|
+
rbln_config.batch_size,
|
|
259
|
+
rbln_config.max_seq_len,
|
|
260
|
+
rbln_config.embedding_dim,
|
|
261
|
+
],
|
|
262
|
+
"float32",
|
|
263
|
+
),
|
|
264
|
+
("embedded_timestep", [rbln_config.batch_size, hidden_size], "float32"),
|
|
265
|
+
("temb", [1, hidden_size * 3], "float32"),
|
|
266
|
+
("image_rotary_emb_0", [hidden_dim, attention_head_dim], "float32"),
|
|
267
|
+
("image_rotary_emb_1", [hidden_dim, attention_head_dim], "float32"),
|
|
268
|
+
("extra_pos_emb", [rbln_config.batch_size, hidden_dim, hidden_size], "float32"),
|
|
269
|
+
]
|
|
270
|
+
|
|
271
|
+
compile_config = RBLNCompileConfig(input_info=input_info)
|
|
272
|
+
rbln_config.set_compile_cfgs([compile_config])
|
|
273
|
+
return rbln_config
|
|
274
|
+
|
|
275
|
+
@classmethod
|
|
276
|
+
def _create_runtimes(
|
|
277
|
+
cls,
|
|
278
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
|
279
|
+
rbln_config: RBLNModelConfig,
|
|
280
|
+
) -> List[rebel.Runtime]:
|
|
281
|
+
if DEFAULT_COMPILED_MODEL_NAME not in rbln_config.device_map:
|
|
282
|
+
cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
|
|
283
|
+
|
|
284
|
+
return [
|
|
285
|
+
rebel.Runtime(
|
|
286
|
+
compiled_model,
|
|
287
|
+
tensor_type="pt",
|
|
288
|
+
device=rbln_config.device_map[DEFAULT_COMPILED_MODEL_NAME],
|
|
289
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
290
|
+
timeout=rbln_config.timeout,
|
|
291
|
+
)
|
|
292
|
+
for compiled_model in compiled_models
|
|
293
|
+
]
|
|
294
|
+
|
|
295
|
+
def forward(
|
|
296
|
+
self,
|
|
297
|
+
hidden_states: torch.Tensor,
|
|
298
|
+
timestep: torch.Tensor,
|
|
299
|
+
encoder_hidden_states: torch.Tensor,
|
|
300
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
301
|
+
fps: Optional[int] = None,
|
|
302
|
+
condition_mask: Optional[torch.Tensor] = None,
|
|
303
|
+
padding_mask: Optional[torch.Tensor] = None,
|
|
304
|
+
return_dict: bool = True,
|
|
305
|
+
):
|
|
306
|
+
"""
|
|
307
|
+
Forward pass for the RBLN-optimized CosmosTransformer3DModel.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
hidden_states (torch.Tensor): The currently predicted image embeddings.
|
|
311
|
+
timestep (torch.Tensor): Current denoising step.
|
|
312
|
+
encoder_hidden_states (torch.Tensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
|
313
|
+
fps: (Optional[int]): Frames per second for the video being generated.
|
|
314
|
+
condition_mask (Optional[torch.Tensor]): Tensor of condition mask.
|
|
315
|
+
padding_mask (Optional[torch.Tensor]): Tensor of padding mask.
|
|
316
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
(Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
|
|
320
|
+
"""
|
|
321
|
+
(
|
|
322
|
+
hidden_states,
|
|
323
|
+
temb,
|
|
324
|
+
embedded_timestep,
|
|
325
|
+
image_rotary_emb_0,
|
|
326
|
+
image_rotary_emb_1,
|
|
327
|
+
extra_pos_emb,
|
|
328
|
+
attention_mask,
|
|
329
|
+
) = self.compute_embedding(hidden_states, timestep, attention_mask, fps, condition_mask, padding_mask)
|
|
330
|
+
|
|
331
|
+
hidden_states = self.model[0].forward(
|
|
332
|
+
hidden_states,
|
|
333
|
+
encoder_hidden_states,
|
|
334
|
+
embedded_timestep,
|
|
335
|
+
temb,
|
|
336
|
+
image_rotary_emb_0,
|
|
337
|
+
image_rotary_emb_1,
|
|
338
|
+
extra_pos_emb,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
if not return_dict:
|
|
342
|
+
return (hidden_states,)
|
|
343
|
+
else:
|
|
344
|
+
return Transformer2DModelOutput(sample=hidden_states)
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
|
19
|
+
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
|
20
|
+
from transformers import PretrainedConfig
|
|
21
|
+
|
|
22
|
+
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
23
|
+
from ....modeling import RBLNModel
|
|
24
|
+
from ....utils.logging import get_logger
|
|
25
|
+
from ...configurations import RBLNSD3Transformer2DModelConfig
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
30
|
+
|
|
31
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
|
32
|
+
|
|
33
|
+
logger = get_logger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class SD3Transformer2DModelWrapper(torch.nn.Module):
|
|
37
|
+
def __init__(self, model: "SD3Transformer2DModel") -> None:
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.model = model
|
|
40
|
+
|
|
41
|
+
def forward(
|
|
42
|
+
self,
|
|
43
|
+
hidden_states: torch.FloatTensor,
|
|
44
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
|
45
|
+
pooled_projections: torch.FloatTensor = None,
|
|
46
|
+
timestep: torch.LongTensor = None,
|
|
47
|
+
# need controlnet support?
|
|
48
|
+
block_controlnet_hidden_states: List = None,
|
|
49
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
50
|
+
return_dict: bool = True,
|
|
51
|
+
):
|
|
52
|
+
return self.model(
|
|
53
|
+
hidden_states=hidden_states,
|
|
54
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
55
|
+
pooled_projections=pooled_projections,
|
|
56
|
+
timestep=timestep,
|
|
57
|
+
return_dict=False,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class RBLNSD3Transformer2DModel(RBLNModel):
|
|
62
|
+
"""
|
|
63
|
+
RBLN implementation of SD3Transformer2DModel for diffusion models like Stable Diffusion 3.
|
|
64
|
+
|
|
65
|
+
The SD3Transformer2DModel takes text and/or image embeddings from encoders (like CLIP) and
|
|
66
|
+
maps them to a shared latent space that guides the diffusion process to generate the desired image.
|
|
67
|
+
|
|
68
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
|
69
|
+
the library implements for all its models.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
hf_library_name = "diffusers"
|
|
73
|
+
auto_model_class = SD3Transformer2DModel
|
|
74
|
+
_output_class = Transformer2DModelOutput
|
|
75
|
+
|
|
76
|
+
def __post_init__(self, **kwargs):
|
|
77
|
+
super().__post_init__(**kwargs)
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
81
|
+
return SD3Transformer2DModelWrapper(model).eval()
|
|
82
|
+
|
|
83
|
+
@classmethod
|
|
84
|
+
def update_rbln_config_using_pipe(
|
|
85
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
|
86
|
+
) -> "RBLNDiffusionMixinConfig":
|
|
87
|
+
if rbln_config.sample_size is None:
|
|
88
|
+
if rbln_config.image_size is not None:
|
|
89
|
+
rbln_config.transformer.sample_size = (
|
|
90
|
+
rbln_config.image_size[0] // pipe.vae_scale_factor,
|
|
91
|
+
rbln_config.image_size[1] // pipe.vae_scale_factor,
|
|
92
|
+
)
|
|
93
|
+
else:
|
|
94
|
+
rbln_config.transformer.sample_size = pipe.default_sample_size
|
|
95
|
+
|
|
96
|
+
prompt_embed_length = pipe.tokenizer_max_length + rbln_config.max_seq_len
|
|
97
|
+
rbln_config.transformer.prompt_embed_length = prompt_embed_length
|
|
98
|
+
return rbln_config
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def _update_rbln_config(
|
|
102
|
+
cls,
|
|
103
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
104
|
+
model: "PreTrainedModel",
|
|
105
|
+
model_config: "PretrainedConfig",
|
|
106
|
+
rbln_config: RBLNSD3Transformer2DModelConfig,
|
|
107
|
+
) -> RBLNSD3Transformer2DModelConfig:
|
|
108
|
+
if rbln_config.sample_size is None:
|
|
109
|
+
rbln_config.sample_size = model_config.sample_size
|
|
110
|
+
|
|
111
|
+
if isinstance(rbln_config.sample_size, int):
|
|
112
|
+
rbln_config.sample_size = (rbln_config.sample_size, rbln_config.sample_size)
|
|
113
|
+
|
|
114
|
+
input_info = [
|
|
115
|
+
(
|
|
116
|
+
"hidden_states",
|
|
117
|
+
[
|
|
118
|
+
rbln_config.batch_size,
|
|
119
|
+
model_config.in_channels,
|
|
120
|
+
rbln_config.sample_size[0],
|
|
121
|
+
rbln_config.sample_size[1],
|
|
122
|
+
],
|
|
123
|
+
"float32",
|
|
124
|
+
),
|
|
125
|
+
(
|
|
126
|
+
"encoder_hidden_states",
|
|
127
|
+
[
|
|
128
|
+
rbln_config.batch_size,
|
|
129
|
+
rbln_config.prompt_embed_length,
|
|
130
|
+
model_config.joint_attention_dim,
|
|
131
|
+
],
|
|
132
|
+
"float32",
|
|
133
|
+
),
|
|
134
|
+
(
|
|
135
|
+
"pooled_projections",
|
|
136
|
+
[
|
|
137
|
+
rbln_config.batch_size,
|
|
138
|
+
model_config.pooled_projection_dim,
|
|
139
|
+
],
|
|
140
|
+
"float32",
|
|
141
|
+
),
|
|
142
|
+
("timestep", [rbln_config.batch_size], "float32"),
|
|
143
|
+
]
|
|
144
|
+
|
|
145
|
+
compile_config = RBLNCompileConfig(input_info=input_info)
|
|
146
|
+
rbln_config.set_compile_cfgs([compile_config])
|
|
147
|
+
return rbln_config
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def compiled_batch_size(self):
|
|
151
|
+
return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
|
|
152
|
+
|
|
153
|
+
def forward(
|
|
154
|
+
self,
|
|
155
|
+
hidden_states: torch.FloatTensor,
|
|
156
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
|
157
|
+
pooled_projections: torch.FloatTensor = None,
|
|
158
|
+
timestep: torch.LongTensor = None,
|
|
159
|
+
block_controlnet_hidden_states: List = None,
|
|
160
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
161
|
+
return_dict: bool = True,
|
|
162
|
+
**kwargs,
|
|
163
|
+
):
|
|
164
|
+
"""
|
|
165
|
+
Forward pass for the RBLN-optimized SD3Transformer2DModel.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
hidden_states (torch.FloatTensor): The currently predicted image embeddings.
|
|
169
|
+
encoder_hidden_states (torch.FloatTensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
|
170
|
+
pooled_projections (torch.FloatTensor): Embeddings projected from the embeddings of input conditions.
|
|
171
|
+
timestep (torch.LongTensor): Current denoising step.
|
|
172
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
(Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
|
|
176
|
+
"""
|
|
177
|
+
sample_batch_size = hidden_states.size()[0]
|
|
178
|
+
compiled_batch_size = self.compiled_batch_size
|
|
179
|
+
if sample_batch_size != compiled_batch_size and (
|
|
180
|
+
sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
|
|
181
|
+
):
|
|
182
|
+
raise ValueError(
|
|
183
|
+
f"Mismatch between transformer's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
|
|
184
|
+
"This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
|
|
185
|
+
"Adjust the batch size of transformer during compilation.\n\n"
|
|
186
|
+
"For details, see: https://docs.rbln.ai/software/optimum/model_api/diffusers/pipelines/stable_diffusion_3.html#important-batch-size-configuration-for-guidance-scale"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
return super().forward(
|
|
190
|
+
hidden_states, encoder_hidden_states, pooled_projections, timestep, return_dict=return_dict
|
|
191
|
+
)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .unet_2d_condition import RBLNUNet2DConditionModel
|
|
16
|
+
from .unet_spatio_temporal_condition import RBLNUNetSpatioTemporalConditionModel
|