optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +48 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +50 -21
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +156 -127
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +26 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +23 -2
- optimum/rbln/utils/runtime_utils.py +42 -6
- optimum/rbln/utils/submodule.py +27 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Optional
|
|
16
|
+
|
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
from ....transformers import RBLNCLIPVisionModelWithProjectionConfig
|
|
19
|
+
from ..models import RBLNAutoencoderKLTemporalDecoderConfig, RBLNUNetSpatioTemporalConditionModelConfig
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class RBLNStableVideoDiffusionPipelineConfig(RBLNModelConfig):
|
|
23
|
+
submodules = ["image_encoder", "unet", "vae"]
|
|
24
|
+
_vae_uses_encoder = True
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
image_encoder: Optional[RBLNCLIPVisionModelWithProjectionConfig] = None,
|
|
29
|
+
unet: Optional[RBLNUNetSpatioTemporalConditionModelConfig] = None,
|
|
30
|
+
vae: Optional[RBLNAutoencoderKLTemporalDecoderConfig] = None,
|
|
31
|
+
*,
|
|
32
|
+
batch_size: Optional[int] = None,
|
|
33
|
+
height: Optional[int] = None,
|
|
34
|
+
width: Optional[int] = None,
|
|
35
|
+
num_frames: Optional[int] = None,
|
|
36
|
+
decode_chunk_size: Optional[int] = None,
|
|
37
|
+
guidance_scale: Optional[float] = None,
|
|
38
|
+
**kwargs: Any,
|
|
39
|
+
):
|
|
40
|
+
"""
|
|
41
|
+
Args:
|
|
42
|
+
image_encoder (Optional[RBLNCLIPVisionModelWithProjectionConfig]): Configuration for the image encoder component.
|
|
43
|
+
Initialized as RBLNCLIPVisionModelWithProjectionConfig if not provided.
|
|
44
|
+
unet (Optional[RBLNUNetSpatioTemporalConditionModelConfig]): Configuration for the UNet model component.
|
|
45
|
+
Initialized as RBLNUNetSpatioTemporalConditionModelConfig if not provided.
|
|
46
|
+
vae (Optional[RBLNAutoencoderKLTemporalDecoderConfig]): Configuration for the VAE model component.
|
|
47
|
+
Initialized as RBLNAutoencoderKLTemporalDecoderConfig if not provided.
|
|
48
|
+
batch_size (Optional[int]): Batch size for inference, applied to all submodules.
|
|
49
|
+
height (Optional[int]): Height of the generated images.
|
|
50
|
+
width (Optional[int]): Width of the generated images.
|
|
51
|
+
num_frames (Optional[int]): The number of frames in the generated video.
|
|
52
|
+
decode_chunk_size (Optional[int]): The number of frames to decode at once during VAE decoding.
|
|
53
|
+
Useful for managing memory usage during video generation.
|
|
54
|
+
guidance_scale (Optional[float]): Scale for classifier-free guidance.
|
|
55
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
ValueError: If both image_size and height/width are provided.
|
|
59
|
+
|
|
60
|
+
Note:
|
|
61
|
+
When guidance_scale > 1.0, the UNet batch size is automatically doubled to
|
|
62
|
+
accommodate classifier-free guidance.
|
|
63
|
+
"""
|
|
64
|
+
super().__init__(**kwargs)
|
|
65
|
+
if height is not None and width is not None:
|
|
66
|
+
image_size = (height, width)
|
|
67
|
+
else:
|
|
68
|
+
# Get default image size from original class to set UNet, VAE image size
|
|
69
|
+
height = self.get_default_values_for_original_cls("__call__", ["height"])["height"]
|
|
70
|
+
width = self.get_default_values_for_original_cls("__call__", ["width"])["width"]
|
|
71
|
+
image_size = (height, width)
|
|
72
|
+
|
|
73
|
+
self.image_encoder = self.initialize_submodule_config(
|
|
74
|
+
image_encoder, cls_name="RBLNCLIPVisionModelWithProjectionConfig", batch_size=batch_size
|
|
75
|
+
)
|
|
76
|
+
self.unet = self.initialize_submodule_config(
|
|
77
|
+
unet,
|
|
78
|
+
cls_name="RBLNUNetSpatioTemporalConditionModelConfig",
|
|
79
|
+
num_frames=num_frames,
|
|
80
|
+
)
|
|
81
|
+
self.vae = self.initialize_submodule_config(
|
|
82
|
+
vae,
|
|
83
|
+
cls_name="RBLNAutoencoderKLTemporalDecoderConfig",
|
|
84
|
+
batch_size=batch_size,
|
|
85
|
+
num_frames=num_frames,
|
|
86
|
+
decode_chunk_size=decode_chunk_size,
|
|
87
|
+
uses_encoder=self.__class__._vae_uses_encoder,
|
|
88
|
+
sample_size=image_size, # image size is equal to sample size in vae
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Get default guidance scale from original class to set UNet batch size
|
|
92
|
+
if guidance_scale is None:
|
|
93
|
+
guidance_scale = self.get_default_values_for_original_cls("__call__", ["max_guidance_scale"])[
|
|
94
|
+
"max_guidance_scale"
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
if not self.unet.batch_size_is_specified:
|
|
98
|
+
do_classifier_free_guidance = guidance_scale > 1.0
|
|
99
|
+
if do_classifier_free_guidance:
|
|
100
|
+
self.unet.batch_size = self.image_encoder.batch_size * 2
|
|
101
|
+
else:
|
|
102
|
+
self.unet.batch_size = self.image_encoder.batch_size
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def batch_size(self):
|
|
106
|
+
return self.vae.batch_size
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def sample_size(self):
|
|
110
|
+
return self.unet.sample_size
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def image_size(self):
|
|
114
|
+
return self.vae.sample_size
|
|
@@ -136,7 +136,7 @@ class RBLNDiffusionMixin:
|
|
|
136
136
|
*,
|
|
137
137
|
export: bool = None,
|
|
138
138
|
model_save_dir: Optional[PathLike] = None,
|
|
139
|
-
rbln_config: Dict[str, Any] =
|
|
139
|
+
rbln_config: Optional[Dict[str, Any]] = None,
|
|
140
140
|
lora_ids: Optional[Union[str, List[str]]] = None,
|
|
141
141
|
lora_weights_names: Optional[Union[str, List[str]]] = None,
|
|
142
142
|
lora_scales: Optional[Union[float, List[float]]] = None,
|
|
@@ -22,9 +22,11 @@ _import_structure = {
|
|
|
22
22
|
"RBLNAutoencoderKL",
|
|
23
23
|
"RBLNAutoencoderKLCosmos",
|
|
24
24
|
"RBLNVQModel",
|
|
25
|
+
"RBLNAutoencoderKLTemporalDecoder",
|
|
25
26
|
],
|
|
26
27
|
"unets": [
|
|
27
28
|
"RBLNUNet2DConditionModel",
|
|
29
|
+
"RBLNUNetSpatioTemporalConditionModel",
|
|
28
30
|
],
|
|
29
31
|
"controlnet": ["RBLNControlNetModel"],
|
|
30
32
|
"transformers": [
|
|
@@ -35,10 +37,22 @@ _import_structure = {
|
|
|
35
37
|
}
|
|
36
38
|
|
|
37
39
|
if TYPE_CHECKING:
|
|
38
|
-
from .autoencoders import
|
|
40
|
+
from .autoencoders import (
|
|
41
|
+
RBLNAutoencoderKL,
|
|
42
|
+
RBLNAutoencoderKLCosmos,
|
|
43
|
+
RBLNAutoencoderKLTemporalDecoder,
|
|
44
|
+
RBLNVQModel,
|
|
45
|
+
)
|
|
39
46
|
from .controlnet import RBLNControlNetModel
|
|
40
|
-
from .transformers import
|
|
41
|
-
|
|
47
|
+
from .transformers import (
|
|
48
|
+
RBLNCosmosTransformer3DModel,
|
|
49
|
+
RBLNPriorTransformer,
|
|
50
|
+
RBLNSD3Transformer2DModel,
|
|
51
|
+
)
|
|
52
|
+
from .unets import (
|
|
53
|
+
RBLNUNet2DConditionModel,
|
|
54
|
+
RBLNUNetSpatioTemporalConditionModel,
|
|
55
|
+
)
|
|
42
56
|
else:
|
|
43
57
|
import sys
|
|
44
58
|
|
|
@@ -68,7 +68,7 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
|
|
|
68
68
|
self.image_size = self.rbln_config.image_size
|
|
69
69
|
|
|
70
70
|
@classmethod
|
|
71
|
-
def
|
|
71
|
+
def _wrap_model_if_needed(
|
|
72
72
|
cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLCosmosConfig
|
|
73
73
|
) -> torch.nn.Module:
|
|
74
74
|
decoder_model = _VAECosmosDecoder(model)
|
|
@@ -98,7 +98,7 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
|
|
|
98
98
|
|
|
99
99
|
compiled_models = {}
|
|
100
100
|
if rbln_config.uses_encoder:
|
|
101
|
-
encoder_model, decoder_model = cls.
|
|
101
|
+
encoder_model, decoder_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
102
102
|
enc_compiled_model = cls.compile(
|
|
103
103
|
encoder_model,
|
|
104
104
|
rbln_compile_config=rbln_config.compile_cfgs[0],
|
|
@@ -107,7 +107,7 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
|
|
|
107
107
|
)
|
|
108
108
|
compiled_models["encoder"] = enc_compiled_model
|
|
109
109
|
else:
|
|
110
|
-
decoder_model = cls.
|
|
110
|
+
decoder_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
111
111
|
dec_compiled_model = cls.compile(
|
|
112
112
|
decoder_model,
|
|
113
113
|
rbln_compile_config=rbln_config.compile_cfgs[-1],
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
|
16
|
+
|
|
17
|
+
import rebel
|
|
18
|
+
import torch # noqa: I001
|
|
19
|
+
from diffusers import AutoencoderKLTemporalDecoder
|
|
20
|
+
from diffusers.models.autoencoders.vae import DecoderOutput
|
|
21
|
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
|
22
|
+
from transformers import PretrainedConfig
|
|
23
|
+
|
|
24
|
+
from ....configuration_utils import RBLNCompileConfig
|
|
25
|
+
from ....modeling import RBLNModel
|
|
26
|
+
from ....utils.logging import get_logger
|
|
27
|
+
from ...configurations import RBLNAutoencoderKLTemporalDecoderConfig
|
|
28
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
|
29
|
+
from .vae import (
|
|
30
|
+
DiagonalGaussianDistribution,
|
|
31
|
+
RBLNRuntimeVAEDecoder,
|
|
32
|
+
RBLNRuntimeVAEEncoder,
|
|
33
|
+
_VAEEncoder,
|
|
34
|
+
_VAETemporalDecoder,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
|
40
|
+
|
|
41
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
|
42
|
+
|
|
43
|
+
logger = get_logger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class RBLNAutoencoderKLTemporalDecoder(RBLNModel):
|
|
47
|
+
auto_model_class = AutoencoderKLTemporalDecoder
|
|
48
|
+
hf_library_name = "diffusers"
|
|
49
|
+
_rbln_config_class = RBLNAutoencoderKLTemporalDecoderConfig
|
|
50
|
+
|
|
51
|
+
def __post_init__(self, **kwargs):
|
|
52
|
+
super().__post_init__(**kwargs)
|
|
53
|
+
|
|
54
|
+
if self.rbln_config.uses_encoder:
|
|
55
|
+
self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
|
|
56
|
+
self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[-1], main_input_name="z")
|
|
57
|
+
self.image_size = self.rbln_config.image_size
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def _wrap_model_if_needed(
|
|
61
|
+
cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLTemporalDecoderConfig
|
|
62
|
+
) -> torch.nn.Module:
|
|
63
|
+
decoder_model = _VAETemporalDecoder(model)
|
|
64
|
+
decoder_model.num_frames = rbln_config.decode_chunk_size
|
|
65
|
+
decoder_model.eval()
|
|
66
|
+
|
|
67
|
+
if rbln_config.uses_encoder:
|
|
68
|
+
encoder_model = _VAEEncoder(model)
|
|
69
|
+
encoder_model.eval()
|
|
70
|
+
return encoder_model, decoder_model
|
|
71
|
+
else:
|
|
72
|
+
return decoder_model
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def get_compiled_model(
|
|
76
|
+
cls, model, rbln_config: RBLNAutoencoderKLTemporalDecoderConfig
|
|
77
|
+
) -> Dict[str, rebel.RBLNCompiledModel]:
|
|
78
|
+
compiled_models = {}
|
|
79
|
+
if rbln_config.uses_encoder:
|
|
80
|
+
encoder_model, decoder_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
81
|
+
enc_compiled_model = cls.compile(
|
|
82
|
+
encoder_model,
|
|
83
|
+
rbln_compile_config=rbln_config.compile_cfgs[0],
|
|
84
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
85
|
+
device=rbln_config.device_map["encoder"],
|
|
86
|
+
)
|
|
87
|
+
compiled_models["encoder"] = enc_compiled_model
|
|
88
|
+
else:
|
|
89
|
+
decoder_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
90
|
+
dec_compiled_model = cls.compile(
|
|
91
|
+
decoder_model,
|
|
92
|
+
rbln_compile_config=rbln_config.compile_cfgs[-1],
|
|
93
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
94
|
+
device=rbln_config.device_map["decoder"],
|
|
95
|
+
)
|
|
96
|
+
compiled_models["decoder"] = dec_compiled_model
|
|
97
|
+
|
|
98
|
+
return compiled_models
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def get_vae_sample_size(
|
|
102
|
+
cls,
|
|
103
|
+
pipe: "RBLNDiffusionMixin",
|
|
104
|
+
rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
|
|
105
|
+
return_vae_scale_factor: bool = False,
|
|
106
|
+
) -> Tuple[int, int]:
|
|
107
|
+
sample_size = rbln_config.sample_size
|
|
108
|
+
if hasattr(pipe, "vae_scale_factor"):
|
|
109
|
+
vae_scale_factor = pipe.vae_scale_factor
|
|
110
|
+
else:
|
|
111
|
+
if hasattr(pipe.vae.config, "block_out_channels"):
|
|
112
|
+
vae_scale_factor = 2 ** (len(pipe.vae.config.block_out_channels) - 1)
|
|
113
|
+
else:
|
|
114
|
+
vae_scale_factor = 8 # vae image processor default value 8 (int)
|
|
115
|
+
|
|
116
|
+
if sample_size is None:
|
|
117
|
+
sample_size = pipe.unet.config.sample_size
|
|
118
|
+
if isinstance(sample_size, int):
|
|
119
|
+
sample_size = (sample_size, sample_size)
|
|
120
|
+
sample_size = (sample_size[0] * vae_scale_factor, sample_size[1] * vae_scale_factor)
|
|
121
|
+
|
|
122
|
+
if return_vae_scale_factor:
|
|
123
|
+
return sample_size, vae_scale_factor
|
|
124
|
+
else:
|
|
125
|
+
return sample_size
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def update_rbln_config_using_pipe(
|
|
129
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
|
130
|
+
) -> "RBLNDiffusionMixinConfig":
|
|
131
|
+
rbln_config.vae.sample_size, rbln_config.vae.vae_scale_factor = cls.get_vae_sample_size(
|
|
132
|
+
pipe, rbln_config.vae, return_vae_scale_factor=True
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
if rbln_config.vae.num_frames is None:
|
|
136
|
+
if hasattr(pipe.unet.config, "num_frames"):
|
|
137
|
+
rbln_config.vae.num_frames = pipe.unet.config.num_frames
|
|
138
|
+
else:
|
|
139
|
+
raise ValueError("num_frames should be specified in unet config.json")
|
|
140
|
+
|
|
141
|
+
if rbln_config.vae.decode_chunk_size is None:
|
|
142
|
+
rbln_config.vae.decode_chunk_size = rbln_config.vae.num_frames
|
|
143
|
+
|
|
144
|
+
def chunk_frame(num_frames, decode_chunk_size):
|
|
145
|
+
# get closest divisor to num_frames
|
|
146
|
+
divisors = [i for i in range(1, num_frames) if num_frames % i == 0]
|
|
147
|
+
closest = min(divisors, key=lambda x: abs(x - decode_chunk_size))
|
|
148
|
+
if decode_chunk_size != closest:
|
|
149
|
+
logger.warning(
|
|
150
|
+
f"To ensure successful model compilation and prevent device OOM, {decode_chunk_size} is set to {closest}."
|
|
151
|
+
)
|
|
152
|
+
return closest
|
|
153
|
+
|
|
154
|
+
decode_chunk_size = chunk_frame(rbln_config.vae.num_frames, rbln_config.vae.decode_chunk_size)
|
|
155
|
+
rbln_config.vae.decode_chunk_size = decode_chunk_size
|
|
156
|
+
return rbln_config
|
|
157
|
+
|
|
158
|
+
@classmethod
|
|
159
|
+
def _update_rbln_config(
|
|
160
|
+
cls,
|
|
161
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
162
|
+
model: "PreTrainedModel",
|
|
163
|
+
model_config: "PretrainedConfig",
|
|
164
|
+
rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
|
|
165
|
+
) -> RBLNAutoencoderKLTemporalDecoderConfig:
|
|
166
|
+
if rbln_config.sample_size is None:
|
|
167
|
+
rbln_config.sample_size = model_config.sample_size
|
|
168
|
+
|
|
169
|
+
if rbln_config.vae_scale_factor is None:
|
|
170
|
+
if hasattr(model_config, "block_out_channels"):
|
|
171
|
+
rbln_config.vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
|
|
172
|
+
else:
|
|
173
|
+
# vae image processor default value 8 (int)
|
|
174
|
+
rbln_config.vae_scale_factor = 8
|
|
175
|
+
|
|
176
|
+
compile_cfgs = []
|
|
177
|
+
if rbln_config.uses_encoder:
|
|
178
|
+
vae_enc_input_info = [
|
|
179
|
+
(
|
|
180
|
+
"x",
|
|
181
|
+
[
|
|
182
|
+
rbln_config.batch_size,
|
|
183
|
+
model_config.in_channels,
|
|
184
|
+
rbln_config.sample_size[0],
|
|
185
|
+
rbln_config.sample_size[1],
|
|
186
|
+
],
|
|
187
|
+
"float32",
|
|
188
|
+
)
|
|
189
|
+
]
|
|
190
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
|
|
191
|
+
|
|
192
|
+
decode_batch_size = rbln_config.batch_size * rbln_config.decode_chunk_size
|
|
193
|
+
vae_dec_input_info = [
|
|
194
|
+
(
|
|
195
|
+
"z",
|
|
196
|
+
[
|
|
197
|
+
decode_batch_size,
|
|
198
|
+
model_config.latent_channels,
|
|
199
|
+
rbln_config.latent_sample_size[0],
|
|
200
|
+
rbln_config.latent_sample_size[1],
|
|
201
|
+
],
|
|
202
|
+
"float32",
|
|
203
|
+
)
|
|
204
|
+
]
|
|
205
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
|
|
206
|
+
|
|
207
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
208
|
+
return rbln_config
|
|
209
|
+
|
|
210
|
+
@classmethod
|
|
211
|
+
def _create_runtimes(
|
|
212
|
+
cls,
|
|
213
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
|
214
|
+
rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
|
|
215
|
+
) -> List[rebel.Runtime]:
|
|
216
|
+
if len(compiled_models) == 1:
|
|
217
|
+
# decoder
|
|
218
|
+
expected_models = ["decoder"]
|
|
219
|
+
else:
|
|
220
|
+
expected_models = ["encoder", "decoder"]
|
|
221
|
+
|
|
222
|
+
if any(model_name not in rbln_config.device_map for model_name in expected_models):
|
|
223
|
+
cls._raise_missing_compiled_file_error(expected_models)
|
|
224
|
+
|
|
225
|
+
device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
|
|
226
|
+
return [
|
|
227
|
+
rebel.Runtime(
|
|
228
|
+
compiled_model,
|
|
229
|
+
tensor_type="pt",
|
|
230
|
+
device=device_val,
|
|
231
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
232
|
+
timeout=rbln_config.timeout,
|
|
233
|
+
)
|
|
234
|
+
for compiled_model, device_val in zip(compiled_models, device_vals)
|
|
235
|
+
]
|
|
236
|
+
|
|
237
|
+
def encode(
|
|
238
|
+
self, x: torch.FloatTensor, return_dict: bool = True
|
|
239
|
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
|
240
|
+
"""
|
|
241
|
+
Encode an input image into a latent representation.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
x: The input image to encode.
|
|
245
|
+
return_dict:
|
|
246
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
The latent representation or AutoencoderKLOutput if return_dict=True
|
|
250
|
+
"""
|
|
251
|
+
posterior = self.encoder.encode(x)
|
|
252
|
+
|
|
253
|
+
if not return_dict:
|
|
254
|
+
return (posterior,)
|
|
255
|
+
|
|
256
|
+
return AutoencoderKLOutput(latent_dist=posterior)
|
|
257
|
+
|
|
258
|
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
|
|
259
|
+
"""
|
|
260
|
+
Decode a latent representation into a video.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
z: The latent representation to decode.
|
|
264
|
+
return_dict:
|
|
265
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
The decoded video or DecoderOutput if return_dict=True
|
|
269
|
+
"""
|
|
270
|
+
decoded = self.decoder.decode(z)
|
|
271
|
+
|
|
272
|
+
if not return_dict:
|
|
273
|
+
return (decoded,)
|
|
274
|
+
|
|
275
|
+
return DecoderOutput(sample=decoded)
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING, List
|
|
15
|
+
from typing import TYPE_CHECKING, List, Union
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution, IdentityDistribution
|
|
@@ -21,7 +21,7 @@ from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
if TYPE_CHECKING:
|
|
24
|
-
from diffusers import AutoencoderKL, AutoencoderKLCosmos, VQModel
|
|
24
|
+
from diffusers import AutoencoderKL, AutoencoderKLCosmos, AutoencoderKLTemporalDecoder, VQModel
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
|
|
@@ -67,18 +67,37 @@ class _VAEDecoder(torch.nn.Module):
|
|
|
67
67
|
return vae_out
|
|
68
68
|
|
|
69
69
|
|
|
70
|
+
class _VAETemporalDecoder(torch.nn.Module):
|
|
71
|
+
def __init__(self, vae: "AutoencoderKLTemporalDecoder"):
|
|
72
|
+
super().__init__()
|
|
73
|
+
self.vae = vae
|
|
74
|
+
self.num_frames = None
|
|
75
|
+
|
|
76
|
+
def forward(self, z):
|
|
77
|
+
vae_out = self.vae.decode(z, num_frames=self.num_frames, return_dict=False)
|
|
78
|
+
return vae_out
|
|
79
|
+
|
|
80
|
+
|
|
70
81
|
class _VAEEncoder(torch.nn.Module):
|
|
71
|
-
def __init__(self, vae: "AutoencoderKL"):
|
|
82
|
+
def __init__(self, vae: Union["AutoencoderKL", "AutoencoderKLTemporalDecoder"]):
|
|
72
83
|
super().__init__()
|
|
73
84
|
self.vae = vae
|
|
74
85
|
|
|
75
86
|
def encode(self, x: torch.FloatTensor, return_dict: bool = True):
|
|
76
|
-
if self
|
|
77
|
-
|
|
87
|
+
if hasattr(self, "use_tiling") and hasattr(self, "use_slicing"):
|
|
88
|
+
if self.use_tiling and (
|
|
89
|
+
x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size
|
|
90
|
+
):
|
|
91
|
+
return self.tiled_encode(x, return_dict=return_dict)
|
|
92
|
+
|
|
93
|
+
if self.use_slicing and x.shape[0] > 1:
|
|
94
|
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
|
95
|
+
h = torch.cat(encoded_slices)
|
|
96
|
+
else:
|
|
97
|
+
h = self.encoder(x)
|
|
98
|
+
if self.quant_conv is not None:
|
|
99
|
+
h = self.quant_conv(h)
|
|
78
100
|
|
|
79
|
-
if self.use_slicing and x.shape[0] > 1:
|
|
80
|
-
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
|
81
|
-
h = torch.cat(encoded_slices)
|
|
82
101
|
else:
|
|
83
102
|
h = self.encoder(x)
|
|
84
103
|
if self.quant_conv is not None:
|
|
@@ -118,7 +118,7 @@ class RBLNControlNetModel(RBLNModel):
|
|
|
118
118
|
)
|
|
119
119
|
|
|
120
120
|
@classmethod
|
|
121
|
-
def
|
|
121
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
122
122
|
use_encoder_hidden_states = False
|
|
123
123
|
for down_block in model.down_blocks:
|
|
124
124
|
if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
|
|
@@ -215,10 +215,25 @@ class RBLNControlNetModel(RBLNModel):
|
|
|
215
215
|
encoder_hidden_states: torch.Tensor,
|
|
216
216
|
controlnet_cond: torch.FloatTensor,
|
|
217
217
|
conditioning_scale: torch.Tensor = 1.0,
|
|
218
|
-
added_cond_kwargs: Dict[str, torch.Tensor] =
|
|
218
|
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
|
219
219
|
return_dict: bool = True,
|
|
220
220
|
**kwargs,
|
|
221
221
|
):
|
|
222
|
+
"""
|
|
223
|
+
Forward pass for the RBLN-optimized ControlNetModel.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
sample (torch.FloatTensor): The noisy input tensor.
|
|
227
|
+
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
|
|
228
|
+
encoder_hidden_states (torch.Tensor): The encoder hidden states.
|
|
229
|
+
controlnet_cond (torch.FloatTensor): The conditional input tensor of shape `(batch_size, max_seq_len, hidden_size)`.
|
|
230
|
+
conditioning_scale (torch.Tensor): The scale factor for ControlNet outputs.
|
|
231
|
+
added_cond_kwargs (Dict[str, torch.Tensor]): Additional conditions for the Stable Diffusion XL UNet.
|
|
232
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
(Union[`~diffusers.models.controlnets.controlnet.ControlNetOutput`], Tuple)
|
|
236
|
+
"""
|
|
222
237
|
sample_batch_size = sample.size()[0]
|
|
223
238
|
compiled_batch_size = self.compiled_batch_size
|
|
224
239
|
if sample_batch_size != compiled_batch_size and (
|
|
@@ -77,7 +77,7 @@ class RBLNPriorTransformer(RBLNModel):
|
|
|
77
77
|
self.clip_std = artifacts["clip_std"]
|
|
78
78
|
|
|
79
79
|
@classmethod
|
|
80
|
-
def
|
|
80
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
81
81
|
return _PriorTransformer(model).eval()
|
|
82
82
|
|
|
83
83
|
@classmethod
|
|
@@ -128,13 +128,27 @@ class RBLNPriorTransformer(RBLNModel):
|
|
|
128
128
|
|
|
129
129
|
def forward(
|
|
130
130
|
self,
|
|
131
|
-
hidden_states,
|
|
131
|
+
hidden_states: torch.Tensor,
|
|
132
132
|
timestep: Union[torch.Tensor, float, int],
|
|
133
133
|
proj_embedding: torch.Tensor,
|
|
134
134
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
135
135
|
attention_mask: Optional[torch.Tensor] = None,
|
|
136
136
|
return_dict: bool = True,
|
|
137
137
|
):
|
|
138
|
+
"""
|
|
139
|
+
Forward pass for the RBLN-optimized PriorTransformer.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
hidden_states (torch.Tensor): The currently predicted image embeddings.
|
|
143
|
+
timestep (Union[torch.Tensor, float, int]): Current denoising step.
|
|
144
|
+
proj_embedding (torch.Tensor): Projected embedding vector the denoising process is conditioned on.
|
|
145
|
+
encoder_hidden_states (Optional[torch.Tensor]): Hidden states of the text embeddings the denoising process is conditioned on.
|
|
146
|
+
attention_mask (Optional[torch.Tensor]): Text mask for the text embeddings.
|
|
147
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`] instead of a plain tuple.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
(Union[`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`, Tuple])
|
|
151
|
+
"""
|
|
138
152
|
# Convert timestep(long) and attention_mask(bool) to float
|
|
139
153
|
return super().forward(
|
|
140
154
|
hidden_states,
|
|
@@ -185,7 +185,7 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
|
|
|
185
185
|
)
|
|
186
186
|
|
|
187
187
|
@classmethod
|
|
188
|
-
def
|
|
188
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
189
189
|
num_latent_frames = rbln_config.num_latent_frames
|
|
190
190
|
latent_height = rbln_config.latent_height
|
|
191
191
|
latent_width = rbln_config.latent_width
|
|
@@ -303,6 +303,21 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
|
|
|
303
303
|
padding_mask: Optional[torch.Tensor] = None,
|
|
304
304
|
return_dict: bool = True,
|
|
305
305
|
):
|
|
306
|
+
"""
|
|
307
|
+
Forward pass for the RBLN-optimized CosmosTransformer3DModel.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
hidden_states (torch.Tensor): The currently predicted image embeddings.
|
|
311
|
+
timestep (torch.Tensor): Current denoising step.
|
|
312
|
+
encoder_hidden_states (torch.Tensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
|
313
|
+
fps: (Optional[int]): Frames per second for the video being generated.
|
|
314
|
+
condition_mask (Optional[torch.Tensor]): Tensor of condition mask.
|
|
315
|
+
padding_mask (Optional[torch.Tensor]): Tensor of padding mask.
|
|
316
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
(Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
|
|
320
|
+
"""
|
|
306
321
|
(
|
|
307
322
|
hidden_states,
|
|
308
323
|
temb,
|
|
@@ -77,7 +77,7 @@ class RBLNSD3Transformer2DModel(RBLNModel):
|
|
|
77
77
|
super().__post_init__(**kwargs)
|
|
78
78
|
|
|
79
79
|
@classmethod
|
|
80
|
-
def
|
|
80
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
81
81
|
return SD3Transformer2DModelWrapper(model).eval()
|
|
82
82
|
|
|
83
83
|
@classmethod
|
|
@@ -161,6 +161,19 @@ class RBLNSD3Transformer2DModel(RBLNModel):
|
|
|
161
161
|
return_dict: bool = True,
|
|
162
162
|
**kwargs,
|
|
163
163
|
):
|
|
164
|
+
"""
|
|
165
|
+
Forward pass for the RBLN-optimized SD3Transformer2DModel.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
hidden_states (torch.FloatTensor): The currently predicted image embeddings.
|
|
169
|
+
encoder_hidden_states (torch.FloatTensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
|
170
|
+
pooled_projections (torch.FloatTensor): Embeddings projected from the embeddings of input conditions.
|
|
171
|
+
timestep (torch.LongTensor): Current denoising step.
|
|
172
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
(Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
|
|
176
|
+
"""
|
|
164
177
|
sample_batch_size = hidden_states.size()[0]
|
|
165
178
|
compiled_batch_size = self.compiled_batch_size
|
|
166
179
|
if sample_batch_size != compiled_batch_size and (
|