optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.4a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. optimum/rbln/__init__.py +12 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +16 -6
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +12 -8
  38. optimum/rbln/transformers/configuration_generic.py +0 -27
  39. optimum/rbln/transformers/modeling_attention_utils.py +242 -109
  40. optimum/rbln/transformers/modeling_generic.py +2 -61
  41. optimum/rbln/transformers/modeling_outputs.py +1 -0
  42. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  43. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  44. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  45. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  46. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  47. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  48. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  49. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  50. optimum/rbln/transformers/models/colpali/modeling_colpali.py +6 -45
  51. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +0 -2
  52. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +10 -1
  53. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  54. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +92 -43
  55. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +207 -64
  56. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  57. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  58. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +140 -46
  59. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  60. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  61. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  62. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +7 -1
  63. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  64. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  65. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +1 -1
  66. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  67. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  68. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -25
  69. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  70. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  71. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  72. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  73. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  74. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  75. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  76. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -9
  77. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -7
  78. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +1 -1
  79. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  80. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  81. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  82. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  83. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  84. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  85. optimum/rbln/transformers/models/siglip/modeling_siglip.py +17 -1
  86. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  87. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  88. optimum/rbln/transformers/models/t5/t5_architecture.py +1 -1
  89. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  90. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  91. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  92. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  93. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  94. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  95. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  96. optimum/rbln/transformers/utils/rbln_quantization.py +9 -0
  97. optimum/rbln/utils/deprecation.py +213 -0
  98. optimum/rbln/utils/hub.py +14 -3
  99. optimum/rbln/utils/import_utils.py +7 -1
  100. optimum/rbln/utils/runtime_utils.py +32 -0
  101. optimum/rbln/utils/submodule.py +3 -1
  102. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/METADATA +2 -2
  103. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/RECORD +106 -99
  104. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/WHEEL +1 -1
  105. optimum/rbln/utils/depreacate_utils.py +0 -16
  106. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/entry_points.txt +0 -0
  107. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,275 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Dict, List, Tuple, Union
