optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.4a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. optimum/rbln/__init__.py +12 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +16 -6
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +12 -8
  38. optimum/rbln/transformers/configuration_generic.py +0 -27
  39. optimum/rbln/transformers/modeling_attention_utils.py +242 -109
  40. optimum/rbln/transformers/modeling_generic.py +2 -61
  41. optimum/rbln/transformers/modeling_outputs.py +1 -0
  42. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  43. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  44. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  45. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  46. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  47. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  48. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  49. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  50. optimum/rbln/transformers/models/colpali/modeling_colpali.py +6 -45
  51. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +0 -2
  52. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +10 -1
  53. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  54. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +92 -43
  55. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +207 -64
  56. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  57. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  58. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +140 -46
  59. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  60. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  61. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  62. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +7 -1
  63. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  64. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  65. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +1 -1
  66. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  67. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  68. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -25
  69. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  70. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  71. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  72. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  73. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  74. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  75. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  76. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -9
  77. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -7
  78. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +1 -1
  79. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  80. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  81. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  82. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  83. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  84. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  85. optimum/rbln/transformers/models/siglip/modeling_siglip.py +17 -1
  86. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  87. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  88. optimum/rbln/transformers/models/t5/t5_architecture.py +1 -1
  89. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  90. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  91. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  92. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  93. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  94. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  95. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  96. optimum/rbln/transformers/utils/rbln_quantization.py +9 -0
  97. optimum/rbln/utils/deprecation.py +213 -0
  98. optimum/rbln/utils/hub.py +14 -3
  99. optimum/rbln/utils/import_utils.py +7 -1
  100. optimum/rbln/utils/runtime_utils.py +32 -0
  101. optimum/rbln/utils/submodule.py +3 -1
  102. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/METADATA +2 -2
  103. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/RECORD +106 -99
  104. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/WHEEL +1 -1
  105. optimum/rbln/utils/depreacate_utils.py +0 -16
  106. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/entry_points.txt +0 -0
  107. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -186,12 +186,16 @@ _import_structure = {
186
186
  "diffusers": [
187
187
  "RBLNAutoencoderKL",
188
188
  "RBLNAutoencoderKLConfig",
189
+ "RBLNAutoencoderKLTemporalDecoder",
190
+ "RBLNAutoencoderKLTemporalDecoderConfig",
189
191
  "RBLNAutoencoderKLCosmos",
190
192
  "RBLNAutoencoderKLCosmosConfig",
191
193
  "RBLNAutoPipelineForImage2Image",
192
194
  "RBLNAutoPipelineForInpainting",
193
195
  "RBLNAutoPipelineForText2Image",
194
196
  "RBLNControlNetModel",
197
+ "RBLNUNetSpatioTemporalConditionModel",
198
+ "RBLNStableVideoDiffusionPipeline",
195
199
  "RBLNControlNetModelConfig",
196
200
  "RBLNCosmosTextToWorldPipeline",
197
201
  "RBLNCosmosVideoToWorldPipeline",
@@ -250,6 +254,8 @@ _import_structure = {
250
254
  "RBLNUNet2DConditionModelConfig",
251
255
  "RBLNVQModel",
252
256
  "RBLNVQModelConfig",
257
+ "RBLNUNetSpatioTemporalConditionModelConfig",
258
+ "RBLNStableVideoDiffusionPipelineConfig",
253
259
  ],
254
260
  }
255
261
 
@@ -260,6 +266,8 @@ if TYPE_CHECKING:
260
266
  RBLNAutoencoderKLConfig,
261
267
  RBLNAutoencoderKLCosmos,
262
268
  RBLNAutoencoderKLCosmosConfig,
269
+ RBLNAutoencoderKLTemporalDecoder,
270
+ RBLNAutoencoderKLTemporalDecoderConfig,
263
271
  RBLNAutoPipelineForImage2Image,
264
272
  RBLNAutoPipelineForInpainting,
265
273
  RBLNAutoPipelineForText2Image,
@@ -318,8 +326,12 @@ if TYPE_CHECKING:
318
326
  RBLNStableDiffusionXLInpaintPipelineConfig,
319
327
  RBLNStableDiffusionXLPipeline,
320
328
  RBLNStableDiffusionXLPipelineConfig,
329
+ RBLNStableVideoDiffusionPipeline,
330
+ RBLNStableVideoDiffusionPipelineConfig,
321
331
  RBLNUNet2DConditionModel,
322
332
  RBLNUNet2DConditionModelConfig,
333
+ RBLNUNetSpatioTemporalConditionModel,
334
+ RBLNUNetSpatioTemporalConditionModelConfig,
323
335
  RBLNVQModel,
324
336
  RBLNVQModelConfig,
325
337
  )
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.9.3rc0'
32
- __version_tuple__ = version_tuple = (0, 9, 3, 'rc0')
31
+ __version__ = version = '0.9.4a2'
32
+ __version_tuple__ = version_tuple = (0, 9, 4, 'a2')
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -24,7 +24,7 @@ import torch
24
24
  from packaging.version import Version
