optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +116 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +171 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +33 -18
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +50 -24
- optimum/rbln/modeling_base.py +116 -35
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +100 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +93 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +22 -50
- optimum/rbln/utils/runtime_utils.py +85 -17
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
|
16
|
+
|
|
17
|
+
import rebel
|
|
18
|
+
import torch # noqa: I001
|
|
19
|
+
from diffusers import AutoencoderKLTemporalDecoder
|
|
20
|
+
from diffusers.models.autoencoders.vae import DecoderOutput
|
|
21
|
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
|
22
|
+
from transformers import PretrainedConfig
|
|
23
|
+
|
|
24
|
+
from ....configuration_utils import RBLNCompileConfig
|
|
25
|
+
from ....modeling import RBLNModel
|
|
26
|
+
from ....utils.logging import get_logger
|
|
27
|
+
from ...configurations import RBLNAutoencoderKLTemporalDecoderConfig
|
|
28
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
|
29
|
+
from .vae import (
|
|
30
|
+
DiagonalGaussianDistribution,
|
|
31
|
+
RBLNRuntimeVAEDecoder,
|
|
32
|
+
RBLNRuntimeVAEEncoder,
|
|
33
|
+
_VAEEncoder,
|
|
34
|
+
_VAETemporalDecoder,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
|
40
|
+
|
|
41
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
|
42
|
+
|
|
43
|
+
logger = get_logger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class RBLNAutoencoderKLTemporalDecoder(RBLNModel):
|
|
47
|
+
auto_model_class = AutoencoderKLTemporalDecoder
|
|
48
|
+
hf_library_name = "diffusers"
|
|
49
|
+
_rbln_config_class = RBLNAutoencoderKLTemporalDecoderConfig
|
|
50
|
+
|
|
51
|
+
def __post_init__(self, **kwargs):
|
|
52
|
+
super().__post_init__(**kwargs)
|
|
53
|
+
|
|
54
|
+
if self.rbln_config.uses_encoder:
|
|
55
|
+
self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
|
|
56
|
+
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[-1], main_input_name="z")
|
|
57
|
+
self.image_size = self.rbln_config.image_size
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def _wrap_model_if_needed(
|
|
61
|
+
cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLTemporalDecoderConfig
|
|
62
|
+
) -> torch.nn.Module:
|
|
63
|
+
decoder_model = _VAETemporalDecoder(model)
|
|
64
|
+
decoder_model.num_frames = rbln_config.decode_chunk_size
|
|
65
|
+
decoder_model.eval()
|
|
66
|
+
|
|
67
|
+
if rbln_config.uses_encoder:
|
|
68
|
+
encoder_model = _VAEEncoder(model)
|
|
69
|
+
encoder_model.eval()
|
|
70
|
+
return encoder_model, decoder_model
|
|
71
|
+
else:
|
|
72
|
+
return decoder_model
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def get_compiled_model(
|
|
76
|
+
cls, model, rbln_config: RBLNAutoencoderKLTemporalDecoderConfig
|
|
77
|
+
) -> Dict[str, rebel.RBLNCompiledModel]:
|
|
78
|
+
compiled_models = {}
|
|
79
|
+
if rbln_config.uses_encoder:
|
|
80
|
+
encoder_model, decoder_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
81
|
+
enc_compiled_model = cls.compile(
|
|
82
|
+
encoder_model,
|
|
83
|
+
rbln_compile_config=rbln_config.compile_cfgs[0],
|
|
84
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
85
|
+
device=rbln_config.device_map["encoder"],
|
|
86
|
+
)
|
|
87
|
+
compiled_models["encoder"] = enc_compiled_model
|
|
88
|
+
else:
|
|
89
|
+
decoder_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
90
|
+
dec_compiled_model = cls.compile(
|
|
91
|
+
decoder_model,
|
|
92
|
+
rbln_compile_config=rbln_config.compile_cfgs[-1],
|
|
93
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
94
|
+
device=rbln_config.device_map["decoder"],
|
|
95
|
+
)
|
|
96
|
+
compiled_models["decoder"] = dec_compiled_model
|
|
97
|
+
|
|
98
|
+
return compiled_models
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def get_vae_sample_size(
|
|
102
|
+
cls,
|
|
103
|
+
pipe: "RBLNDiffusionMixin",
|
|
104
|
+
rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
|
|
105
|
+
return_vae_scale_factor: bool = False,
|
|
106
|
+
) -> Tuple[int, int]:
|
|
107
|
+
sample_size = rbln_config.sample_size
|
|
108
|
+
if hasattr(pipe, "vae_scale_factor"):
|
|
109
|
+
vae_scale_factor = pipe.vae_scale_factor
|
|
110
|
+
else:
|
|
111
|
+
if hasattr(pipe.vae.config, "block_out_channels"):
|
|
112
|
+
vae_scale_factor = 2 ** (len(pipe.vae.config.block_out_channels) - 1)
|
|
113
|
+
else:
|
|
114
|
+
vae_scale_factor = 8 # vae image processor default value 8 (int)
|
|
115
|
+
|
|
116
|
+
if sample_size is None:
|
|
117
|
+
sample_size = pipe.unet.config.sample_size
|
|
118
|
+
if isinstance(sample_size, int):
|
|
119
|
+
sample_size = (sample_size, sample_size)
|
|
120
|
+
sample_size = (sample_size[0] * vae_scale_factor, sample_size[1] * vae_scale_factor)
|
|
121
|
+
|
|
122
|
+
if return_vae_scale_factor:
|
|
123
|
+
return sample_size, vae_scale_factor
|
|
124
|
+
else:
|
|
125
|
+
return sample_size
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def update_rbln_config_using_pipe(
|
|
129
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
|
130
|
+
) -> "RBLNDiffusionMixinConfig":
|
|
131
|
+
rbln_config.vae.sample_size, rbln_config.vae.vae_scale_factor = cls.get_vae_sample_size(
|
|
132
|
+
pipe, rbln_config.vae, return_vae_scale_factor=True
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
if rbln_config.vae.num_frames is None:
|
|
136
|
+
if hasattr(pipe.unet.config, "num_frames"):
|
|
137
|
+
rbln_config.vae.num_frames = pipe.unet.config.num_frames
|
|
138
|
+
else:
|
|
139
|
+
raise ValueError("num_frames should be specified in unet config.json")
|
|
140
|
+
|
|
141
|
+
if rbln_config.vae.decode_chunk_size is None:
|
|
142
|
+
rbln_config.vae.decode_chunk_size = rbln_config.vae.num_frames
|
|
143
|
+
|
|
144
|
+
def chunk_frame(num_frames, decode_chunk_size):
|
|
145
|
+
# get closest divisor to num_frames
|
|
146
|
+
divisors = [i for i in range(1, num_frames) if num_frames % i == 0]
|
|
147
|
+
closest = min(divisors, key=lambda x: abs(x - decode_chunk_size))
|
|
148
|
+
if decode_chunk_size != closest:
|
|
149
|
+
logger.warning(
|
|
150
|
+
f"To ensure successful model compilation and prevent device OOM, {decode_chunk_size} is set to {closest}."
|
|
151
|
+
)
|
|
152
|
+
return closest
|
|
153
|
+
|
|
154
|
+
decode_chunk_size = chunk_frame(rbln_config.vae.num_frames, rbln_config.vae.decode_chunk_size)
|
|
155
|
+
rbln_config.vae.decode_chunk_size = decode_chunk_size
|
|
156
|
+
return rbln_config
|
|
157
|
+
|
|
158
|
+
@classmethod
|
|
159
|
+
def _update_rbln_config(
|
|
160
|
+
cls,
|
|
161
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
162
|
+
model: "PreTrainedModel",
|
|
163
|
+
model_config: "PretrainedConfig",
|
|
164
|
+
rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
|
|
165
|
+
) -> RBLNAutoencoderKLTemporalDecoderConfig:
|
|
166
|
+
if rbln_config.sample_size is None:
|
|
167
|
+
rbln_config.sample_size = model_config.sample_size
|
|
168
|
+
|
|
169
|
+
if rbln_config.vae_scale_factor is None:
|
|
170
|
+
if hasattr(model_config, "block_out_channels"):
|
|
171
|
+
rbln_config.vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
|
|
172
|
+
else:
|
|
173
|
+
# vae image processor default value 8 (int)
|
|
174
|
+
rbln_config.vae_scale_factor = 8
|
|
175
|
+
|
|
176
|
+
compile_cfgs = []
|
|
177
|
+
if rbln_config.uses_encoder:
|
|
178
|
+
vae_enc_input_info = [
|
|
179
|
+
(
|
|
180
|
+
"x",
|
|
181
|
+
[
|
|
182
|
+
rbln_config.batch_size,
|
|
183
|
+
model_config.in_channels,
|
|
184
|
+
rbln_config.sample_size[0],
|
|
185
|
+
rbln_config.sample_size[1],
|
|
186
|
+
],
|
|
187
|
+
"float32",
|
|
188
|
+
)
|
|
189
|
+
]
|
|
190
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
|
|
191
|
+
|
|
192
|
+
decode_batch_size = rbln_config.batch_size * rbln_config.decode_chunk_size
|
|
193
|
+
vae_dec_input_info = [
|
|
194
|
+
(
|
|
195
|
+
"z",
|
|
196
|
+
[
|
|
197
|
+
decode_batch_size,
|
|
198
|
+
model_config.latent_channels,
|
|
199
|
+
rbln_config.latent_sample_size[0],
|
|
200
|
+
rbln_config.latent_sample_size[1],
|
|
201
|
+
],
|
|
202
|
+
"float32",
|
|
203
|
+
)
|
|
204
|
+
]
|
|
205
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
|
|
206
|
+
|
|
207
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
208
|
+
return rbln_config
|
|
209
|
+
|
|
210
|
+
@classmethod
|
|
211
|
+
def _create_runtimes(
|
|
212
|
+
cls,
|
|
213
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
|
214
|
+
rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
|
|
215
|
+
) -> List[rebel.Runtime]:
|
|
216
|
+
if len(compiled_models) == 1:
|
|
217
|
+
# decoder
|
|
218
|
+
expected_models = ["decoder"]
|
|
219
|
+
else:
|
|
220
|
+
expected_models = ["encoder", "decoder"]
|
|
221
|
+
|
|
222
|
+
if any(model_name not in rbln_config.device_map for model_name in expected_models):
|
|
223
|
+
cls._raise_missing_compiled_file_error(expected_models)
|
|
224
|
+
|
|
225
|
+
device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
|
|
226
|
+
return [
|
|
227
|
+
rebel.Runtime(
|
|
228
|
+
compiled_model,
|
|
229
|
+
tensor_type="pt",
|
|
230
|
+
device=device_val,
|
|
231
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
232
|
+
timeout=rbln_config.timeout,
|
|
233
|
+
)
|
|
234
|
+
for compiled_model, device_val in zip(compiled_models, device_vals)
|
|
235
|
+
]
|
|
236
|
+
|
|
237
|
+
def encode(
|
|
238
|
+
self, x: torch.FloatTensor, return_dict: bool = True
|
|
239
|
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
|
240
|
+
"""
|
|
241
|
+
Encode an input image into a latent representation.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
x: The input image to encode.
|
|
245
|
+
return_dict:
|
|
246
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
The latent representation or AutoencoderKLOutput if return_dict=True
|
|
250
|
+
"""
|
|
251
|
+
posterior = self.encoder.encode(x)
|
|
252
|
+
|
|
253
|
+
if not return_dict:
|
|
254
|
+
return (posterior,)
|
|
255
|
+
|
|
256
|
+
return AutoencoderKLOutput(latent_dist=posterior)
|
|
257
|
+
|
|
258
|
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
|
|
259
|
+
"""
|
|
260
|
+
Decode a latent representation into a video.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
z: The latent representation to decode.
|
|
264
|
+
return_dict:
|
|
265
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
The decoded video or DecoderOutput if return_dict=True
|
|
269
|
+
"""
|
|
270
|
+
decoded = self.decoder.decode(z)
|
|
271
|
+
|
|
272
|
+
if not return_dict:
|
|
273
|
+
return (decoded,)
|
|
274
|
+
|
|
275
|
+
return DecoderOutput(sample=decoded)
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING, List
|
|
15
|
+
from typing import TYPE_CHECKING, List, Union
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution, IdentityDistribution
|
|
@@ -21,7 +21,7 @@ from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
if TYPE_CHECKING:
|
|
24
|
-
from diffusers import AutoencoderKL, AutoencoderKLCosmos, VQModel
|
|
24
|
+
from diffusers import AutoencoderKL, AutoencoderKLCosmos, AutoencoderKLTemporalDecoder, VQModel
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
|
|
@@ -67,18 +67,37 @@ class _VAEDecoder(torch.nn.Module):
|
|
|
67
67
|
return vae_out
|
|
68
68
|
|
|
69
69
|
|
|
70
|
+
class _VAETemporalDecoder(torch.nn.Module):
|
|
71
|
+
def __init__(self, vae: "AutoencoderKLTemporalDecoder"):
|
|
72
|
+
super().__init__()
|
|
73
|
+
self.vae = vae
|
|
74
|
+
self.num_frames = None
|
|
75
|
+
|
|
76
|
+
def forward(self, z):
|
|
77
|
+
vae_out = self.vae.decode(z, num_frames=self.num_frames, return_dict=False)
|
|
78
|
+
return vae_out
|
|
79
|
+
|
|
80
|
+
|
|
70
81
|
class _VAEEncoder(torch.nn.Module):
|
|
71
|
-
def __init__(self, vae: "AutoencoderKL"):
|
|
82
|
+
def __init__(self, vae: Union["AutoencoderKL", "AutoencoderKLTemporalDecoder"]):
|
|
72
83
|
super().__init__()
|
|
73
84
|
self.vae = vae
|
|
74
85
|
|
|
75
86
|
def encode(self, x: torch.FloatTensor, return_dict: bool = True):
|
|
76
|
-
if self
|
|
77
|
-
|
|
87
|
+
if hasattr(self, "use_tiling") and hasattr(self, "use_slicing"):
|
|
88
|
+
if self.use_tiling and (
|
|
89
|
+
x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size
|
|
90
|
+
):
|
|
91
|
+
return self.tiled_encode(x, return_dict=return_dict)
|
|
92
|
+
|
|
93
|
+
if self.use_slicing and x.shape[0] > 1:
|
|
94
|
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
|
95
|
+
h = torch.cat(encoded_slices)
|
|
96
|
+
else:
|
|
97
|
+
h = self.encoder(x)
|
|
98
|
+
if self.quant_conv is not None:
|
|
99
|
+
h = self.quant_conv(h)
|
|
78
100
|
|
|
79
|
-
if self.use_slicing and x.shape[0] > 1:
|
|
80
|
-
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
|
81
|
-
h = torch.cat(encoded_slices)
|
|
82
101
|
else:
|
|
83
102
|
h = self.encoder(x)
|
|
84
103
|
if self.quant_conv is not None:
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING, List, Union
|
|
15
|
+
from typing import TYPE_CHECKING, Any, List, Union
|
|
16
16
|
|
|
17
17
|
import rebel
|
|
18
18
|
import torch
|
|
@@ -165,17 +165,46 @@ class RBLNVQModel(RBLNModel):
|
|
|
165
165
|
tensor_type="pt",
|
|
166
166
|
device=device_val,
|
|
167
167
|
activate_profiler=rbln_config.activate_profiler,
|
|
168
|
+
timeout=rbln_config.timeout,
|
|
168
169
|
)
|
|
169
170
|
for compiled_model, device_val in zip(compiled_models, device_vals)
|
|
170
171
|
]
|
|
171
172
|
|
|
172
|
-
def encode(
|
|
173
|
+
def encode(
|
|
174
|
+
self, x: torch.FloatTensor, return_dict: bool = True, **kwargs: Any
|
|
175
|
+
) -> Union[torch.FloatTensor, VQEncoderOutput]:
|
|
176
|
+
"""
|
|
177
|
+
Encode an input image into a quantized latent representation.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
x: The input image to encode.
|
|
181
|
+
return_dict:
|
|
182
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
183
|
+
kwargs: Additional arguments to pass to the encoder/quantizer.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
The quantized latent representation or a specific output object.
|
|
187
|
+
"""
|
|
173
188
|
posterior = self.encoder.encode(x)
|
|
174
189
|
if not return_dict:
|
|
175
190
|
return (posterior,)
|
|
176
191
|
return VQEncoderOutput(latents=posterior)
|
|
177
192
|
|
|
178
|
-
def decode(
|
|
193
|
+
def decode(
|
|
194
|
+
self, h: torch.FloatTensor, return_dict: bool = True, **kwargs: Any
|
|
195
|
+
) -> Union[torch.FloatTensor, DecoderOutput]:
|
|
196
|
+
"""
|
|
197
|
+
Decode a quantized latent representation back into an image.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
h: The quantized latent representation to decode.
|
|
201
|
+
return_dict:
|
|
202
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
203
|
+
kwargs: Additional arguments to pass to the decoder.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
The decoded image or a DecoderOutput object.
|
|
207
|
+
"""
|
|
179
208
|
dec, commit_loss = self.decoder.decode(h, **kwargs)
|
|
180
209
|
if not return_dict:
|
|
181
210
|
return (dec, commit_loss)
|
|
@@ -118,7 +118,7 @@ class RBLNControlNetModel(RBLNModel):
|
|
|
118
118
|
)
|
|
119
119
|
|
|
120
120
|
@classmethod
|
|
121
|
-
def
|
|
121
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
122
122
|
use_encoder_hidden_states = False
|
|
123
123
|
for down_block in model.down_blocks:
|
|
124
124
|
if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
|
|
@@ -219,6 +219,21 @@ class RBLNControlNetModel(RBLNModel):
|
|
|
219
219
|
return_dict: bool = True,
|
|
220
220
|
**kwargs,
|
|
221
221
|
):
|
|
222
|
+
"""
|
|
223
|
+
Forward pass for the RBLN-optimized ControlNetModel.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
sample (torch.FloatTensor): The noisy input tensor.
|
|
227
|
+
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
|
|
228
|
+
encoder_hidden_states (torch.Tensor): The encoder hidden states.
|
|
229
|
+
controlnet_cond (torch.FloatTensor): The conditional input tensor of shape `(batch_size, max_seq_len, hidden_size)`.
|
|
230
|
+
conditioning_scale (torch.Tensor): The scale factor for ControlNet outputs.
|
|
231
|
+
added_cond_kwargs (Dict[str, torch.Tensor]): Additional conditions for the Stable Diffusion XL UNet.
|
|
232
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
(Union[`~diffusers.models.controlnets.controlnet.ControlNetOutput`], Tuple)
|
|
236
|
+
"""
|
|
222
237
|
sample_batch_size = sample.size()[0]
|
|
223
238
|
compiled_batch_size = self.compiled_batch_size
|
|
224
239
|
if sample_batch_size != compiled_batch_size and (
|
|
@@ -59,7 +59,7 @@ class RBLNPriorTransformer(RBLNModel):
|
|
|
59
59
|
"""
|
|
60
60
|
RBLN implementation of PriorTransformer for diffusion models like Kandinsky V2.2.
|
|
61
61
|
|
|
62
|
-
The
|
|
62
|
+
The PriorTransformer takes text and/or image embeddings from encoders (like CLIP) and
|
|
63
63
|
maps them to a shared latent space that guides the diffusion process to generate the desired image.
|
|
64
64
|
|
|
65
65
|
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
|
@@ -77,7 +77,7 @@ class RBLNPriorTransformer(RBLNModel):
|
|
|
77
77
|
self.clip_std = artifacts["clip_std"]
|
|
78
78
|
|
|
79
79
|
@classmethod
|
|
80
|
-
def
|
|
80
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
81
81
|
return _PriorTransformer(model).eval()
|
|
82
82
|
|
|
83
83
|
@classmethod
|
|
@@ -128,13 +128,27 @@ class RBLNPriorTransformer(RBLNModel):
|
|
|
128
128
|
|
|
129
129
|
def forward(
|
|
130
130
|
self,
|
|
131
|
-
hidden_states,
|
|
131
|
+
hidden_states: torch.Tensor,
|
|
132
132
|
timestep: Union[torch.Tensor, float, int],
|
|
133
133
|
proj_embedding: torch.Tensor,
|
|
134
134
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
135
135
|
attention_mask: Optional[torch.Tensor] = None,
|
|
136
136
|
return_dict: bool = True,
|
|
137
137
|
):
|
|
138
|
+
"""
|
|
139
|
+
Forward pass for the RBLN-optimized PriorTransformer.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
hidden_states (torch.Tensor): The currently predicted image embeddings.
|
|
143
|
+
timestep (Union[torch.Tensor, float, int]): Current denoising step.
|
|
144
|
+
proj_embedding (torch.Tensor): Projected embedding vector the denoising process is conditioned on.
|
|
145
|
+
encoder_hidden_states (Optional[torch.Tensor]): Hidden states of the text embeddings the denoising process is conditioned on.
|
|
146
|
+
attention_mask (Optional[torch.Tensor]): Text mask for the text embeddings.
|
|
147
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`] instead of a plain tuple.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
(Union[`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`, Tuple])
|
|
151
|
+
"""
|
|
138
152
|
# Convert timestep(long) and attention_mask(bool) to float
|
|
139
153
|
return super().forward(
|
|
140
154
|
hidden_states,
|
|
@@ -94,7 +94,15 @@ class CosmosTransformer3DModelWrapper(torch.nn.Module):
|
|
|
94
94
|
|
|
95
95
|
|
|
96
96
|
class RBLNCosmosTransformer3DModel(RBLNModel):
|
|
97
|
-
"""
|
|
97
|
+
"""
|
|
98
|
+
RBLN implementation of CosmosTransformer3DModel for diffusion models like Cosmos.
|
|
99
|
+
|
|
100
|
+
The CosmosTransformer3DModel takes text and/or image embeddings from encoders (like CLIP) and
|
|
101
|
+
maps them to a shared latent space that guides the diffusion process to generate the desired image.
|
|
102
|
+
|
|
103
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
|
104
|
+
the library implements for all its models.
|
|
105
|
+
"""
|
|
98
106
|
|
|
99
107
|
hf_library_name = "diffusers"
|
|
100
108
|
auto_model_class = CosmosTransformer3DModel
|
|
@@ -177,7 +185,7 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
|
|
|
177
185
|
)
|
|
178
186
|
|
|
179
187
|
@classmethod
|
|
180
|
-
def
|
|
188
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
181
189
|
num_latent_frames = rbln_config.num_latent_frames
|
|
182
190
|
latent_height = rbln_config.latent_height
|
|
183
191
|
latent_width = rbln_config.latent_width
|
|
@@ -279,7 +287,7 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
|
|
|
279
287
|
tensor_type="pt",
|
|
280
288
|
device=rbln_config.device_map[DEFAULT_COMPILED_MODEL_NAME],
|
|
281
289
|
activate_profiler=rbln_config.activate_profiler,
|
|
282
|
-
timeout=
|
|
290
|
+
timeout=rbln_config.timeout,
|
|
283
291
|
)
|
|
284
292
|
for compiled_model in compiled_models
|
|
285
293
|
]
|
|
@@ -295,6 +303,21 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
|
|
|
295
303
|
padding_mask: Optional[torch.Tensor] = None,
|
|
296
304
|
return_dict: bool = True,
|
|
297
305
|
):
|
|
306
|
+
"""
|
|
307
|
+
Forward pass for the RBLN-optimized CosmosTransformer3DModel.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
hidden_states (torch.Tensor): The currently predicted image embeddings.
|
|
311
|
+
timestep (torch.Tensor): Current denoising step.
|
|
312
|
+
encoder_hidden_states (torch.Tensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
|
313
|
+
fps: (Optional[int]): Frames per second for the video being generated.
|
|
314
|
+
condition_mask (Optional[torch.Tensor]): Tensor of condition mask.
|
|
315
|
+
padding_mask (Optional[torch.Tensor]): Tensor of padding mask.
|
|
316
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
(Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
|
|
320
|
+
"""
|
|
298
321
|
(
|
|
299
322
|
hidden_states,
|
|
300
323
|
temb,
|
|
@@ -59,7 +59,15 @@ class SD3Transformer2DModelWrapper(torch.nn.Module):
|
|
|
59
59
|
|
|
60
60
|
|
|
61
61
|
class RBLNSD3Transformer2DModel(RBLNModel):
|
|
62
|
-
"""
|
|
62
|
+
"""
|
|
63
|
+
RBLN implementation of SD3Transformer2DModel for diffusion models like Stable Diffusion 3.
|
|
64
|
+
|
|
65
|
+
The SD3Transformer2DModel takes text and/or image embeddings from encoders (like CLIP) and
|
|
66
|
+
maps them to a shared latent space that guides the diffusion process to generate the desired image.
|
|
67
|
+
|
|
68
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
|
69
|
+
the library implements for all its models.
|
|
70
|
+
"""
|
|
63
71
|
|
|
64
72
|
hf_library_name = "diffusers"
|
|
65
73
|
auto_model_class = SD3Transformer2DModel
|
|
@@ -69,7 +77,7 @@ class RBLNSD3Transformer2DModel(RBLNModel):
|
|
|
69
77
|
super().__post_init__(**kwargs)
|
|
70
78
|
|
|
71
79
|
@classmethod
|
|
72
|
-
def
|
|
80
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
73
81
|
return SD3Transformer2DModelWrapper(model).eval()
|
|
74
82
|
|
|
75
83
|
@classmethod
|
|
@@ -153,6 +161,19 @@ class RBLNSD3Transformer2DModel(RBLNModel):
|
|
|
153
161
|
return_dict: bool = True,
|
|
154
162
|
**kwargs,
|
|
155
163
|
):
|
|
164
|
+
"""
|
|
165
|
+
Forward pass for the RBLN-optimized SD3Transformer2DModel.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
hidden_states (torch.FloatTensor): The currently predicted image embeddings.
|
|
169
|
+
encoder_hidden_states (torch.FloatTensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
|
170
|
+
pooled_projections (torch.FloatTensor): Embeddings projected from the embeddings of input conditions.
|
|
171
|
+
timestep (torch.LongTensor): Current denoising step.
|
|
172
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
(Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
|
|
176
|
+
"""
|
|
156
177
|
sample_batch_size = hidden_states.size()[0]
|
|
157
178
|
compiled_batch_size = self.compiled_batch_size
|
|
158
179
|
if sample_batch_size != compiled_batch_size and (
|
|
@@ -141,10 +141,13 @@ class _UNet_Kandinsky(torch.nn.Module):
|
|
|
141
141
|
|
|
142
142
|
class RBLNUNet2DConditionModel(RBLNModel):
|
|
143
143
|
"""
|
|
144
|
-
|
|
144
|
+
RBLN implementation of UNet2DConditionModel for diffusion models.
|
|
145
145
|
|
|
146
|
-
This
|
|
147
|
-
|
|
146
|
+
This model is used to accelerate UNet2DCondition models from diffusers library on RBLN NPUs.
|
|
147
|
+
It is a key component in diffusion-based image generation models like Stable Diffusion.
|
|
148
|
+
|
|
149
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
|
150
|
+
the library implements for all its models.
|
|
148
151
|
"""
|
|
149
152
|
|
|
150
153
|
hf_library_name = "diffusers"
|
|
@@ -168,7 +171,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
|
168
171
|
self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
|
|
169
172
|
|
|
170
173
|
@classmethod
|
|
171
|
-
def
|
|
174
|
+
def _wrap_model_if_needed(
|
|
172
175
|
cls, model: torch.nn.Module, rbln_config: RBLNUNet2DConditionModelConfig
|
|
173
176
|
) -> torch.nn.Module:
|
|
174
177
|
if model.config.addition_embed_type == "text_time":
|
|
@@ -346,6 +349,22 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
|
346
349
|
return_dict: bool = True,
|
|
347
350
|
**kwargs,
|
|
348
351
|
) -> Union[UNet2DConditionOutput, Tuple]:
|
|
352
|
+
"""
|
|
353
|
+
Forward pass for the RBLN-optimized UNet2DConditionModel.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
sample (torch.Tensor): The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
|
357
|
+
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
|
|
358
|
+
encoder_hidden_states (torch.Tensor): The encoder hidden states.
|
|
359
|
+
added_cond_kwargs (Dict[str, torch.Tensor]): A kwargs dictionary containing additional embeddings that
|
|
360
|
+
if specified are added to the embeddings that are passed along to the UNet blocks.
|
|
361
|
+
down_block_additional_residuals (Optional[Tuple[torch.Tensor]]): A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
|
362
|
+
mid_block_additional_residual (Optional[torch.Tensor]): A tensor that if specified is added to the residual of the middle unet block.
|
|
363
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
(Union[`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`], Tuple)
|
|
367
|
+
"""
|
|
349
368
|
sample_batch_size = sample.size()[0]
|
|
350
369
|
compiled_batch_size = self.compiled_batch_size
|
|
351
370
|
if sample_batch_size != compiled_batch_size and (
|