16
+
17
+ import rebel
18
+ import torch # noqa: I001
19
+ from diffusers import AutoencoderKLTemporalDecoder
20
+ from diffusers.models.autoencoders.vae import DecoderOutput
21
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
22
+ from transformers import PretrainedConfig
23
+
24
+ from ....configuration_utils import RBLNCompileConfig
25
+ from ....modeling import RBLNModel
26
+ from ....utils.logging import get_logger
27
+ from ...configurations import RBLNAutoencoderKLTemporalDecoderConfig
28
+ from ...modeling_diffusers import RBLNDiffusionMixin
29
+ from .vae import (
30
+ DiagonalGaussianDistribution,
31
+ RBLNRuntimeVAEDecoder,
32
+ RBLNRuntimeVAEEncoder,
33
+ _VAEEncoder,
34
+ _VAETemporalDecoder,
35
+ )
36
+
37
+
38
+ if TYPE_CHECKING:
39
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
40
+
41
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
42
+
43
+ logger = get_logger(__name__)
44
+
45
+
46
+ class RBLNAutoencoderKLTemporalDecoder(RBLNModel):
47
+ auto_model_class = AutoencoderKLTemporalDecoder
48
+ hf_library_name = "diffusers"
49
+ _rbln_config_class = RBLNAutoencoderKLTemporalDecoderConfig
50
+
51
+ def __post_init__(self, **kwargs):
52
+ super().__post_init__(**kwargs)
53
+
54
+ if self.rbln_config.uses_encoder:
55
+ self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
56
+ self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[-1], main_input_name="z")
57
+ self.image_size = self.rbln_config.image_size
58
+
59
+ @classmethod
60
+ def _wrap_model_if_needed(
61
+ cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLTemporalDecoderConfig
62
+ ) -> torch.nn.Module:
63
+ decoder_model = _VAETemporalDecoder(model)
64
+ decoder_model.num_frames = rbln_config.decode_chunk_size
65
+ decoder_model.eval()
66
+
67
+ if rbln_config.uses_encoder:
68
+ encoder_model = _VAEEncoder(model)
69
+ encoder_model.eval()
70
+ return encoder_model, decoder_model
71
+ else:
72
+ return decoder_model
73
+
74
+ @classmethod
75
+ def get_compiled_model(
76
+ cls, model, rbln_config: RBLNAutoencoderKLTemporalDecoderConfig
77
+ ) -> Dict[str, rebel.RBLNCompiledModel]:
78
+ compiled_models = {}
79
+ if rbln_config.uses_encoder:
80
+ encoder_model, decoder_model = cls._wrap_model_if_needed(model, rbln_config)
81
+ enc_compiled_model = cls.compile(
82
+ encoder_model,
83
+ rbln_compile_config=rbln_config.compile_cfgs[0],
84
+ create_runtimes=rbln_config.create_runtimes,
85
+ device=rbln_config.device_map["encoder"],
86
+ )
87
+ compiled_models["encoder"] = enc_compiled_model
88
+ else:
89
+ decoder_model = cls._wrap_model_if_needed(model, rbln_config)
90
+ dec_compiled_model = cls.compile(
91
+ decoder_model,
92
+ rbln_compile_config=rbln_config.compile_cfgs[-1],
93
+ create_runtimes=rbln_config.create_runtimes,
94
+ device=rbln_config.device_map["decoder"],
95
+ )
96
+ compiled_models["decoder"] = dec_compiled_model
97
+
98
+ return compiled_models
99
+
100
+ @classmethod
101
+ def get_vae_sample_size(
102
+ cls,
103
+ pipe: "RBLNDiffusionMixin",
104
+ rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
105
+ return_vae_scale_factor: bool = False,
106
+ ) -> Tuple[int, int]:
107
+ sample_size = rbln_config.sample_size
108
+ if hasattr(pipe, "vae_scale_factor"):
109
+ vae_scale_factor = pipe.vae_scale_factor
110
+ else:
111
+ if hasattr(pipe.vae.config, "block_out_channels"):
112
+ vae_scale_factor = 2 ** (len(pipe.vae.config.block_out_channels) - 1)
113
+ else:
114
+ vae_scale_factor = 8 # vae image processor default value 8 (int)
115
+
116
+ if sample_size is None:
117
+ sample_size = pipe.unet.config.sample_size
118
+ if isinstance(sample_size, int):
119
+ sample_size = (sample_size, sample_size)
120
+ sample_size = (sample_size[0] * vae_scale_factor, sample_size[1] * vae_scale_factor)
121
+
122
+ if return_vae_scale_factor:
123
+ return sample_size, vae_scale_factor
124
+ else:
125
+ return sample_size
126
+
127
+ @classmethod
128
+ def update_rbln_config_using_pipe(
129
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
130
+ ) -> "RBLNDiffusionMixinConfig":
131
+ rbln_config.vae.sample_size, rbln_config.vae.vae_scale_factor = cls.get_vae_sample_size(
132
+ pipe, rbln_config.vae, return_vae_scale_factor=True
133
+ )
134
+
135
+ if rbln_config.vae.num_frames is None:
136
+ if hasattr(pipe.unet.config, "num_frames"):
137
+ rbln_config.vae.num_frames = pipe.unet.config.num_frames
138
+ else:
139
+ raise ValueError("num_frames should be specified in unet config.json")
140
+
141
+ if rbln_config.vae.decode_chunk_size is None:
142
+ rbln_config.vae.decode_chunk_size = rbln_config.vae.num_frames
143
+
144
+ def chunk_frame(num_frames, decode_chunk_size):
145
+ # get closest divisor to num_frames
146
+ divisors = [i for i in range(1, num_frames) if num_frames % i == 0]
147
+ closest = min(divisors, key=lambda x: abs(x - decode_chunk_size))
148
+ if decode_chunk_size != closest:
149
+ logger.warning(
150
+ f"To ensure successful model compilation and prevent device OOM, {decode_chunk_size} is set to {closest}."
151
+ )
152
+ return closest
153
+
154
+ decode_chunk_size = chunk_frame(rbln_config.vae.num_frames, rbln_config.vae.decode_chunk_size)
155
+ rbln_config.vae.decode_chunk_size = decode_chunk_size
156
+ return rbln_config
157
+
158
+ @classmethod
159
+ def _update_rbln_config(
160
+ cls,
161
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
162
+ model: "PreTrainedModel",
163
+ model_config: "PretrainedConfig",
164
+ rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
165
+ ) -> RBLNAutoencoderKLTemporalDecoderConfig:
166
+ if rbln_config.sample_size is None:
167
+ rbln_config.sample_size = model_config.sample_size
168
+
169
+ if rbln_config.vae_scale_factor is None:
170
+ if hasattr(model_config, "block_out_channels"):
171
+ rbln_config.vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
172
+ else:
173
+ # vae image processor default value 8 (int)
174
+ rbln_config.vae_scale_factor = 8
175
+
176
+ compile_cfgs = []
177
+ if rbln_config.uses_encoder:
178
+ vae_enc_input_info = [
179
+ (
180
+ "x",
181
+ [
182
+ rbln_config.batch_size,
183
+ model_config.in_channels,
184
+ rbln_config.sample_size[0],
185
+ rbln_config.sample_size[1],
186
+ ],
187
+ "float32",
188
+ )
189
+ ]
190
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
191
+
192
+ decode_batch_size = rbln_config.batch_size * rbln_config.decode_chunk_size
193
+ vae_dec_input_info = [
194
+ (
195
+ "z",
196
+ [
197
+ decode_batch_size,
198
+ model_config.latent_channels,
199
+ rbln_config.latent_sample_size[0],
200
+ rbln_config.latent_sample_size[1],
201
+ ],
202
+ "float32",
203
+ )
204
+ ]
205
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
206
+
207
+ rbln_config.set_compile_cfgs(compile_cfgs)
208
+ return rbln_config
209
+
210
+ @classmethod
211
+ def _create_runtimes(
212
+ cls,
213
+ compiled_models: List[rebel.RBLNCompiledModel],
214
+ rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
215
+ ) -> List[rebel.Runtime]:
216
+ if len(compiled_models) == 1:
217
+ # decoder
218
+ expected_models = ["decoder"]
219
+ else:
220
+ expected_models = ["encoder", "decoder"]
221
+
222
+ if any(model_name not in rbln_config.device_map for model_name in expected_models):
223
+ cls._raise_missing_compiled_file_error(expected_models)
224
+
225
+ device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
226
+ return [
227
+ rebel.Runtime(
228
+ compiled_model,
229
+ tensor_type="pt",
230
+ device=device_val,
231
+ activate_profiler=rbln_config.activate_profiler,
232
+ timeout=rbln_config.timeout,
233
+ )
234
+ for compiled_model, device_val in zip(compiled_models, device_vals)
235
+ ]
236
+
237
+ def encode(
238
+ self, x: torch.FloatTensor, return_dict: bool = True
239
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
240
+ """
241
+ Encode an input image into a latent representation.
242
+
243
+ Args:
244
+ x: The input image to encode.
245
+ return_dict:
246
+ Whether to return output as a dictionary. Defaults to True.
247
+
248
+ Returns:
249
+ The latent representation or AutoencoderKLOutput if return_dict=True
250
+ """
251
+ posterior = self.encoder.encode(x)
252
+
253
+ if not return_dict:
254
+ return (posterior,)
255
+
256
+ return AutoencoderKLOutput(latent_dist=posterior)
257
+
258
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
259
+ """
260
+ Decode a latent representation into a video.
261
+
262
+ Args:
263
+ z: The latent representation to decode.
264
+ return_dict:
265
+ Whether to return output as a dictionary. Defaults to True.
266
+
267
+ Returns:
268
+ The decoded video or DecoderOutput if return_dict=True
269
+ """
270
+ decoded = self.decoder.decode(z)
271
+
272
+ if not return_dict:
273
+ return (decoded,)
274
+
275
+ return DecoderOutput(sample=decoded)
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, List
15
+ from typing import TYPE_CHECKING, List, Union
16
16
 
17
17
  import torch
18
18
  from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution, IdentityDistribution
@@ -21,7 +21,7 @@ from ....utils.runtime_utils import RBLNPytorchRuntime
21
21
 
22
22
 
23
23
  if TYPE_CHECKING:
24
- from diffusers import AutoencoderKL, AutoencoderKLCosmos, VQModel
24
+ from diffusers import AutoencoderKL, AutoencoderKLCosmos, AutoencoderKLTemporalDecoder, VQModel
25
25
 
26
26
 
27
27
  class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
@@ -67,18 +67,37 @@ class _VAEDecoder(torch.nn.Module):
67
67
  return vae_out
68
68
 
69
69
 
70
+ class _VAETemporalDecoder(torch.nn.Module):
71
+ def __init__(self, vae: "AutoencoderKLTemporalDecoder"):
72
+ super().__init__()
73
+ self.vae = vae
74
+ self.num_frames = None
75
+
76
+ def forward(self, z):
77
+ vae_out = self.vae.decode(z, num_frames=self.num_frames, return_dict=False)
78
+ return vae_out
79
+
80
+
70
81
  class _VAEEncoder(torch.nn.Module):
71
- def __init__(self, vae: "AutoencoderKL"):
82
+ def __init__(self, vae: Union["AutoencoderKL", "AutoencoderKLTemporalDecoder"]):
72
83
  super().__init__()
73
84
  self.vae = vae
74
85
 
75
86
  def encode(self, x: torch.FloatTensor, return_dict: bool = True):
76
- if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
77
- return self.tiled_encode(x, return_dict=return_dict)
87
+ if hasattr(self, "use_tiling") and hasattr(self, "use_slicing"):
88
+ if self.use_tiling and (
89
+ x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size
90
+ ):
91
+ return self.tiled_encode(x, return_dict=return_dict)
92
+
93
+ if self.use_slicing and x.shape[0] > 1:
94
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
95
+ h = torch.cat(encoded_slices)
96
+ else:
97
+ h = self.encoder(x)
98
+ if self.quant_conv is not None:
99
+ h = self.quant_conv(h)
78
100
 
79
- if self.use_slicing and x.shape[0] > 1:
80
- encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
81
- h = torch.cat(encoded_slices)
82
101
  else:
83
102
  h = self.encoder(x)
84
103
  if self.quant_conv is not None:
@@ -118,7 +118,7 @@ class RBLNControlNetModel(RBLNModel):
118
118
  )