25
25
 
26
26
  from .__version__ import __version__
27
- from .utils.depreacate_utils import warn_deprecated_npu
27
+ from .utils.deprecation import warn_deprecated_npu
28
28
  from .utils.logging import get_logger
29
29
  from .utils.runtime_utils import ContextRblnConfig
30
30
 
@@ -117,9 +117,14 @@ class RBLNCompileConfig:
117
117
  return self
118
118
 
119
119
  def get_dummy_inputs(
120
- self, fill=0, static_tensors: Dict[str, torch.Tensor] = {}, meta_tensor_names: List[str] = []
120
+ self,
121
+ fill=0,
122
+ static_tensors: Optional[Dict[str, torch.Tensor]] = None,
123
+ meta_tensor_names: Optional[List[str]] = None,
121
124
  ):
122
125
  dummy = []
126
+ static_tensors = static_tensors if static_tensors is not None else {}
127
+ meta_tensor_names = meta_tensor_names if meta_tensor_names is not None else []
123
128
  for name, shape, dtype in self.input_info:
124
129
  if name in static_tensors:
125
130
  tensor = static_tensors[name]
@@ -255,7 +260,7 @@ class RBLNAutoConfig:
255
260
  def load(
256
261
  path: str,
257
262
  passed_rbln_config: Optional["RBLNModelConfig"] = None,
258
- kwargs: Optional[Dict[str, Any]] = {},
263
+ kwargs: Optional[Dict[str, Any]] = None,
259
264
  return_unused_kwargs: bool = False,
260
265
  ) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
261
266
  """
@@ -269,6 +274,8 @@ class RBLNAutoConfig:
269
274
  Returns:
270
275
  RBLNModelConfig: The loaded RBLNModelConfig.
271
276
  """
277
+ if kwargs is None:
278
+ kwargs = {}
272
279
  cls, config_file = load_config(path)
273
280
 
274
281
  rbln_keys = [key for key in kwargs.keys() if key.startswith("rbln_")]
@@ -528,6 +535,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
528
535
  ]
529
536
  submodules: List[str] = []
530
537
  subclass_non_save_attributes = []
538
+ _allow_no_compile_cfgs = False
531
539
 
532
540
  def initialize_submodule_config(
533
541
  self,
@@ -654,7 +662,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
654
662
  timeout: Optional[int] = None,
655
663
  optimum_rbln_version: Optional[str] = None,
656
664
  _torch_dtype: Optional[str] = None,
657
- _compile_cfgs: List[RBLNCompileConfig] = [],
665
+ _compile_cfgs: Optional[List[RBLNCompileConfig]] = None,
658
666
  *,
659
667
  optimize_host_memory: Optional[bool] = None,
660
668
  **kwargs: Any,
@@ -707,7 +715,8 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
707
715
  if self.optimum_rbln_version is None:
708
716
  self.optimum_rbln_version = __version__
709
717
 
710
- self._compile_cfgs: List[RBLNCompileConfig] = _compile_cfgs
718
+ compile_cfgs = _compile_cfgs if _compile_cfgs is not None else []
719
+ self._compile_cfgs: List[RBLNCompileConfig] = compile_cfgs
711
720
 
712
721
  if not isinstance(self._compile_cfgs, list):
713
722
  raise ValueError("`compile_cfgs` must be a list of `RBLNCompileConfig`.")
@@ -808,7 +817,8 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
808
817
  or len(self._compile_cfgs) == 0
809
818
  or not all(isinstance(cfg, RBLNCompileConfig) for cfg in self._compile_cfgs)
810
819
  ):
