optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.4a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +12 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +16 -6
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +12 -8
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +242 -109
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +1 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +6 -45
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +0 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +10 -1
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +92 -43
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +207 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +140 -46
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +7 -1
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +1 -1
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -25
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -9
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -7
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +17 -1
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +1 -1
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +9 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +7 -1
- optimum/rbln/utils/runtime_utils.py +32 -0
- optimum/rbln/utils/submodule.py +3 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/METADATA +2 -2
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/RECORD +106 -99
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/WHEEL +1 -1
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
|
16
|
+
|
|
17
|
+
import rebel
|
|
18
|
+
import torch # noqa: I001
|
|
19
|
+
from diffusers import AutoencoderKLTemporalDecoder
|
|
20
|
+
from diffusers.models.autoencoders.vae import DecoderOutput
|
|
21
|
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
|
22
|
+
from transformers import PretrainedConfig
|
|
23
|
+
|
|
24
|
+
from ....configuration_utils import RBLNCompileConfig
|
|
25
|
+
from ....modeling import RBLNModel
|
|
26
|
+
from ....utils.logging import get_logger
|
|
27
|
+
from ...configurations import RBLNAutoencoderKLTemporalDecoderConfig
|
|
28
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
|
29
|
+
from .vae import (
|
|
30
|
+
DiagonalGaussianDistribution,
|
|
31
|
+
RBLNRuntimeVAEDecoder,
|
|
32
|
+
RBLNRuntimeVAEEncoder,
|
|
33
|
+
_VAEEncoder,
|
|
34
|
+
_VAETemporalDecoder,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
|
40
|
+
|
|
41
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
|
42
|
+
|
|
43
|
+
logger = get_logger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class RBLNAutoencoderKLTemporalDecoder(RBLNModel):
|
|
47
|
+
auto_model_class = AutoencoderKLTemporalDecoder
|
|
48
|
+
hf_library_name = "diffusers"
|
|
49
|
+
_rbln_config_class = RBLNAutoencoderKLTemporalDecoderConfig
|
|
50
|
+
|
|
51
|
+
def __post_init__(self, **kwargs):
|
|
52
|
+
super().__post_init__(**kwargs)
|
|
53
|
+
|
|
54
|
+
if self.rbln_config.uses_encoder:
|
|
55
|
+
self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
|
|
56
|
+
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[-1], main_input_name="z")
|
|
57
|
+
self.image_size = self.rbln_config.image_size
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def _wrap_model_if_needed(
|
|
61
|
+
cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLTemporalDecoderConfig
|
|
62
|
+
) -> torch.nn.Module:
|
|
63
|
+
decoder_model = _VAETemporalDecoder(model)
|
|
64
|
+
decoder_model.num_frames = rbln_config.decode_chunk_size
|
|
65
|
+
decoder_model.eval()
|
|
66
|
+
|
|
67
|
+
if rbln_config.uses_encoder:
|
|
68
|
+
encoder_model = _VAEEncoder(model)
|
|
69
|
+
encoder_model.eval()
|
|
70
|
+
return encoder_model, decoder_model
|
|
71
|
+
else:
|
|
72
|
+
return decoder_model
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def get_compiled_model(
|
|
76
|
+
cls, model, rbln_config: RBLNAutoencoderKLTemporalDecoderConfig
|
|
77
|
+
) -> Dict[str, rebel.RBLNCompiledModel]:
|
|
78
|
+
compiled_models = {}
|
|
79
|
+
if rbln_config.uses_encoder:
|
|
80
|
+
encoder_model, decoder_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
81
|
+
enc_compiled_model = cls.compile(
|
|
82
|
+
encoder_model,
|
|
83
|
+
rbln_compile_config=rbln_config.compile_cfgs[0],
|
|
84
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
85
|
+
device=rbln_config.device_map["encoder"],
|
|
86
|
+
)
|
|
87
|
+
compiled_models["encoder"] = enc_compiled_model
|
|
88
|
+
else:
|
|
89
|
+
decoder_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
90
|
+
dec_compiled_model = cls.compile(
|
|
91
|
+
decoder_model,
|
|
92
|
+
rbln_compile_config=rbln_config.compile_cfgs[-1],
|
|
93
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
94
|
+
device=rbln_config.device_map["decoder"],
|
|
95
|
+
)
|
|
96
|
+
compiled_models["decoder"] = dec_compiled_model
|
|
97
|
+
|
|
98
|
+
return compiled_models
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def get_vae_sample_size(
|
|
102
|
+
cls,
|
|
103
|
+
pipe: "RBLNDiffusionMixin",
|
|
104
|
+
rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
|
|
105
|
+
return_vae_scale_factor: bool = False,
|
|
106
|
+
) -> Tuple[int, int]:
|
|
107
|
+
sample_size = rbln_config.sample_size
|
|
108
|
+
if hasattr(pipe, "vae_scale_factor"):
|
|
109
|
+
vae_scale_factor = pipe.vae_scale_factor
|
|
110
|
+
else:
|
|
111
|
+
if hasattr(pipe.vae.config, "block_out_channels"):
|
|
112
|
+
vae_scale_factor = 2 ** (len(pipe.vae.config.block_out_channels) - 1)
|
|
113
|
+
else:
|
|
114
|
+
vae_scale_factor = 8 # vae image processor default value 8 (int)
|
|
115
|
+
|
|
116
|
+
if sample_size is None:
|
|
117
|
+
sample_size = pipe.unet.config.sample_size
|
|
118
|
+
if isinstance(sample_size, int):
|
|
119
|
+
sample_size = (sample_size, sample_size)
|
|
120
|
+
sample_size = (sample_size[0] * vae_scale_factor, sample_size[1] * vae_scale_factor)
|
|
121
|
+
|
|
122
|
+
if return_vae_scale_factor:
|
|
123
|
+
return sample_size, vae_scale_factor
|
|
124
|
+
else:
|
|
125
|
+
return sample_size
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def update_rbln_config_using_pipe(
|
|
129
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
|
130
|
+
) -> "RBLNDiffusionMixinConfig":
|
|
131
|
+
rbln_config.vae.sample_size, rbln_config.vae.vae_scale_factor = cls.get_vae_sample_size(
|
|
132
|
+
pipe, rbln_config.vae, return_vae_scale_factor=True
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
if rbln_config.vae.num_frames is None:
|
|
136
|
+
if hasattr(pipe.unet.config, "num_frames"):
|
|
137
|
+
rbln_config.vae.num_frames = pipe.unet.config.num_frames
|
|
138
|
+
else:
|
|
139
|
+
raise ValueError("num_frames should be specified in unet config.json")
|
|
140
|
+
|
|
141
|
+
if rbln_config.vae.decode_chunk_size is None:
|
|
142
|
+
rbln_config.vae.decode_chunk_size = rbln_config.vae.num_frames
|
|
143
|
+
|
|
144
|
+
def chunk_frame(num_frames, decode_chunk_size):
|
|
145
|
+
# get closest divisor to num_frames
|
|
146
|
+
divisors = [i for i in range(1, num_frames) if num_frames % i == 0]
|
|
147
|
+
closest = min(divisors, key=lambda x: abs(x - decode_chunk_size))
|
|
148
|
+
if decode_chunk_size != closest:
|
|
149
|
+
logger.warning(
|
|
150
|
+
f"To ensure successful model compilation and prevent device OOM, {decode_chunk_size} is set to {closest}."
|
|
151
|
+
)
|
|
152
|
+
return closest
|
|
153
|
+
|
|
154
|
+
decode_chunk_size = chunk_frame(rbln_config.vae.num_frames, rbln_config.vae.decode_chunk_size)
|
|
155
|
+
rbln_config.vae.decode_chunk_size = decode_chunk_size
|
|
156
|
+
return rbln_config
|
|
157
|
+
|
|
158
|
+
@classmethod
|
|
159
|
+
def _update_rbln_config(
|
|
160
|
+
cls,
|
|
161
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
162
|
+
model: "PreTrainedModel",
|
|
163
|
+
model_config: "PretrainedConfig",
|
|
164
|
+
rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
|
|
165
|
+
) -> RBLNAutoencoderKLTemporalDecoderConfig:
|
|
166
|
+
if rbln_config.sample_size is None:
|
|
167
|
+
rbln_config.sample_size = model_config.sample_size
|
|
168
|
+
|
|
169
|
+
if rbln_config.vae_scale_factor is None:
|
|
170
|
+
if hasattr(model_config, "block_out_channels"):
|
|
171
|
+
rbln_config.vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
|
|
172
|
+
else:
|
|
173
|
+
# vae image processor default value 8 (int)
|
|
174
|
+
rbln_config.vae_scale_factor = 8
|
|
175
|
+
|
|
176
|
+
compile_cfgs = []
|
|
177
|
+
if rbln_config.uses_encoder:
|
|
178
|
+
vae_enc_input_info = [
|
|
179
|
+
(
|
|
180
|
+
"x",
|
|
181
|
+
[
|
|
182
|
+
rbln_config.batch_size,
|
|
183
|
+
model_config.in_channels,
|
|
184
|
+
rbln_config.sample_size[0],
|
|
185
|
+
rbln_config.sample_size[1],
|
|
186
|
+
],
|
|
187
|
+
"float32",
|
|
188
|
+
)
|
|
189
|
+
]
|
|
190
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
|
|
191
|
+
|
|
192
|
+
decode_batch_size = rbln_config.batch_size * rbln_config.decode_chunk_size
|
|
193
|
+
vae_dec_input_info = [
|
|
194
|
+
(
|
|
195
|
+
"z",
|
|
196
|
+
[
|
|
197
|
+
decode_batch_size,
|
|
198
|
+
model_config.latent_channels,
|
|
199
|
+
rbln_config.latent_sample_size[0],
|
|
200
|
+
rbln_config.latent_sample_size[1],
|
|
201
|
+
],
|
|
202
|
+
"float32",
|
|
203
|
+
)
|
|
204
|
+
]
|
|
205
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
|
|
206
|
+
|
|
207
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
208
|
+
return rbln_config
|
|
209
|
+
|
|
210
|
+
@classmethod
|
|
211
|
+
def _create_runtimes(
|
|
212
|
+
cls,
|
|
213
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
|
214
|
+
rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
|
|
215
|
+
) -> List[rebel.Runtime]:
|
|
216
|
+
if len(compiled_models) == 1:
|
|
217
|
+
# decoder
|
|
218
|
+
expected_models = ["decoder"]
|
|
219
|
+
else:
|
|
220
|
+
expected_models = ["encoder", "decoder"]
|
|
221
|
+
|
|
222
|
+
if any(model_name not in rbln_config.device_map for model_name in expected_models):
|
|
223
|
+
cls._raise_missing_compiled_file_error(expected_models)
|
|
224
|
+
|
|
225
|
+
device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
|
|
226
|
+
return [
|
|
227
|
+
rebel.Runtime(
|
|
228
|
+
compiled_model,
|
|
229
|
+
tensor_type="pt",
|
|
230
|
+
device=device_val,
|
|
231
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
232
|
+
timeout=rbln_config.timeout,
|
|
233
|
+
)
|
|
234
|
+
for compiled_model, device_val in zip(compiled_models, device_vals)
|
|
235
|
+
]
|
|
236
|
+
|
|
237
|
+
def encode(
|
|
238
|
+
self, x: torch.FloatTensor, return_dict: bool = True
|
|
239
|
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
|
240
|
+
"""
|
|
241
|
+
Encode an input image into a latent representation.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
x: The input image to encode.
|
|
245
|
+
return_dict:
|
|
246
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
The latent representation or AutoencoderKLOutput if return_dict=True
|
|
250
|
+
"""
|
|
251
|
+
posterior = self.encoder.encode(x)
|
|
252
|
+
|
|
253
|
+
if not return_dict:
|
|
254
|
+
return (posterior,)
|
|
255
|
+
|
|
256
|
+
return AutoencoderKLOutput(latent_dist=posterior)
|
|
257
|
+
|
|
258
|
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
|
|
259
|
+
"""
|
|
260
|
+
Decode a latent representation into a video.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
z: The latent representation to decode.
|
|
264
|
+
return_dict:
|
|
265
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
The decoded video or DecoderOutput if return_dict=True
|
|
269
|
+
"""
|
|
270
|
+
decoded = self.decoder.decode(z)
|
|
271
|
+
|
|
272
|
+
if not return_dict:
|
|
273
|
+
return (decoded,)
|
|
274
|
+
|
|
275
|
+
return DecoderOutput(sample=decoded)
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING, List
|
|
15
|
+
from typing import TYPE_CHECKING, List, Union
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution, IdentityDistribution
|
|
@@ -21,7 +21,7 @@ from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
if TYPE_CHECKING:
|
|
24
|
-
from diffusers import AutoencoderKL, AutoencoderKLCosmos, VQModel
|
|
24
|
+
from diffusers import AutoencoderKL, AutoencoderKLCosmos, AutoencoderKLTemporalDecoder, VQModel
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
|
|
@@ -67,18 +67,37 @@ class _VAEDecoder(torch.nn.Module):
|
|
|
67
67
|
return vae_out
|
|
68
68
|
|
|
69
69
|
|
|
70
|
+
class _VAETemporalDecoder(torch.nn.Module):
|
|
71
|
+
def __init__(self, vae: "AutoencoderKLTemporalDecoder"):
|
|
72
|
+
super().__init__()
|
|
73
|
+
self.vae = vae
|
|
74
|
+
self.num_frames = None
|
|
75
|
+
|
|
76
|
+
def forward(self, z):
|
|
77
|
+
vae_out = self.vae.decode(z, num_frames=self.num_frames, return_dict=False)
|
|
78
|
+
return vae_out
|
|
79
|
+
|
|
80
|
+
|
|
70
81
|
class _VAEEncoder(torch.nn.Module):
|
|
71
|
-
def __init__(self, vae: "AutoencoderKL"):
|
|
82
|
+
def __init__(self, vae: Union["AutoencoderKL", "AutoencoderKLTemporalDecoder"]):
|
|
72
83
|
super().__init__()
|
|
73
84
|
self.vae = vae
|
|
74
85
|
|
|
75
86
|
def encode(self, x: torch.FloatTensor, return_dict: bool = True):
|
|
76
|
-
if self
|
|
77
|
-
|
|
87
|
+
if hasattr(self, "use_tiling") and hasattr(self, "use_slicing"):
|
|
88
|
+
if self.use_tiling and (
|
|
89
|
+
x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size
|
|
90
|
+
):
|
|
91
|
+
return self.tiled_encode(x, return_dict=return_dict)
|
|
92
|
+
|
|
93
|
+
if self.use_slicing and x.shape[0] > 1:
|
|
94
|
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
|
95
|
+
h = torch.cat(encoded_slices)
|
|
96
|
+
else:
|
|
97
|
+
h = self.encoder(x)
|
|
98
|
+
if self.quant_conv is not None:
|
|
99
|
+
h = self.quant_conv(h)
|
|
78
100
|
|
|
79
|
-
if self.use_slicing and x.shape[0] > 1:
|
|
80
|
-
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
|
81
|
-
h = torch.cat(encoded_slices)
|
|
82
101
|
else:
|
|
83
102
|
h = self.encoder(x)
|
|
84
103
|
if self.quant_conv is not None:
|
|
@@ -118,7 +118,7 @@ class RBLNControlNetModel(RBLNModel):
|
|
|
118
118
|
)
|
|
119
119
|
|
|
120
120
|
@classmethod
|
|
121
|
-
def
|
|
121
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
122
122
|
use_encoder_hidden_states = False
|
|
123
123
|
for down_block in model.down_blocks:
|
|
124
124
|
if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
|
|
@@ -215,10 +215,25 @@ class RBLNControlNetModel(RBLNModel):
|
|
|
215
215
|
encoder_hidden_states: torch.Tensor,
|
|
216
216
|
controlnet_cond: torch.FloatTensor,
|
|
217
217
|
conditioning_scale: torch.Tensor = 1.0,
|
|
218
|
-
added_cond_kwargs: Dict[str, torch.Tensor] =
|
|
218
|
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
|
219
219
|
return_dict: bool = True,
|
|
220
220
|
**kwargs,
|
|
221
221
|
):
|
|
222
|
+
"""
|
|
223
|
+
Forward pass for the RBLN-optimized ControlNetModel.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
sample (torch.FloatTensor): The noisy input tensor.
|
|
227
|
+
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
|
|
228
|
+
encoder_hidden_states (torch.Tensor): The encoder hidden states.
|
|
229
|
+
controlnet_cond (torch.FloatTensor): The conditional input tensor of shape `(batch_size, max_seq_len, hidden_size)`.
|
|
230
|
+
conditioning_scale (torch.Tensor): The scale factor for ControlNet outputs.
|
|
231
|
+
added_cond_kwargs (Dict[str, torch.Tensor]): Additional conditions for the Stable Diffusion XL UNet.
|
|
232
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
(Union[`~diffusers.models.controlnets.controlnet.ControlNetOutput`], Tuple)
|
|
236
|
+
"""
|
|
222
237
|
sample_batch_size = sample.size()[0]
|
|
223
238
|
compiled_batch_size = self.compiled_batch_size
|
|
224
239
|
if sample_batch_size != compiled_batch_size and (
|
|
@@ -77,7 +77,7 @@ class RBLNPriorTransformer(RBLNModel):
|
|
|
77
77
|
self.clip_std = artifacts["clip_std"]
|
|
78
78
|
|
|
79
79
|
@classmethod
|
|
80
|
-
def
|
|
80
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
81
81
|
return _PriorTransformer(model).eval()
|
|
82
82
|
|
|
83
83
|
@classmethod
|
|
@@ -128,13 +128,27 @@ class RBLNPriorTransformer(RBLNModel):
|
|
|
128
128
|
|
|
129
129
|
def forward(
|
|
130
130
|
self,
|
|
131
|
-
hidden_states,
|
|
131
|
+
hidden_states: torch.Tensor,
|
|
132
132
|
timestep: Union[torch.Tensor, float, int],
|
|
133
133
|
proj_embedding: torch.Tensor,
|
|
134
134
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
135
135
|
attention_mask: Optional[torch.Tensor] = None,
|
|
136
136
|
return_dict: bool = True,
|
|
137
137
|
):
|
|
138
|
+
"""
|
|
139
|
+
Forward pass for the RBLN-optimized PriorTransformer.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
hidden_states (torch.Tensor): The currently predicted image embeddings.
|
|
143
|
+
timestep (Union[torch.Tensor, float, int]): Current denoising step.
|
|
144
|
+
proj_embedding (torch.Tensor): Projected embedding vector the denoising process is conditioned on.
|
|
145
|
+
encoder_hidden_states (Optional[torch.Tensor]): Hidden states of the text embeddings the denoising process is conditioned on.
|
|
146
|
+
attention_mask (Optional[torch.Tensor]): Text mask for the text embeddings.
|
|
147
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`] instead of a plain tuple.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
(Union[`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`, Tuple])
|
|
151
|
+
"""
|
|
138
152
|
# Convert timestep(long) and attention_mask(bool) to float
|
|
139
153
|
return super().forward(
|
|
140
154
|
hidden_states,
|
|
@@ -185,7 +185,7 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
|
|
|
185
185
|
)
|
|
186
186
|
|
|
187
187
|
@classmethod
|
|
188
|
-
def
|
|
188
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
189
189
|
num_latent_frames = rbln_config.num_latent_frames
|
|
190
190
|
latent_height = rbln_config.latent_height
|
|
191
191
|
latent_width = rbln_config.latent_width
|
|
@@ -303,6 +303,21 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
|
|
|
303
303
|
padding_mask: Optional[torch.Tensor] = None,
|
|
304
304
|
return_dict: bool = True,
|
|
305
305
|
):
|
|
306
|
+
"""
|
|
307
|
+
Forward pass for the RBLN-optimized CosmosTransformer3DModel.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
hidden_states (torch.Tensor): The currently predicted image embeddings.
|
|
311
|
+
timestep (torch.Tensor): Current denoising step.
|
|
312
|
+
encoder_hidden_states (torch.Tensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
|
313
|
+
fps: (Optional[int]): Frames per second for the video being generated.
|
|
314
|
+
condition_mask (Optional[torch.Tensor]): Tensor of condition mask.
|
|
315
|
+
padding_mask (Optional[torch.Tensor]): Tensor of padding mask.
|
|
316
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
(Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
|
|
320
|
+
"""
|
|
306
321
|
(
|
|
307
322
|
hidden_states,
|
|
308
323
|
temb,
|
|
@@ -77,7 +77,7 @@ class RBLNSD3Transformer2DModel(RBLNModel):
|
|
|
77
77
|
super().__post_init__(**kwargs)
|
|
78
78
|
|
|
79
79
|
@classmethod
|
|
80
|
-
def
|
|
80
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
81
81
|
return SD3Transformer2DModelWrapper(model).eval()
|
|
82
82
|
|
|
83
83
|
@classmethod
|
|
@@ -161,6 +161,19 @@ class RBLNSD3Transformer2DModel(RBLNModel):
|
|
|
161
161
|
return_dict: bool = True,
|
|
162
162
|
**kwargs,
|
|
163
163
|
):
|
|
164
|
+
"""
|
|
165
|
+
Forward pass for the RBLN-optimized SD3Transformer2DModel.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
hidden_states (torch.FloatTensor): The currently predicted image embeddings.
|
|
169
|
+
encoder_hidden_states (torch.FloatTensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
|
170
|
+
pooled_projections (torch.FloatTensor): Embeddings projected from the embeddings of input conditions.
|
|
171
|
+
timestep (torch.LongTensor): Current denoising step.
|
|
172
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
(Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
|
|
176
|
+
"""
|
|
164
177
|
sample_batch_size = hidden_states.size()[0]
|
|
165
178
|
compiled_batch_size = self.compiled_batch_size
|
|
166
179
|
if sample_batch_size != compiled_batch_size and (
|
|
@@ -171,7 +171,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
|
171
171
|
self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
|
|
172
172
|
|
|
173
173
|
@classmethod
|
|
174
|
-
def
|
|
174
|
+
def _wrap_model_if_needed(
|
|
175
175
|
cls, model: torch.nn.Module, rbln_config: RBLNUNet2DConditionModelConfig
|
|
176
176
|
) -> torch.nn.Module:
|
|
177
177
|
if model.config.addition_embed_type == "text_time":
|
|
@@ -341,7 +341,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
|
341
341
|
timestep_cond: Optional[torch.Tensor] = None,
|
|
342
342
|
attention_mask: Optional[torch.Tensor] = None,
|
|
343
343
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
344
|
-
added_cond_kwargs: Dict[str, torch.Tensor] =
|
|
344
|
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
|
345
345
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
|
346
346
|
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
|
347
347
|
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
|
@@ -349,6 +349,22 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
|
349
349
|
return_dict: bool = True,
|
|
350
350
|
**kwargs,
|
|
351
351
|
) -> Union[UNet2DConditionOutput, Tuple]:
|
|
352
|
+
"""
|
|
353
|
+
Forward pass for the RBLN-optimized UNet2DConditionModel.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
sample (torch.Tensor): The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
|
357
|
+
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
|
|
358
|
+
encoder_hidden_states (torch.Tensor): The encoder hidden states.
|
|
359
|
+
added_cond_kwargs (Dict[str, torch.Tensor]): A kwargs dictionary containing additional embeddings that
|
|
360
|
+
if specified are added to the embeddings that are passed along to the UNet blocks.
|
|
361
|
+
down_block_additional_residuals (Optional[Tuple[torch.Tensor]]): A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
|
362
|
+
mid_block_additional_residual (Optional[torch.Tensor]): A tensor that if specified is added to the residual of the middle unet block.
|
|
363
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
(Union[`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`], Tuple)
|
|
367
|
+
"""
|
|
352
368
|
sample_batch_size = sample.size()[0]
|
|
353
369
|
compiled_batch_size = self.compiled_batch_size
|
|
354
370
|
if sample_batch_size != compiled_batch_size and (
|