119
119
 
120
120
  @classmethod
121
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
121
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
122
122
  use_encoder_hidden_states = False
123
123
  for down_block in model.down_blocks:
124
124
  if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
@@ -215,10 +215,25 @@ class RBLNControlNetModel(RBLNModel):
215
215
  encoder_hidden_states: torch.Tensor,
216
216
  controlnet_cond: torch.FloatTensor,
217
217
  conditioning_scale: torch.Tensor = 1.0,
218
- added_cond_kwargs: Dict[str, torch.Tensor] = {},
218
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
219
219
  return_dict: bool = True,
220
220
  **kwargs,
221
221
  ):
222
+ """
223
+ Forward pass for the RBLN-optimized ControlNetModel.
224
+
225
+ Args:
226
+ sample (torch.FloatTensor): The noisy input tensor.
227
+ timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
228
+ encoder_hidden_states (torch.Tensor): The encoder hidden states.
229
+ controlnet_cond (torch.FloatTensor): The conditional input tensor of shape `(batch_size, max_seq_len, hidden_size)`.
230
+ conditioning_scale (torch.Tensor): The scale factor for ControlNet outputs.
231
+ added_cond_kwargs (Dict[str, torch.Tensor]): Additional conditions for the Stable Diffusion XL UNet.
232
+ return_dict (bool): Whether or not to return a [`~diffusers.models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple
233
+
234
+ Returns:
235
+ (Union[`~diffusers.models.controlnets.controlnet.ControlNetOutput`], Tuple)
236
+ """
222
237
  sample_batch_size = sample.size()[0]