811
- raise RuntimeError("`compile_cfgs` must be set before freezing.")
820
+ if not self._allow_no_compile_cfgs:
821
+ raise RuntimeError("`compile_cfgs` must contain at least one `RBLNCompileConfig` before freezing.")
812
822
 
813
823
  for submodule_name in self.submodules:
814
824
  submodule_config = getattr(self, submodule_name, None)
@@ -57,6 +57,9 @@ _import_structure = {
57
57
  "RBLNSD3Transformer2DModelConfig",
58
58
  "RBLNUNet2DConditionModelConfig",
59
59
  "RBLNVQModelConfig",
60
+ "RBLNUNetSpatioTemporalConditionModelConfig",
61
+ "RBLNStableVideoDiffusionPipelineConfig",
62
+ "RBLNAutoencoderKLTemporalDecoderConfig",
60
63
  ],
61
64
  "pipelines": [
62
65
  "RBLNAutoPipelineForImage2Image",
@@ -86,14 +89,17 @@ _import_structure = {
86
89
  "RBLNStableDiffusion3Pipeline",
87
90
  "RBLNStableDiffusion3Img2ImgPipeline",
88
91
  "RBLNStableDiffusion3InpaintPipeline",
92
+ "RBLNStableVideoDiffusionPipeline",
89
93
  ],
90
94
  "models": [
91
95
  "RBLNAutoencoderKL",
92
96
  "RBLNAutoencoderKLCosmos",
93
97
  "RBLNUNet2DConditionModel",
98
+ "RBLNUNetSpatioTemporalConditionModel",
94
99
  "RBLNControlNetModel",
95
100
  "RBLNCosmosTransformer3DModel",
96
101
  "RBLNSD3Transformer2DModel",
102
+ "RBLNAutoencoderKLTemporalDecoder",
97
103
  "RBLNPriorTransformer",
98
104
  "RBLNVQModel",
99
105
  ],
@@ -106,6 +112,7 @@ if TYPE_CHECKING:
106
112
  from .configurations import (
107
113
  RBLNAutoencoderKLConfig,
108
114
  RBLNAutoencoderKLCosmosConfig,
115
+ RBLNAutoencoderKLTemporalDecoderConfig,
109
116
  RBLNControlNetModelConfig,
110
117
  RBLNCosmosTextToWorldPipelineConfig,
111
118
  RBLNCosmosTransformer3DModelConfig,
@@ -132,18 +139,22 @@ if TYPE_CHECKING:
132
139
  RBLNStableDiffusionXLImg2ImgPipelineConfig,
133
140
  RBLNStableDiffusionXLInpaintPipelineConfig,
134
141
  RBLNStableDiffusionXLPipelineConfig,
142
+ RBLNStableVideoDiffusionPipelineConfig,
135
143
  RBLNUNet2DConditionModelConfig,
144
+ RBLNUNetSpatioTemporalConditionModelConfig,
136
145
  RBLNVQModelConfig,
137
146
  )
138
147
  from .modeling_diffusers import RBLNDiffusionMixin
139
148
  from .models import (
140
149
  RBLNAutoencoderKL,
141
150
  RBLNAutoencoderKLCosmos,
151
+ RBLNAutoencoderKLTemporalDecoder,
142
152
  RBLNControlNetModel,
143
153
  RBLNCosmosTransformer3DModel,
144
154
  RBLNPriorTransformer,
145
155
  RBLNSD3Transformer2DModel,
146
156
  RBLNUNet2DConditionModel,
157
+ RBLNUNetSpatioTemporalConditionModel,
147
158
  RBLNVQModel,
148
159
  )
149
160
  from .pipelines import (
@@ -174,6 +185,7 @@ if TYPE_CHECKING:
174
185
  RBLNStableDiffusionXLImg2ImgPipeline,
175
186
  RBLNStableDiffusionXLInpaintPipeline,
176
187
  RBLNStableDiffusionXLPipeline,
188
+ RBLNStableVideoDiffusionPipeline,
177
189
  )
178
190
  else:
179
191
  import sys
@@ -1,11 +1,13 @@
1
1
  from .models import (
2
2
  RBLNAutoencoderKLConfig,
3
3
  RBLNAutoencoderKLCosmosConfig,
4
+ RBLNAutoencoderKLTemporalDecoderConfig,
4
5
  RBLNControlNetModelConfig,
5
6
  RBLNCosmosTransformer3DModelConfig,
6
7
  RBLNPriorTransformerConfig,
7
8
  RBLNSD3Transformer2DModelConfig,
8
9
  RBLNUNet2DConditionModelConfig,
10
+ RBLNUNetSpatioTemporalConditionModelConfig,
9
11
  RBLNVQModelConfig,
10
12
  )
11
13
  from .pipelines import (
@@ -31,4 +33,5 @@ from .pipelines import (
31
33
  RBLNStableDiffusionXLImg2ImgPipelineConfig,
32
34
  RBLNStableDiffusionXLInpaintPipelineConfig,
33
35
  RBLNStableDiffusionXLPipelineConfig,
36
+ RBLNStableVideoDiffusionPipelineConfig,
34
37
  )
@@ -1,8 +1,10 @@
1
1
  from .configuration_autoencoder_kl import RBLNAutoencoderKLConfig
2
2
  from .configuration_autoencoder_kl_cosmos import RBLNAutoencoderKLCosmosConfig
3
+ from .configuration_autoencoder_kl_temporal_decoder import RBLNAutoencoderKLTemporalDecoderConfig
3
4
  from .configuration_controlnet import RBLNControlNetModelConfig
4
5
  from .configuration_prior_transformer import RBLNPriorTransformerConfig
5
6
  from .configuration_transformer_cosmos import RBLNCosmosTransformer3DModelConfig
6
7
  from .configuration_transformer_sd3 import RBLNSD3Transformer2DModelConfig
7
8
  from .configuration_unet_2d_condition import RBLNUNet2DConditionModelConfig
9
+ from .configuration_unet_spatio_temporal_condition import RBLNUNetSpatioTemporalConditionModelConfig
8
10
  from .configuration_vq_model import RBLNVQModelConfig
@@ -0,0 +1,67 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional, Tuple
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNAutoencoderKLTemporalDecoderConfig(RBLNModelConfig):
21
+ def __init__(
22
+ self,
23
+ batch_size: Optional[int] = None,
24
+ sample_size: Optional[Tuple[int, int]] = None,
25
+ uses_encoder: Optional[bool] = None,
26
+ num_frames: Optional[int] = None,
27
+ decode_chunk_size: Optional[int] = None,
28
+ vae_scale_factor: Optional[float] = None,
29
+ **kwargs: Any,
30
+ ):
31
+ """
32
+ Args:
33
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
34
+ sample_size (Optional[Tuple[int, int]]): The spatial dimensions (height, width) of the input/output images.
35
+ If an integer is provided, it's used for both height and width.
36
+ uses_encoder (Optional[bool]): Whether to include the encoder part of the VAE in the model.
37
+ When False, only the decoder is used (for latent-to-image conversion).
38
+ num_frames (Optional[int]): The number of frames in the generated video.
39
+ decode_chunk_size (Optional[int]): The number of frames to decode at once during VAE decoding.
40
+ Useful for managing memory usage during video generation.
41
+ vae_scale_factor (Optional[float]): The scaling factor between pixel space and latent space.
42
+ Determines how much smaller the latent representations are compared to the original images.
43
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
44
+
45
+ Raises:
46
+ ValueError: If batch_size is not a positive integer.
47
+ """
48
+ super().__init__(**kwargs)
49
+ self.batch_size = batch_size or 1
50
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
51
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
52
+
53
+ self.uses_encoder = uses_encoder
54
+ self.num_frames = num_frames
55
+ self.decode_chunk_size = decode_chunk_size
56
+ self.vae_scale_factor = vae_scale_factor
57
+ self.sample_size = sample_size
58
+ if isinstance(sample_size, int):
59
+ self.sample_size = (sample_size, sample_size)
60
+
61
+ @property
62
+ def image_size(self):
63
+ return self.sample_size
64
+
65
+ @property
66
+ def latent_sample_size(self):
67
+ return (self.image_size[0] // self.vae_scale_factor, self.image_size[1] // self.vae_scale_factor)
@@ -0,0 +1,59 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional, Tuple
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNUNetSpatioTemporalConditionModelConfig(RBLNModelConfig):
21
+ subclass_non_save_attributes = ["_batch_size_is_specified"]
22
+
23
+ def __init__(
24
+ self,
25
+ batch_size: Optional[int] = None,
26
+ sample_size: Optional[Tuple[int, int]] = None,
27
+ in_features: Optional[int] = None,
28
+ num_frames: Optional[int] = None,
29
+ **kwargs: Any,
30
+ ):
31
+ """
32
+ Args:
33
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
34
+ sample_size (Optional[Tuple[int, int]]): The spatial dimensions (height, width) of the generated samples.
35
+ If an integer is provided, it's used for both height and width.
36
+ in_features (Optional[int]): Number of input features for the model.
37
+ num_frames (Optional[int]): The number of frames in the generated video.
38
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
39
+
40
+ Raises:
41
+ ValueError: If batch_size is not a positive integer.
42
+ """
43
+ super().__init__(**kwargs)
44
+ self._batch_size_is_specified = batch_size is not None
45
+
46
+ self.batch_size = batch_size or 1
47
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
48
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
49
+
50
+ self.in_features = in_features
51
+ self.num_frames = num_frames
52
+
53
+ self.sample_size = sample_size
54
+ if isinstance(sample_size, int):
55
+ self.sample_size = (sample_size, sample_size)
56
+
57
+ @property
58
+ def batch_size_is_specified(self):
59
+ return self._batch_size_is_specified
@@ -29,3 +29,6 @@ from .configuration_stable_diffusion_xl import (
29
29
  RBLNStableDiffusionXLInpaintPipelineConfig,
30
30
  RBLNStableDiffusionXLPipelineConfig,
31
31
  )
32
+ from .configuration_stable_video_diffusion import (
33
+ RBLNStableVideoDiffusionPipelineConfig,
34
+ )
@@ -0,0 +1,114 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....transformers import RBLNCLIPVisionModelWithProjectionConfig
19
+ from ..models import RBLNAutoencoderKLTemporalDecoderConfig, RBLNUNetSpatioTemporalConditionModelConfig
20
+
21
+
22
+ class RBLNStableVideoDiffusionPipelineConfig(RBLNModelConfig):
23
+ submodules = ["image_encoder", "unet", "vae"]
24
+ _vae_uses_encoder = True
25
+
26
+ def __init__(
27
+ self,
28
+ image_encoder: Optional[RBLNCLIPVisionModelWithProjectionConfig] = None,
29
+ unet: Optional[RBLNUNetSpatioTemporalConditionModelConfig] = None,
30
+ vae: Optional[RBLNAutoencoderKLTemporalDecoderConfig] = None,
31
+ *,
32
+ batch_size: Optional[int] = None,
33
+ height: Optional[int] = None,
34
+ width: Optional[int] = None,
35
+ num_frames: Optional[int] = None,
36
+ decode_chunk_size: Optional[int] = None,
37
+ guidance_scale: Optional[float] = None,
38
+ **kwargs: Any,
39
+ ):
40
+ """
41
+ Args:
42
+ image_encoder (Optional[RBLNCLIPVisionModelWithProjectionConfig]): Configuration for the image encoder component.
43
+ Initialized as RBLNCLIPVisionModelWithProjectionConfig if not provided.
44
+ unet (Optional[RBLNUNetSpatioTemporalConditionModelConfig]): Configuration for the UNet model component.
45
+ Initialized as RBLNUNetSpatioTemporalConditionModelConfig if not provided.
46
+ vae (Optional[RBLNAutoencoderKLTemporalDecoderConfig]): Configuration for the VAE model component.
47
+ Initialized as RBLNAutoencoderKLTemporalDecoderConfig if not provided.
48
+ batch_size (Optional[int]): Batch size for inference, applied to all submodules.
49
+ height (Optional[int]): Height of the generated images.
50
+ width (Optional[int]): Width of the generated images.
51
+ num_frames (Optional[int]): The number of frames in the generated video.
52
+ decode_chunk_size (Optional[int]): The number of frames to decode at once during VAE decoding.
53
+ Useful for managing memory usage during video generation.
54
+ guidance_scale (Optional[float]): Scale for classifier-free guidance.
55
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
56
+
57
+ Raises:
58
+ ValueError: If both image_size and height/width are provided.
59
+
60
+ Note:
61
+ When guidance_scale > 1.0, the UNet batch size is automatically doubled to
62
+ accommodate classifier-free guidance.
63
+ """
64
+ super().__init__(**kwargs)
65
+ if height is not None and width is not None:
66
+ image_size = (height, width)
67
+ else:
68
+ # Get default image size from original class to set UNet, VAE image size
69
+ height = self.get_default_values_for_original_cls("__call__", ["height"])["height"]
70
+ width = self.get_default_values_for_original_cls("__call__", ["width"])["width"]
71
+ image_size = (height, width)
72
+
73
+ self.image_encoder = self.initialize_submodule_config(
74
+ image_encoder, cls_name="RBLNCLIPVisionModelWithProjectionConfig", batch_size=batch_size
75
+ )
76
+ self.unet = self.initialize_submodule_config(
77
+ unet,
78
+ cls_name="RBLNUNetSpatioTemporalConditionModelConfig",
79
+ num_frames=num_frames,
80
+ )
81
+ self.vae = self.initialize_submodule_config(
82
+ vae,
83
+ cls_name="RBLNAutoencoderKLTemporalDecoderConfig",
84
+ batch_size=batch_size,
85
+ num_frames=num_frames,
86
+ decode_chunk_size=decode_chunk_size,
87
+ uses_encoder=self.__class__._vae_uses_encoder,
88
+ sample_size=image_size, # image size is equal to sample size in vae
89
+ )
90
+
91
+ # Get default guidance scale from original class to set UNet batch size
92
+ if guidance_scale is None:
93
+ guidance_scale = self.get_default_values_for_original_cls("__call__", ["max_guidance_scale"])[
94
+ "max_guidance_scale"
95
+ ]
96
+
97
+ if not self.unet.batch_size_is_specified:
98
+ do_classifier_free_guidance = guidance_scale > 1.0
99
+ if do_classifier_free_guidance:
100
+ self.unet.batch_size = self.image_encoder.batch_size * 2
101
+ else:
102
+ self.unet.batch_size = self.image_encoder.batch_size
103
+
104
+ @property
105
+ def batch_size(self):
106
+ return self.vae.batch_size
107
+
108
+ @property
109
+ def sample_size(self):
110
+ return self.unet.sample_size
111
+
112
+ @property
113
+ def image_size(self):
114
+ return self.vae.sample_size
@@ -136,7 +136,7 @@ class RBLNDiffusionMixin:
136
136
  *,
137
137
  export: bool = None,
138
138
  model_save_dir: Optional[PathLike] = None,
139
- rbln_config: Dict[str, Any] = {},
139
+ rbln_config: Optional[Dict[str, Any]] = None,
140
140
  lora_ids: Optional[Union[str, List[str]]] = None,
141
141
  lora_weights_names: Optional[Union[str, List[str]]] = None,
142
142
  lora_scales: Optional[Union[float, List[float]]] = None,
@@ -22,9 +22,11 @@ _import_structure = {
22
22
  "RBLNAutoencoderKL",
23
23
  "RBLNAutoencoderKLCosmos",
24
24
  "RBLNVQModel",
25
+ "RBLNAutoencoderKLTemporalDecoder",
25
26
  ],
26
27
  "unets": [
27
28
  "RBLNUNet2DConditionModel",
29
+ "RBLNUNetSpatioTemporalConditionModel",
28
30
  ],
29
31
  "controlnet": ["RBLNControlNetModel"],
30
32
  "transformers": [
@@ -35,10 +37,22 @@ _import_structure = {
35
37
  }
36
38
 
37
39
  if TYPE_CHECKING:
38
- from .autoencoders import RBLNAutoencoderKL, RBLNAutoencoderKLCosmos, RBLNVQModel
40
+ from .autoencoders import (
41
+ RBLNAutoencoderKL,
42
+ RBLNAutoencoderKLCosmos,
43
+ RBLNAutoencoderKLTemporalDecoder,
44
+ RBLNVQModel,
45
+ )
39
46
  from .controlnet import RBLNControlNetModel
40
- from .transformers import RBLNCosmosTransformer3DModel, RBLNPriorTransformer, RBLNSD3Transformer2DModel
41
- from .unets import RBLNUNet2DConditionModel
47
+ from .transformers import (
48
+ RBLNCosmosTransformer3DModel,
49
+ RBLNPriorTransformer,
50
+ RBLNSD3Transformer2DModel,
51
+ )
52
+ from .unets import (
53
+ RBLNUNet2DConditionModel,
54
+ RBLNUNetSpatioTemporalConditionModel,
55
+ )
42
56
  else:
43
57
  import sys
44
58
 
@@ -14,4 +14,5 @@
14
14
 
15
15
  from .autoencoder_kl import RBLNAutoencoderKL
16
16
  from .autoencoder_kl_cosmos import RBLNAutoencoderKLCosmos
17
+ from .autoencoder_kl_temporal_decoder import RBLNAutoencoderKLTemporalDecoder
17
18
  from .vq_model import RBLNVQModel
@@ -68,7 +68,7 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
68
68
  self.image_size = self.rbln_config.image_size
69
69
 
70
70
  @classmethod
71
- def wrap_model_if_needed(
71
+ def _wrap_model_if_needed(
72
72
  cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLCosmosConfig
73
73
  ) -> torch.nn.Module:
74
74
  decoder_model = _VAECosmosDecoder(model)
@@ -98,7 +98,7 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
98
98
 
99
99
  compiled_models = {}
100
100
  if rbln_config.uses_encoder:
101
- encoder_model, decoder_model = cls.wrap_model_if_needed(model, rbln_config)
101
+ encoder_model, decoder_model = cls._wrap_model_if_needed(model, rbln_config)
102
102
  enc_compiled_model = cls.compile(
103
103
  encoder_model,
104
104
  rbln_compile_config=rbln_config.compile_cfgs[0],
@@ -107,7 +107,7 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
107
107
  )
108
108
  compiled_models["encoder"] = enc_compiled_model
109
109
  else:
110
- decoder_model = cls.wrap_model_if_needed(model, rbln_config)
110
+ decoder_model = cls._wrap_model_if_needed(model, rbln_config)
111
111
  dec_compiled_model = cls.compile(
112
112
  decoder_model,
113
113
  rbln_compile_config=rbln_config.compile_cfgs[-1],