optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +48 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +50 -21
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +156 -127
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +26 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +23 -2
- optimum/rbln/utils/runtime_utils.py +42 -6
- optimum/rbln/utils/submodule.py +27 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -171,7 +171,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
|
171
171
|
self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
|
|
172
172
|
|
|
173
173
|
@classmethod
|
|
174
|
-
def
|
|
174
|
+
def _wrap_model_if_needed(
|
|
175
175
|
cls, model: torch.nn.Module, rbln_config: RBLNUNet2DConditionModelConfig
|
|
176
176
|
) -> torch.nn.Module:
|
|
177
177
|
if model.config.addition_embed_type == "text_time":
|
|
@@ -341,7 +341,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
|
341
341
|
timestep_cond: Optional[torch.Tensor] = None,
|
|
342
342
|
attention_mask: Optional[torch.Tensor] = None,
|
|
343
343
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
344
|
-
added_cond_kwargs: Dict[str, torch.Tensor] =
|
|
344
|
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
|
345
345
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
|
346
346
|
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
|
347
347
|
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
|
@@ -349,6 +349,22 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
|
349
349
|
return_dict: bool = True,
|
|
350
350
|
**kwargs,
|
|
351
351
|
) -> Union[UNet2DConditionOutput, Tuple]:
|
|
352
|
+
"""
|
|
353
|
+
Forward pass for the RBLN-optimized UNet2DConditionModel.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
sample (torch.Tensor): The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
|
357
|
+
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
|
|
358
|
+
encoder_hidden_states (torch.Tensor): The encoder hidden states.
|
|
359
|
+
added_cond_kwargs (Dict[str, torch.Tensor]): A kwargs dictionary containing additional embeddings that
|
|
360
|
+
if specified are added to the embeddings that are passed along to the UNet blocks.
|
|
361
|
+
down_block_additional_residuals (Optional[Tuple[torch.Tensor]]): A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
|
362
|
+
mid_block_additional_residual (Optional[torch.Tensor]): A tensor that if specified is added to the residual of the middle unet block.
|
|
363
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
(Union[`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`], Tuple)
|
|
367
|
+
"""
|
|
352
368
|
sample_batch_size = sample.size()[0]
|
|
353
369
|
compiled_batch_size = self.compiled_batch_size
|
|
354
370
|
if sample_batch_size != compiled_batch_size and (
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from diffusers.models.unets.unet_spatio_temporal_condition import (
|
|
20
|
+
UNetSpatioTemporalConditionModel,
|
|
21
|
+
UNetSpatioTemporalConditionOutput,
|
|
22
|
+
)
|
|
23
|
+
from transformers import PretrainedConfig
|
|
24
|
+
|
|
25
|
+
from ....configuration_utils import RBLNCompileConfig
|
|
26
|
+
from ....modeling import RBLNModel
|
|
27
|
+
from ....utils.logging import get_logger
|
|
28
|
+
from ...configurations import RBLNUNetSpatioTemporalConditionModelConfig
|
|
29
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, PreTrainedModel
|
|
34
|
+
|
|
35
|
+
logger = get_logger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class _UNet_STCM(torch.nn.Module):
|
|
39
|
+
def __init__(self, unet: "UNetSpatioTemporalConditionModel"):
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.unet = unet
|
|
42
|
+
|
|
43
|
+
def forward(
|
|
44
|
+
self,
|
|
45
|
+
sample: torch.Tensor,
|
|
46
|
+
timestep: Union[torch.Tensor, float, int],
|
|
47
|
+
encoder_hidden_states: torch.Tensor,
|
|
48
|
+
added_time_ids: torch.Tensor,
|
|
49
|
+
) -> torch.Tensor:
|
|
50
|
+
unet_out = self.unet(
|
|
51
|
+
sample=sample,
|
|
52
|
+
timestep=timestep,
|
|
53
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
54
|
+
added_time_ids=added_time_ids,
|
|
55
|
+
return_dict=False,
|
|
56
|
+
)
|
|
57
|
+
return unet_out
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class RBLNUNetSpatioTemporalConditionModel(RBLNModel):
|
|
61
|
+
hf_library_name = "diffusers"
|
|
62
|
+
auto_model_class = UNetSpatioTemporalConditionModel
|
|
63
|
+
_rbln_config_class = RBLNUNetSpatioTemporalConditionModelConfig
|
|
64
|
+
output_class = UNetSpatioTemporalConditionOutput
|
|
65
|
+
output_key = "sample"
|
|
66
|
+
|
|
67
|
+
def __post_init__(self, **kwargs):
|
|
68
|
+
super().__post_init__(**kwargs)
|
|
69
|
+
self.in_features = self.rbln_config.in_features
|
|
70
|
+
if self.in_features is not None:
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class LINEAR1:
|
|
74
|
+
in_features: int
|
|
75
|
+
|
|
76
|
+
@dataclass
|
|
77
|
+
class ADDEMBEDDING:
|
|
78
|
+
linear_1: LINEAR1
|
|
79
|
+
|
|
80
|
+
self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def _wrap_model_if_needed(
|
|
84
|
+
cls, model: torch.nn.Module, rbln_config: RBLNUNetSpatioTemporalConditionModelConfig
|
|
85
|
+
) -> torch.nn.Module:
|
|
86
|
+
return _UNet_STCM(model).eval()
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def get_unet_sample_size(
|
|
90
|
+
cls,
|
|
91
|
+
pipe: RBLNDiffusionMixin,
|
|
92
|
+
rbln_config: RBLNUNetSpatioTemporalConditionModelConfig,
|
|
93
|
+
image_size: Optional[Tuple[int, int]] = None,
|
|
94
|
+
) -> Union[int, Tuple[int, int]]:
|
|
95
|
+
scale_factor = pipe.vae_scale_factor
|
|
96
|
+
|
|
97
|
+
if image_size is None:
|
|
98
|
+
vae_sample_size = pipe.vae.config.sample_size
|
|
99
|
+
if isinstance(vae_sample_size, int):
|
|
100
|
+
vae_sample_size = (vae_sample_size, vae_sample_size)
|
|
101
|
+
|
|
102
|
+
sample_size = (
|
|
103
|
+
vae_sample_size[0] // scale_factor,
|
|
104
|
+
vae_sample_size[1] // scale_factor,
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
sample_size = (image_size[0] // scale_factor, image_size[1] // scale_factor)
|
|
108
|
+
return sample_size
|
|
109
|
+
|
|
110
|
+
@classmethod
|
|
111
|
+
def update_rbln_config_using_pipe(
|
|
112
|
+
cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
|
113
|
+
) -> Dict[str, Any]:
|
|
114
|
+
rbln_config.unet.sample_size = cls.get_unet_sample_size(
|
|
115
|
+
pipe, rbln_config.unet, image_size=rbln_config.image_size
|
|
116
|
+
)
|
|
117
|
+
return rbln_config
|
|
118
|
+
|
|
119
|
+
@classmethod
|
|
120
|
+
def _update_rbln_config(
|
|
121
|
+
cls,
|
|
122
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor"],
|
|
123
|
+
model: "PreTrainedModel",
|
|
124
|
+
model_config: "PretrainedConfig",
|
|
125
|
+
rbln_config: RBLNUNetSpatioTemporalConditionModelConfig,
|
|
126
|
+
) -> RBLNUNetSpatioTemporalConditionModelConfig:
|
|
127
|
+
if rbln_config.num_frames is None:
|
|
128
|
+
rbln_config.num_frames = model_config.num_frames
|
|
129
|
+
|
|
130
|
+
if rbln_config.sample_size is None:
|
|
131
|
+
rbln_config.sample_size = model_config.sample_size
|
|
132
|
+
|
|
133
|
+
input_info = [
|
|
134
|
+
(
|
|
135
|
+
"sample",
|
|
136
|
+
[
|
|
137
|
+
rbln_config.batch_size,
|
|
138
|
+
rbln_config.num_frames,
|
|
139
|
+
model_config.in_channels,
|
|
140
|
+
rbln_config.sample_size[0],
|
|
141
|
+
rbln_config.sample_size[1],
|
|
142
|
+
],
|
|
143
|
+
"float32",
|
|
144
|
+
),
|
|
145
|
+
("timestep", [], "float32"),
|
|
146
|
+
("encoder_hidden_states", [rbln_config.batch_size, 1, model_config.cross_attention_dim], "float32"),
|
|
147
|
+
("added_time_ids", [rbln_config.batch_size, 3], "float32"),
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
if hasattr(model_config, "addition_time_embed_dim"):
|
|
151
|
+
rbln_config.in_features = model_config.projection_class_embeddings_input_dim
|
|
152
|
+
|
|
153
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
|
154
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
|
155
|
+
|
|
156
|
+
return rbln_config
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def compiled_batch_size(self):
|
|
160
|
+
return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
|
|
161
|
+
|
|
162
|
+
def forward(
|
|
163
|
+
self,
|
|
164
|
+
sample: torch.Tensor,
|
|
165
|
+
timestep: Union[torch.Tensor, float, int],
|
|
166
|
+
encoder_hidden_states: torch.Tensor,
|
|
167
|
+
added_time_ids: torch.Tensor,
|
|
168
|
+
return_dict: bool = True,
|
|
169
|
+
**kwargs,
|
|
170
|
+
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
|
|
171
|
+
"""
|
|
172
|
+
Forward pass for the RBLN-optimized UNetSpatioTemporalConditionModel.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
sample (torch.Tensor): The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
|
176
|
+
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
|
|
177
|
+
encoder_hidden_states (torch.Tensor): The encoder hidden states.
|
|
178
|
+
added_time_ids (torch.Tensor): A tensor containing additional sinusoidal embeddings and added to the time embeddings.
|
|
179
|
+
return_dict (bool): Whether or not to return a [`~diffusers.models.unets.unet_spatio_temporal_condition.UNetSpatioTemporalConditionOutput`] instead of a plain tuple.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
(Union[`~diffusers.models.unets.unet_spatio_temporal_condition.UNetSpatioTemporalConditionOutput`], Tuple)
|
|
183
|
+
"""
|
|
184
|
+
sample_batch_size = sample.size()[0]
|
|
185
|
+
compiled_batch_size = self.compiled_batch_size
|
|
186
|
+
if sample_batch_size != compiled_batch_size and (
|
|
187
|
+
sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
|
|
188
|
+
):
|
|
189
|
+
raise ValueError(
|
|
190
|
+
f"Mismatch between UNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
|
|
191
|
+
"This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
|
|
192
|
+
"Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
|
|
193
|
+
"For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
|
|
194
|
+
)
|
|
195
|
+
return super().forward(
|
|
196
|
+
sample.contiguous(),
|
|
197
|
+
timestep.float(),
|
|
198
|
+
encoder_hidden_states,
|
|
199
|
+
added_time_ids,
|
|
200
|
+
return_dict=return_dict,
|
|
201
|
+
)
|
|
@@ -59,6 +59,9 @@ _import_structure = {
|
|
|
59
59
|
"RBLNStableDiffusion3Img2ImgPipeline",
|
|
60
60
|
"RBLNStableDiffusion3InpaintPipeline",
|
|
61
61
|
],
|
|
62
|
+
"stable_video_diffusion": [
|
|
63
|
+
"RBLNStableVideoDiffusionPipeline",
|
|
64
|
+
],
|
|
62
65
|
}
|
|
63
66
|
if TYPE_CHECKING:
|
|
64
67
|
from .auto_pipeline import (
|
|
@@ -98,6 +101,7 @@ if TYPE_CHECKING:
|
|
|
98
101
|
RBLNStableDiffusionXLInpaintPipeline,
|
|
99
102
|
RBLNStableDiffusionXLPipeline,
|
|
100
103
|
)
|
|
104
|
+
from .stable_video_diffusion import RBLNStableVideoDiffusionPipeline
|
|
101
105
|
else:
|
|
102
106
|
import sys
|
|
103
107
|
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import importlib
|
|
17
17
|
from pathlib import Path
|
|
18
|
-
from typing import Any, Dict, Type, Union
|
|
18
|
+
from typing import Any, Dict, Optional, Type, Union
|
|
19
19
|
|
|
20
20
|
from diffusers.models.controlnets import ControlNetUnionModel
|
|
21
21
|
from diffusers.pipelines.auto_pipeline import (
|
|
@@ -174,7 +174,7 @@ class RBLNAutoPipelineBase:
|
|
|
174
174
|
model_id: Union[str, Path],
|
|
175
175
|
*,
|
|
176
176
|
export: bool = None,
|
|
177
|
-
rbln_config: Union[Dict[str, Any], RBLNModelConfig] =
|
|
177
|
+
rbln_config: Optional[Union[Dict[str, Any], RBLNModelConfig]] = None,
|
|
178
178
|
**kwargs: Any,
|
|
179
179
|
):
|
|
180
180
|
"""
|
|
@@ -96,6 +96,26 @@ class RBLNMultiControlNetModel(RBLNModel):
|
|
|
96
96
|
guess_mode: bool = False,
|
|
97
97
|
return_dict: bool = True,
|
|
98
98
|
):
|
|
99
|
+
"""
|
|
100
|
+
Forward pass for the RBLN-optimized MultiControlNetModel.
|
|
101
|
+
|
|
102
|
+
This method processes multiple ControlNet models in sequence, applying each one to the input sample
|
|
103
|
+
with its corresponding conditioning image and scale factor. The outputs from all ControlNets are
|
|
104
|
+
merged by addition to produce the final control signals.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
sample (torch.FloatTensor): The noisy input tensor.
|
|
108
|
+
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
|
|
109
|
+
encoder_hidden_states (torch.Tensor): The encoder hidden states from the text encoder.
|
|
110
|
+
controlnet_cond (List[torch.Tensor]): A list of conditional input tensors, one for each ControlNet model.
|
|
111
|
+
conditioning_scale (List[float]): A list of scale factors for each ControlNet output. Each scale
|
|
112
|
+
controls the strength of the corresponding ControlNet's influence on the generation.
|
|
113
|
+
return_dict (bool): Whether or not to return a dictionary instead of a plain tuple. Currently,
|
|
114
|
+
this method always returns a tuple regardless of this parameter.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
(Tuple[List[torch.Tensor], torch.Tensor])
|
|
118
|
+
"""
|
|
99
119
|
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
|
100
120
|
down_samples, mid_sample = controlnet(
|
|
101
121
|
sample=sample.contiguous(),
|
|
@@ -151,7 +151,9 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
|
|
|
151
151
|
for image_ in image:
|
|
152
152
|
self.check_image(image_, prompt, prompt_embeds)
|
|
153
153
|
else:
|
|
154
|
-
|
|
154
|
+
raise TypeError(
|
|
155
|
+
"Unsupported controlnet type. Expected `RBLNControlNetModel` or `RBLNMultiControlNetModel`."
|
|
156
|
+
)
|
|
155
157
|
|
|
156
158
|
# Check `controlnet_conditioning_scale`
|
|
157
159
|
if (
|
|
@@ -180,7 +182,9 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
|
|
|
180
182
|
" the same length as the number of controlnets"
|
|
181
183
|
)
|
|
182
184
|
else:
|
|
183
|
-
|
|
185
|
+
raise TypeError(
|
|
186
|
+
"Unsupported controlnet type. Expected `RBLNControlNetModel` or `RBLNMultiControlNetModel`."
|
|
187
|
+
)
|
|
184
188
|
|
|
185
189
|
if not isinstance(control_guidance_start, (tuple, list)):
|
|
186
190
|
control_guidance_start = [control_guidance_start]
|
|
@@ -254,7 +258,7 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
|
|
|
254
258
|
control_guidance_end: Union[float, List[float]] = 1.0,
|
|
255
259
|
clip_skip: Optional[int] = None,
|
|
256
260
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
|
257
|
-
callback_on_step_end_tensor_inputs: List[str] =
|
|
261
|
+
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
|
258
262
|
**kwargs,
|
|
259
263
|
):
|
|
260
264
|
r"""
|
|
@@ -393,6 +397,9 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
|
|
|
393
397
|
)
|
|
394
398
|
|
|
395
399
|
# 1. Check inputs. Raise error if not correct
|
|
400
|
+
if callback_on_step_end_tensor_inputs is None:
|
|
401
|
+
callback_on_step_end_tensor_inputs = ["latents"]
|
|
402
|
+
|
|
396
403
|
self.check_inputs(
|
|
397
404
|
prompt,
|
|
398
405
|
image,
|
|
@@ -503,7 +510,9 @@ class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionC
|
|
|
503
510
|
image = images
|
|
504
511
|
height, width = image[0].shape[-2:]
|
|
505
512
|
else:
|
|
506
|
-
|
|
513
|
+
raise TypeError(
|
|
514
|
+
"Unsupported controlnet type. Expected `RBLNControlNetModel` or `RBLNMultiControlNetModel`."
|
|
515
|
+
)
|
|
507
516
|
|
|
508
517
|
# 5. Prepare timesteps
|
|
509
518
|
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
|
@@ -152,7 +152,9 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDif
|
|
|
152
152
|
for image_ in image:
|
|
153
153
|
self.check_image(image_, prompt, prompt_embeds)
|
|
154
154
|
else:
|
|
155
|
-
|
|
155
|
+
raise TypeError(
|
|
156
|
+
"Unsupported controlnet type. Expected `RBLNControlNetModel` or `RBLNMultiControlNetModel`."
|
|
157
|
+
)
|
|
156
158
|
|
|
157
159
|
# Check `controlnet_conditioning_scale`
|
|
158
160
|
if (
|
|
@@ -178,7 +180,9 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDif
|
|
|
178
180
|
" the same length as the number of controlnets"
|
|
179
181
|
)
|
|
180
182
|
else:
|
|
181
|
-
|
|
183
|
+
raise TypeError(
|
|
184
|
+
"Unsupported controlnet type. Expected `RBLNControlNetModel` or `RBLNMultiControlNetModel`."
|
|
185
|
+
)
|
|
182
186
|
|
|
183
187
|
if len(control_guidance_start) != len(control_guidance_end):
|
|
184
188
|
raise ValueError(
|
|
@@ -247,7 +251,7 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDif
|
|
|
247
251
|
control_guidance_end: Union[float, List[float]] = 1.0,
|
|
248
252
|
clip_skip: Optional[int] = None,
|
|
249
253
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
|
250
|
-
callback_on_step_end_tensor_inputs: List[str] =
|
|
254
|
+
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
|
251
255
|
**kwargs,
|
|
252
256
|
):
|
|
253
257
|
r"""
|
|
@@ -384,6 +388,9 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDif
|
|
|
384
388
|
)
|
|
385
389
|
|
|
386
390
|
# 1. Check inputs. Raise error if not correct
|
|
391
|
+
if callback_on_step_end_tensor_inputs is None:
|
|
392
|
+
callback_on_step_end_tensor_inputs = ["latents"]
|
|
393
|
+
|
|
387
394
|
self.check_inputs(
|
|
388
395
|
prompt,
|
|
389
396
|
control_image,
|
|
@@ -490,7 +497,9 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDif
|
|
|
490
497
|
|
|
491
498
|
control_image = control_images
|
|
492
499
|
else:
|
|
493
|
-
|
|
500
|
+
raise TypeError(
|
|
501
|
+
"Unsupported controlnet type. Expected `RBLNControlNetModel` or `RBLNMultiControlNetModel`."
|
|
502
|
+
)
|
|
494
503
|
|
|
495
504
|
# 5. Prepare timesteps
|
|
496
505
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
|
@@ -178,7 +178,9 @@ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusio
|
|
|
178
178
|
for image_ in image:
|
|
179
179
|
self.check_image(image_, prompt, prompt_embeds)
|
|
180
180
|
else:
|
|
181
|
-
|
|
181
|
+
raise TypeError(
|
|
182
|
+
"Unsupported controlnet type. Expected `RBLNControlNetModel` or `RBLNMultiControlNetModel`."
|
|
183
|
+
)
|
|
182
184
|
|
|
183
185
|
# Check `controlnet_conditioning_scale`
|
|
184
186
|
if (
|
|
@@ -204,7 +206,9 @@ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusio
|
|
|
204
206
|
" the same length as the number of controlnets"
|
|
205
207
|
)
|
|
206
208
|
else:
|
|
207
|
-
|
|
209
|
+
raise TypeError(
|
|
210
|
+
"Unsupported controlnet type. Expected `RBLNControlNetModel` or `RBLNMultiControlNetModel`."
|
|
211
|
+
)
|
|
208
212
|
|
|
209
213
|
if not isinstance(control_guidance_start, (tuple, list)):
|
|
210
214
|
control_guidance_start = [control_guidance_start]
|
|
@@ -288,7 +292,7 @@ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusio
|
|
|
288
292
|
negative_target_size: Optional[Tuple[int, int]] = None,
|
|
289
293
|
clip_skip: Optional[int] = None,
|
|
290
294
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
|
291
|
-
callback_on_step_end_tensor_inputs: List[str] =
|
|
295
|
+
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
|
292
296
|
**kwargs,
|
|
293
297
|
):
|
|
294
298
|
r"""
|
|
@@ -466,6 +470,9 @@ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusio
|
|
|
466
470
|
)
|
|
467
471
|
|
|
468
472
|
# 1. Check inputs. Raise error if not correct
|
|
473
|
+
if callback_on_step_end_tensor_inputs is None:
|
|
474
|
+
callback_on_step_end_tensor_inputs = ["latents"]
|
|
475
|
+
|
|
469
476
|
self.check_inputs(
|
|
470
477
|
prompt,
|
|
471
478
|
prompt_2,
|
|
@@ -581,7 +588,9 @@ class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusio
|
|
|
581
588
|
image = images
|
|
582
589
|
height, width = image[0].shape[-2:]
|
|
583
590
|
else:
|
|
584
|
-
|
|
591
|
+
raise TypeError(
|
|
592
|
+
"Unsupported controlnet type. Expected `RBLNControlNetModel` or `RBLNMultiControlNetModel`."
|
|
593
|
+
)
|
|
585
594
|
|
|
586
595
|
# 5. Prepare timesteps
|
|
587
596
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
|
@@ -190,7 +190,9 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableD
|
|
|
190
190
|
for image_ in image:
|
|
191
191
|
self.check_image(image_, prompt, prompt_embeds)
|
|
192
192
|
else:
|
|
193
|
-
|
|
193
|
+
raise TypeError(
|
|
194
|
+
"Unsupported controlnet type. Expected `RBLNControlNetModel` or `RBLNMultiControlNetModel`."
|
|
195
|
+
)
|
|
194
196
|
|
|
195
197
|
# Check `controlnet_conditioning_scale`
|
|
196
198
|
if (
|
|
@@ -216,7 +218,9 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableD
|
|
|
216
218
|
" the same length as the number of controlnets"
|
|
217
219
|
)
|
|
218
220
|
else:
|
|
219
|
-
|
|
221
|
+
raise TypeError(
|
|
222
|
+
"Unsupported controlnet type. Expected `RBLNControlNetModel` or `RBLNMultiControlNetModel`."
|
|
223
|
+
)
|
|
220
224
|
|
|
221
225
|
if not isinstance(control_guidance_start, (tuple, list)):
|
|
222
226
|
control_guidance_start = [control_guidance_start]
|
|
@@ -303,7 +307,7 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableD
|
|
|
303
307
|
negative_aesthetic_score: float = 2.5,
|
|
304
308
|
clip_skip: Optional[int] = None,
|
|
305
309
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
|
306
|
-
callback_on_step_end_tensor_inputs: List[str] =
|
|
310
|
+
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
|
307
311
|
**kwargs,
|
|
308
312
|
):
|
|
309
313
|
r"""
|
|
@@ -500,6 +504,9 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableD
|
|
|
500
504
|
)
|
|
501
505
|
|
|
502
506
|
# 1. Check inputs. Raise error if not correct
|
|
507
|
+
if callback_on_step_end_tensor_inputs is None:
|
|
508
|
+
callback_on_step_end_tensor_inputs = ["latents"]
|
|
509
|
+
|
|
503
510
|
self.check_inputs(
|
|
504
511
|
prompt,
|
|
505
512
|
prompt_2,
|
|
@@ -618,7 +625,9 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableD
|
|
|
618
625
|
control_image = control_images
|
|
619
626
|
height, width = control_image[0].shape[-2:]
|
|
620
627
|
else:
|
|
621
|
-
|
|
628
|
+
raise TypeError(
|
|
629
|
+
"Unsupported controlnet type. Expected `RBLNControlNetModel` or `RBLNMultiControlNetModel`."
|
|
630
|
+
)
|
|
622
631
|
|
|
623
632
|
# 5. Prepare timesteps
|
|
624
633
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
|
@@ -86,7 +86,7 @@ class RBLNCosmosTextToWorldPipeline(RBLNDiffusionMixin, CosmosTextToWorldPipelin
|
|
|
86
86
|
*,
|
|
87
87
|
export: bool = False,
|
|
88
88
|
safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
|
|
89
|
-
rbln_config: Dict[str, Any] =
|
|
89
|
+
rbln_config: Optional[Dict[str, Any]] = None,
|
|
90
90
|
**kwargs: Any,
|
|
91
91
|
):
|
|
92
92
|
"""
|
|
@@ -86,7 +86,7 @@ class RBLNCosmosVideoToWorldPipeline(RBLNDiffusionMixin, CosmosVideoToWorldPipel
|
|
|
86
86
|
*,
|
|
87
87
|
export: bool = False,
|
|
88
88
|
safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
|
|
89
|
-
rbln_config: Dict[str, Any] =
|
|
89
|
+
rbln_config: Optional[Dict[str, Any]] = None,
|
|
90
90
|
**kwargs: Any,
|
|
91
91
|
):
|
|
92
92
|
"""
|
|
@@ -118,7 +118,6 @@ class RBLNCosmosVideoToWorldPipeline(RBLNDiffusionMixin, CosmosVideoToWorldPipel
|
|
|
118
118
|
RBLN compilation process. These may include parameters specific to individual submodules
|
|
119
119
|
or the particular diffusion pipeline being used.
|
|
120
120
|
"""
|
|
121
|
-
|
|
122
121
|
rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
|
|
123
122
|
if safety_checker is None and export:
|
|
124
123
|
safety_checker = RBLNCosmosSafetyChecker(rbln_config=rbln_config.safety_checker)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .pipeline_stable_video_diffusion import RBLNStableVideoDiffusionPipeline
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from diffusers import StableVideoDiffusionPipeline
|
|
17
|
+
|
|
18
|
+
from ....utils.logging import get_logger
|
|
19
|
+
from ...configurations import RBLNStableVideoDiffusionPipelineConfig
|
|
20
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
logger = get_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class RBLNStableVideoDiffusionPipeline(RBLNDiffusionMixin, StableVideoDiffusionPipeline):
|
|
27
|
+
"""
|
|
28
|
+
RBLN-accelerated implementation of Stable Video Diffusion pipeline for image-to-video generation.
|
|
29
|
+
|
|
30
|
+
This pipeline compiles Stable Video Diffusion models to run efficiently on RBLN NPUs, enabling high-performance
|
|
31
|
+
inference for generating videos from images with optimized memory usage and throughput.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
original_class = StableVideoDiffusionPipeline
|
|
35
|
+
_rbln_config_class = RBLNStableVideoDiffusionPipelineConfig
|
|
36
|
+
_submodules = ["image_encoder", "unet", "vae"]
|
|
37
|
+
|
|
38
|
+
def handle_additional_kwargs(self, **kwargs):
|
|
39
|
+
compiled_num_frames = self.unet.rbln_config.num_frames
|
|
40
|
+
if compiled_num_frames is not None:
|
|
41
|
+
kwargs["num_frames"] = compiled_num_frames
|
|
42
|
+
|
|
43
|
+
compiled_decode_chunk_size = self.vae.rbln_config.decode_chunk_size
|
|
44
|
+
if compiled_decode_chunk_size is not None:
|
|
45
|
+
kwargs["decode_chunk_size"] = compiled_decode_chunk_size
|
|
46
|
+
return kwargs
|