223
238
  compiled_batch_size = self.compiled_batch_size
224
239
  if sample_batch_size != compiled_batch_size and (
@@ -77,7 +77,7 @@ class RBLNPriorTransformer(RBLNModel):
77
77
  self.clip_std = artifacts["clip_std"]
78
78
 
79
79
  @classmethod
80
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
80
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
81
81
  return _PriorTransformer(model).eval()
82
82
 
83
83
  @classmethod
@@ -128,13 +128,27 @@ class RBLNPriorTransformer(RBLNModel):
128
128
 
129
129
  def forward(
130
130
  self,
131
- hidden_states,
131
+ hidden_states: torch.Tensor,
132
132
  timestep: Union[torch.Tensor, float, int],
133
133
  proj_embedding: torch.Tensor,
134
134
  encoder_hidden_states: Optional[torch.Tensor] = None,
135
135
  attention_mask: Optional[torch.Tensor] = None,
136
136
  return_dict: bool = True,
137
137
  ):
138
+ """
139
+ Forward pass for the RBLN-optimized PriorTransformer.
140
+
141
+ Args:
142
+ hidden_states (torch.Tensor): The currently predicted image embeddings.
143
+ timestep (Union[torch.Tensor, float, int]): Current denoising step.
144
+ proj_embedding (torch.Tensor): Projected embedding vector the denoising process is conditioned on.
145
+ encoder_hidden_states (Optional[torch.Tensor]): Hidden states of the text embeddings the denoising process is conditioned on.
146
+ attention_mask (Optional[torch.Tensor]): Text mask for the text embeddings.
147
+ return_dict (bool): Whether or not to return a [`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`] instead of a plain tuple.
148
+
149
+ Returns:
150
+ (Union[`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`, Tuple])
151
+ """
138
152
  # Convert timestep(long) and attention_mask(bool) to float
139
153
  return super().forward(
140
154
  hidden_states,
@@ -185,7 +185,7 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
185
185
  )
186
186
 
187
187
  @classmethod
188
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
188
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
189
189
  num_latent_frames = rbln_config.num_latent_frames
190
190
  latent_height = rbln_config.latent_height
191
191
  latent_width = rbln_config.latent_width
@@ -303,6 +303,21 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
303
303
  padding_mask: Optional[torch.Tensor] = None,
304
304
  return_dict: bool = True,
305
305
  ):
306
+ """
307
+ Forward pass for the RBLN-optimized CosmosTransformer3DModel.
308
+
309
+ Args:
310
+ hidden_states (torch.Tensor): The currently predicted image embeddings.
311
+ timestep (torch.Tensor): Current denoising step.
312
+ encoder_hidden_states (torch.Tensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
313
+ fps: (Optional[int]): Frames per second for the video being generated.
314
+ condition_mask (Optional[torch.Tensor]): Tensor of condition mask.
315
+ padding_mask (Optional[torch.Tensor]): Tensor of padding mask.
316
+ return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
317
+
318
+ Returns:
319
+ (Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
320
+ """
306
321
  (
307
322
  hidden_states,
308
323
  temb,
@@ -77,7 +77,7 @@ class RBLNSD3Transformer2DModel(RBLNModel):
77
77
  super().__post_init__(**kwargs)
78
78
 
79
79
  @classmethod
80
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
80
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
81
81
  return SD3Transformer2DModelWrapper(model).eval()
82
82
 
83
83
  @classmethod
@@ -161,6 +161,19 @@ class RBLNSD3Transformer2DModel(RBLNModel):
161
161
  return_dict: bool = True,
162
162
  **kwargs,
163
163
  ):
164
+ """
165
+ Forward pass for the RBLN-optimized SD3Transformer2DModel.
166
+
167
+ Args:
168
+ hidden_states (torch.FloatTensor): The currently predicted image embeddings.
169
+ encoder_hidden_states (torch.FloatTensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
170
+ pooled_projections (torch.FloatTensor): Embeddings projected from the embeddings of input conditions.
171
+ timestep (torch.LongTensor): Current denoising step.
172
+ return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
173
+
174
+ Returns:
175
+ (Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
176
+ """
164
177
  sample_batch_size = hidden_states.size()[0]
165
178
  compiled_batch_size = self.compiled_batch_size
166
179
  if sample_batch_size != compiled_batch_size and (
@@ -13,3 +13,4 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from .unet_2d_condition import RBLNUNet2DConditionModel
16
+ from .unet_spatio_temporal_condition import RBLNUNetSpatioTemporalConditionModel
@@ -171,7 +171,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
171
171
  self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
172
172
 
173
173
  @classmethod
174
- def wrap_model_if_needed(
174
+ def _wrap_model_if_needed(
175
175
  cls, model: torch.nn.Module, rbln_config: RBLNUNet2DConditionModelConfig
176
176
  ) -> torch.nn.Module:
177
177
  if model.config.addition_embed_type == "text_time":
@@ -341,7 +341,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
341
341
  timestep_cond: Optional[torch.Tensor] = None,
342
342
  attention_mask: Optional[torch.Tensor] = None,
343
343
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
344
- added_cond_kwargs: Dict[str, torch.Tensor] = {},
344
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
345
345
  down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
346
346
  mid_block_additional_residual: Optional[torch.Tensor] = None,
347
347
  down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
@@ -349,6 +349,22 @@ class RBLNUNet2DConditionModel(RBLNModel):
349
349
  return_dict: bool = True,
350
350
  **kwargs,
351
351
  ) -> Union[UNet2DConditionOutput, Tuple]:
352
+ """
353
+ Forward pass for the RBLN-optimized UNet2DConditionModel.
354
+
355
+ Args:
356
+ sample (torch.Tensor): The noisy input tensor with the following shape `(batch, channel, height, width)`.
357
+ timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
358
+ encoder_hidden_states (torch.Tensor): The encoder hidden states.
359
+ added_cond_kwargs (Dict[str, torch.Tensor]): A kwargs dictionary containing additional embeddings that
360
+ if specified are added to the embeddings that are passed along to the UNet blocks.
361
+ down_block_additional_residuals (Optional[Tuple[torch.Tensor]]): A tuple of tensors that if specified are added to the residuals of down unet blocks.
362
+ mid_block_additional_residual (Optional[torch.Tensor]): A tensor that if specified is added to the residual of the middle unet block.
363
+ return_dict (bool): Whether or not to return a [`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
364
+
365
+ Returns:
366
+ (Union[`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`], Tuple)
367
+ """
352
368
  sample_batch_size = sample.size()[0]
353
369
  compiled_batch_size = self.compiled_batch_size
354
370
  if sample_batch_size != compiled_batch_size and (