optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,114 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....transformers import RBLNCLIPVisionModelWithProjectionConfig
19
+ from ..models import RBLNAutoencoderKLTemporalDecoderConfig, RBLNUNetSpatioTemporalConditionModelConfig
20
+
21
+
22
+ class RBLNStableVideoDiffusionPipelineConfig(RBLNModelConfig):
23
+ submodules = ["image_encoder", "unet", "vae"]
24
+ _vae_uses_encoder = True
25
+
26
+ def __init__(
27
+ self,
28
+ image_encoder: Optional[RBLNCLIPVisionModelWithProjectionConfig] = None,
29
+ unet: Optional[RBLNUNetSpatioTemporalConditionModelConfig] = None,
30
+ vae: Optional[RBLNAutoencoderKLTemporalDecoderConfig] = None,
31
+ *,
32
+ batch_size: Optional[int] = None,
33
+ height: Optional[int] = None,
34
+ width: Optional[int] = None,
35
+ num_frames: Optional[int] = None,
36
+ decode_chunk_size: Optional[int] = None,
37
+ guidance_scale: Optional[float] = None,
38
+ **kwargs: Any,
39
+ ):
40
+ """
41
+ Args:
42
+ image_encoder (Optional[RBLNCLIPVisionModelWithProjectionConfig]): Configuration for the image encoder component.
43
+ Initialized as RBLNCLIPVisionModelWithProjectionConfig if not provided.
44
+ unet (Optional[RBLNUNetSpatioTemporalConditionModelConfig]): Configuration for the UNet model component.
45
+ Initialized as RBLNUNetSpatioTemporalConditionModelConfig if not provided.
46
+ vae (Optional[RBLNAutoencoderKLTemporalDecoderConfig]): Configuration for the VAE model component.
47
+ Initialized as RBLNAutoencoderKLTemporalDecoderConfig if not provided.
48
+ batch_size (Optional[int]): Batch size for inference, applied to all submodules.
49
+ height (Optional[int]): Height of the generated images.
50
+ width (Optional[int]): Width of the generated images.
51
+ num_frames (Optional[int]): The number of frames in the generated video.
52
+ decode_chunk_size (Optional[int]): The number of frames to decode at once during VAE decoding.
53
+ Useful for managing memory usage during video generation.
54
+ guidance_scale (Optional[float]): Scale for classifier-free guidance.
55
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
56
+
57
+ Raises:
58
+ ValueError: If both image_size and height/width are provided.
59
+
60
+ Note:
61
+ When guidance_scale > 1.0, the UNet batch size is automatically doubled to
62
+ accommodate classifier-free guidance.
63
+ """
64
+ super().__init__(**kwargs)
65
+ if height is not None and width is not None:
66
+ image_size = (height, width)
67
+ else:
68
+ # Get default image size from original class to set UNet, VAE image size
69
+ height = self.get_default_values_for_original_cls("__call__", ["height"])["height"]
70
+ width = self.get_default_values_for_original_cls("__call__", ["width"])["width"]
71
+ image_size = (height, width)
72
+
73
+ self.image_encoder = self.initialize_submodule_config(
74
+ image_encoder, cls_name="RBLNCLIPVisionModelWithProjectionConfig", batch_size=batch_size
75
+ )
76
+ self.unet = self.initialize_submodule_config(
77
+ unet,
78
+ cls_name="RBLNUNetSpatioTemporalConditionModelConfig",
79
+ num_frames=num_frames,
80
+ )
81
+ self.vae = self.initialize_submodule_config(
82
+ vae,
83
+ cls_name="RBLNAutoencoderKLTemporalDecoderConfig",
84
+ batch_size=batch_size,
85
+ num_frames=num_frames,
86
+ decode_chunk_size=decode_chunk_size,
87
+ uses_encoder=self.__class__._vae_uses_encoder,
88
+ sample_size=image_size, # image size is equal to sample size in vae
89
+ )
90
+
91
+ # Get default guidance scale from original class to set UNet batch size
92
+ if guidance_scale is None:
93
+ guidance_scale = self.get_default_values_for_original_cls("__call__", ["max_guidance_scale"])[
94
+ "max_guidance_scale"
95
+ ]
96
+
97
+ if not self.unet.batch_size_is_specified:
98
+ do_classifier_free_guidance = guidance_scale > 1.0
99
+ if do_classifier_free_guidance:
100
+ self.unet.batch_size = self.image_encoder.batch_size * 2
101
+ else:
102
+ self.unet.batch_size = self.image_encoder.batch_size
103
+
104
+ @property
105
+ def batch_size(self):
106
+ return self.vae.batch_size
107
+
108
+ @property
109
+ def sample_size(self):
110
+ return self.unet.sample_size
111
+
112
+ @property
113
+ def image_size(self):
114
+ return self.vae.sample_size
@@ -136,7 +136,7 @@ class RBLNDiffusionMixin:
136
136
  *,
137
137
  export: bool = None,
138
138
  model_save_dir: Optional[PathLike] = None,
139
- rbln_config: Dict[str, Any] = {},
139
+ rbln_config: Optional[Dict[str, Any]] = None,
140
140
  lora_ids: Optional[Union[str, List[str]]] = None,
141
141
  lora_weights_names: Optional[Union[str, List[str]]] = None,
142
142
  lora_scales: Optional[Union[float, List[float]]] = None,
@@ -22,9 +22,11 @@ _import_structure = {
22
22
  "RBLNAutoencoderKL",
23
23
  "RBLNAutoencoderKLCosmos",
24
24
  "RBLNVQModel",
25
+ "RBLNAutoencoderKLTemporalDecoder",
25
26
  ],
26
27
  "unets": [
27
28
  "RBLNUNet2DConditionModel",
29
+ "RBLNUNetSpatioTemporalConditionModel",
28
30
  ],
29
31
  "controlnet": ["RBLNControlNetModel"],
30
32
  "transformers": [
@@ -35,10 +37,22 @@ _import_structure = {
35
37
  }
36
38
 
37
39
  if TYPE_CHECKING:
38
- from .autoencoders import RBLNAutoencoderKL, RBLNAutoencoderKLCosmos, RBLNVQModel
40
+ from .autoencoders import (
41
+ RBLNAutoencoderKL,
42
+ RBLNAutoencoderKLCosmos,
43
+ RBLNAutoencoderKLTemporalDecoder,
44
+ RBLNVQModel,
45
+ )
39
46
  from .controlnet import RBLNControlNetModel
40
- from .transformers import RBLNCosmosTransformer3DModel, RBLNPriorTransformer, RBLNSD3Transformer2DModel
41
- from .unets import RBLNUNet2DConditionModel
47
+ from .transformers import (
48
+ RBLNCosmosTransformer3DModel,
49
+ RBLNPriorTransformer,
50
+ RBLNSD3Transformer2DModel,
51
+ )
52
+ from .unets import (
53
+ RBLNUNet2DConditionModel,
54
+ RBLNUNetSpatioTemporalConditionModel,
55
+ )
42
56
  else:
43
57
  import sys
44
58
 
@@ -14,4 +14,5 @@
14
14
 
15
15
  from .autoencoder_kl import RBLNAutoencoderKL
16
16
  from .autoencoder_kl_cosmos import RBLNAutoencoderKLCosmos
17
+ from .autoencoder_kl_temporal_decoder import RBLNAutoencoderKLTemporalDecoder
17
18
  from .vq_model import RBLNVQModel
@@ -68,7 +68,7 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
68
68
  self.image_size = self.rbln_config.image_size
69
69
 
70
70
  @classmethod
71
- def wrap_model_if_needed(
71
+ def _wrap_model_if_needed(
72
72
  cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLCosmosConfig
73
73
  ) -> torch.nn.Module:
74
74
  decoder_model = _VAECosmosDecoder(model)
@@ -98,7 +98,7 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
98
98
 
99
99
  compiled_models = {}
100
100
  if rbln_config.uses_encoder:
101
- encoder_model, decoder_model = cls.wrap_model_if_needed(model, rbln_config)
101
+ encoder_model, decoder_model = cls._wrap_model_if_needed(model, rbln_config)
102
102
  enc_compiled_model = cls.compile(
103
103
  encoder_model,
104
104
  rbln_compile_config=rbln_config.compile_cfgs[0],
@@ -107,7 +107,7 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
107
107
  )
108
108
  compiled_models["encoder"] = enc_compiled_model
109
109
  else:
110
- decoder_model = cls.wrap_model_if_needed(model, rbln_config)
110
+ decoder_model = cls._wrap_model_if_needed(model, rbln_config)
111
111
  dec_compiled_model = cls.compile(
112
112
  decoder_model,
113
113
  rbln_compile_config=rbln_config.compile_cfgs[-1],
@@ -0,0 +1,275 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Dict, List, Tuple, Union
16
+
17
+ import rebel
18
+ import torch # noqa: I001
19
+ from diffusers import AutoencoderKLTemporalDecoder
20
+ from diffusers.models.autoencoders.vae import DecoderOutput
21
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
22
+ from transformers import PretrainedConfig
23
+
24
+ from ....configuration_utils import RBLNCompileConfig
25
+ from ....modeling import RBLNModel
26
+ from ....utils.logging import get_logger
27
+ from ...configurations import RBLNAutoencoderKLTemporalDecoderConfig
28
+ from ...modeling_diffusers import RBLNDiffusionMixin
29
+ from .vae import (
30
+ DiagonalGaussianDistribution,
31
+ RBLNRuntimeVAEDecoder,
32
+ RBLNRuntimeVAEEncoder,
33
+ _VAEEncoder,
34
+ _VAETemporalDecoder,
35
+ )
36
+
37
+
38
+ if TYPE_CHECKING:
39
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
40
+
41
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
42
+
43
+ logger = get_logger(__name__)
44
+
45
+
46
+ class RBLNAutoencoderKLTemporalDecoder(RBLNModel):
47
+ auto_model_class = AutoencoderKLTemporalDecoder
48
+ hf_library_name = "diffusers"
49
+ _rbln_config_class = RBLNAutoencoderKLTemporalDecoderConfig
50
+
51
+ def __post_init__(self, **kwargs):
52
+ super().__post_init__(**kwargs)
53
+
54
+ if self.rbln_config.uses_encoder:
55
+ self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
56
+ self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[-1], main_input_name="z")
57
+ self.image_size = self.rbln_config.image_size
58
+
59
+ @classmethod
60
+ def _wrap_model_if_needed(
61
+ cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLTemporalDecoderConfig
62
+ ) -> torch.nn.Module:
63
+ decoder_model = _VAETemporalDecoder(model)
64
+ decoder_model.num_frames = rbln_config.decode_chunk_size
65
+ decoder_model.eval()
66
+
67
+ if rbln_config.uses_encoder:
68
+ encoder_model = _VAEEncoder(model)
69
+ encoder_model.eval()
70
+ return encoder_model, decoder_model
71
+ else:
72
+ return decoder_model
73
+
74
+ @classmethod
75
+ def get_compiled_model(
76
+ cls, model, rbln_config: RBLNAutoencoderKLTemporalDecoderConfig
77
+ ) -> Dict[str, rebel.RBLNCompiledModel]:
78
+ compiled_models = {}
79
+ if rbln_config.uses_encoder:
80
+ encoder_model, decoder_model = cls._wrap_model_if_needed(model, rbln_config)
81
+ enc_compiled_model = cls.compile(
82
+ encoder_model,
83
+ rbln_compile_config=rbln_config.compile_cfgs[0],
84
+ create_runtimes=rbln_config.create_runtimes,
85
+ device=rbln_config.device_map["encoder"],
86
+ )
87
+ compiled_models["encoder"] = enc_compiled_model
88
+ else:
89
+ decoder_model = cls._wrap_model_if_needed(model, rbln_config)
90
+ dec_compiled_model = cls.compile(
91
+ decoder_model,
92
+ rbln_compile_config=rbln_config.compile_cfgs[-1],
93
+ create_runtimes=rbln_config.create_runtimes,
94
+ device=rbln_config.device_map["decoder"],
95
+ )
96
+ compiled_models["decoder"] = dec_compiled_model
97
+
98
+ return compiled_models
99
+
100
+ @classmethod
101
+ def get_vae_sample_size(
102
+ cls,
103
+ pipe: "RBLNDiffusionMixin",
104
+ rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
105
+ return_vae_scale_factor: bool = False,
106
+ ) -> Tuple[int, int]:
107
+ sample_size = rbln_config.sample_size
108
+ if hasattr(pipe, "vae_scale_factor"):
109
+ vae_scale_factor = pipe.vae_scale_factor
110
+ else:
111
+ if hasattr(pipe.vae.config, "block_out_channels"):
112
+ vae_scale_factor = 2 ** (len(pipe.vae.config.block_out_channels) - 1)
113
+ else:
114
+ vae_scale_factor = 8 # vae image processor default value 8 (int)
115
+
116
+ if sample_size is None:
117
+ sample_size = pipe.unet.config.sample_size
118
+ if isinstance(sample_size, int):
119
+ sample_size = (sample_size, sample_size)
120
+ sample_size = (sample_size[0] * vae_scale_factor, sample_size[1] * vae_scale_factor)
121
+
122
+ if return_vae_scale_factor:
123
+ return sample_size, vae_scale_factor
124
+ else:
125
+ return sample_size
126
+
127
+ @classmethod
128
+ def update_rbln_config_using_pipe(
129
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
130
+ ) -> "RBLNDiffusionMixinConfig":
131
+ rbln_config.vae.sample_size, rbln_config.vae.vae_scale_factor = cls.get_vae_sample_size(
132
+ pipe, rbln_config.vae, return_vae_scale_factor=True
133
+ )
134
+
135
+ if rbln_config.vae.num_frames is None:
136
+ if hasattr(pipe.unet.config, "num_frames"):
137
+ rbln_config.vae.num_frames = pipe.unet.config.num_frames
138
+ else:
139
+ raise ValueError("num_frames should be specified in unet config.json")
140
+
141
+ if rbln_config.vae.decode_chunk_size is None:
142
+ rbln_config.vae.decode_chunk_size = rbln_config.vae.num_frames
143
+
144
+ def chunk_frame(num_frames, decode_chunk_size):
145
+ # get closest divisor to num_frames
146
+ divisors = [i for i in range(1, num_frames) if num_frames % i == 0]
147
+ closest = min(divisors, key=lambda x: abs(x - decode_chunk_size))
148
+ if decode_chunk_size != closest:
149
+ logger.warning(
150
+ f"To ensure successful model compilation and prevent device OOM, {decode_chunk_size} is set to {closest}."
151
+ )
152
+ return closest
153
+
154
+ decode_chunk_size = chunk_frame(rbln_config.vae.num_frames, rbln_config.vae.decode_chunk_size)
155
+ rbln_config.vae.decode_chunk_size = decode_chunk_size
156
+ return rbln_config
157
+
158
+ @classmethod
159
+ def _update_rbln_config(
160
+ cls,
161
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
162
+ model: "PreTrainedModel",
163
+ model_config: "PretrainedConfig",
164
+ rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
165
+ ) -> RBLNAutoencoderKLTemporalDecoderConfig:
166
+ if rbln_config.sample_size is None:
167
+ rbln_config.sample_size = model_config.sample_size
168
+
169
+ if rbln_config.vae_scale_factor is None:
170
+ if hasattr(model_config, "block_out_channels"):
171
+ rbln_config.vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
172
+ else:
173
+ # vae image processor default value 8 (int)
174
+ rbln_config.vae_scale_factor = 8
175
+
176
+ compile_cfgs = []
177
+ if rbln_config.uses_encoder:
178
+ vae_enc_input_info = [
179
+ (
180
+ "x",
181
+ [
182
+ rbln_config.batch_size,
183
+ model_config.in_channels,
184
+ rbln_config.sample_size[0],
185
+ rbln_config.sample_size[1],
186
+ ],
187
+ "float32",
188
+ )
189
+ ]
190
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
191
+
192
+ decode_batch_size = rbln_config.batch_size * rbln_config.decode_chunk_size
193
+ vae_dec_input_info = [
194
+ (
195
+ "z",
196
+ [
197
+ decode_batch_size,
198
+ model_config.latent_channels,
199
+ rbln_config.latent_sample_size[0],
200
+ rbln_config.latent_sample_size[1],
201
+ ],
202
+ "float32",
203
+ )
204
+ ]
205
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
206
+
207
+ rbln_config.set_compile_cfgs(compile_cfgs)
208
+ return rbln_config
209
+
210
+ @classmethod
211
+ def _create_runtimes(
212
+ cls,
213
+ compiled_models: List[rebel.RBLNCompiledModel],
214
+ rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
215
+ ) -> List[rebel.Runtime]:
216
+ if len(compiled_models) == 1:
217
+ # decoder
218
+ expected_models = ["decoder"]
219
+ else:
220
+ expected_models = ["encoder", "decoder"]
221
+
222
+ if any(model_name not in rbln_config.device_map for model_name in expected_models):
223
+ cls._raise_missing_compiled_file_error(expected_models)
224
+
225
+ device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
226
+ return [
227
+ rebel.Runtime(
228
+ compiled_model,
229
+ tensor_type="pt",
230
+ device=device_val,
231
+ activate_profiler=rbln_config.activate_profiler,
232
+ timeout=rbln_config.timeout,
233
+ )
234
+ for compiled_model, device_val in zip(compiled_models, device_vals)
235
+ ]
236
+
237
+ def encode(
238
+ self, x: torch.FloatTensor, return_dict: bool = True
239
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
240
+ """
241
+ Encode an input image into a latent representation.
242
+
243
+ Args:
244
+ x: The input image to encode.
245
+ return_dict:
246
+ Whether to return output as a dictionary. Defaults to True.
247
+
248
+ Returns:
249
+ The latent representation or AutoencoderKLOutput if return_dict=True
250
+ """
251
+ posterior = self.encoder.encode(x)
252
+
253
+ if not return_dict:
254
+ return (posterior,)
255
+
256
+ return AutoencoderKLOutput(latent_dist=posterior)
257
+
258
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
259
+ """
260
+ Decode a latent representation into a video.
261
+
262
+ Args:
263
+ z: The latent representation to decode.
264
+ return_dict:
265
+ Whether to return output as a dictionary. Defaults to True.
266
+
267
+ Returns:
268
+ The decoded video or DecoderOutput if return_dict=True
269
+ """
270
+ decoded = self.decoder.decode(z)
271
+
272
+ if not return_dict:
273
+ return (decoded,)
274
+
275
+ return DecoderOutput(sample=decoded)
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, List
15
+ from typing import TYPE_CHECKING, List, Union
16
16
 
17
17
  import torch
18
18
  from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution, IdentityDistribution
@@ -21,7 +21,7 @@ from ....utils.runtime_utils import RBLNPytorchRuntime
21
21
 
22
22
 
23
23
  if TYPE_CHECKING:
24
- from diffusers import AutoencoderKL, AutoencoderKLCosmos, VQModel
24
+ from diffusers import AutoencoderKL, AutoencoderKLCosmos, AutoencoderKLTemporalDecoder, VQModel
25
25
 
26
26
 
27
27
  class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
@@ -67,18 +67,37 @@ class _VAEDecoder(torch.nn.Module):
67
67
  return vae_out
68
68
 
69
69
 
70
+ class _VAETemporalDecoder(torch.nn.Module):
71
+ def __init__(self, vae: "AutoencoderKLTemporalDecoder"):
72
+ super().__init__()
73
+ self.vae = vae
74
+ self.num_frames = None
75
+
76
+ def forward(self, z):
77
+ vae_out = self.vae.decode(z, num_frames=self.num_frames, return_dict=False)
78
+ return vae_out
79
+
80
+
70
81
  class _VAEEncoder(torch.nn.Module):
71
- def __init__(self, vae: "AutoencoderKL"):
82
+ def __init__(self, vae: Union["AutoencoderKL", "AutoencoderKLTemporalDecoder"]):
72
83
  super().__init__()
73
84
  self.vae = vae
74
85
 
75
86
  def encode(self, x: torch.FloatTensor, return_dict: bool = True):
76
- if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
77
- return self.tiled_encode(x, return_dict=return_dict)
87
+ if hasattr(self, "use_tiling") and hasattr(self, "use_slicing"):
88
+ if self.use_tiling and (
89
+ x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size
90
+ ):
91
+ return self.tiled_encode(x, return_dict=return_dict)
92
+
93
+ if self.use_slicing and x.shape[0] > 1:
94
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
95
+ h = torch.cat(encoded_slices)
96
+ else:
97
+ h = self.encoder(x)
98
+ if self.quant_conv is not None:
99
+ h = self.quant_conv(h)
78
100
 
79
- if self.use_slicing and x.shape[0] > 1:
80
- encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
81
- h = torch.cat(encoded_slices)
82
101
  else:
83
102
  h = self.encoder(x)
84
103
  if self.quant_conv is not None:
@@ -118,7 +118,7 @@ class RBLNControlNetModel(RBLNModel):
118
118
  )
119
119
 
120
120
  @classmethod
121
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
121
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
122
122
  use_encoder_hidden_states = False
123
123
  for down_block in model.down_blocks:
124
124
  if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
@@ -215,10 +215,25 @@ class RBLNControlNetModel(RBLNModel):
215
215
  encoder_hidden_states: torch.Tensor,
216
216
  controlnet_cond: torch.FloatTensor,
217
217
  conditioning_scale: torch.Tensor = 1.0,
218
- added_cond_kwargs: Dict[str, torch.Tensor] = {},
218
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
219
219
  return_dict: bool = True,
220
220
  **kwargs,
221
221
  ):
222
+ """
223
+ Forward pass for the RBLN-optimized ControlNetModel.
224
+
225
+ Args:
226
+ sample (torch.FloatTensor): The noisy input tensor.
227
+ timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
228
+ encoder_hidden_states (torch.Tensor): The encoder hidden states.
229
+ controlnet_cond (torch.FloatTensor): The conditional input tensor of shape `(batch_size, max_seq_len, hidden_size)`.
230
+ conditioning_scale (torch.Tensor): The scale factor for ControlNet outputs.
231
+ added_cond_kwargs (Dict[str, torch.Tensor]): Additional conditions for the Stable Diffusion XL UNet.
232
+ return_dict (bool): Whether or not to return a [`~diffusers.models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple
233
+
234
+ Returns:
235
+ (Union[`~diffusers.models.controlnets.controlnet.ControlNetOutput`], Tuple)
236
+ """
222
237
  sample_batch_size = sample.size()[0]
223
238
  compiled_batch_size = self.compiled_batch_size
224
239
  if sample_batch_size != compiled_batch_size and (
@@ -77,7 +77,7 @@ class RBLNPriorTransformer(RBLNModel):
77
77
  self.clip_std = artifacts["clip_std"]
78
78
 
79
79
  @classmethod
80
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
80
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
81
81
  return _PriorTransformer(model).eval()
82
82
 
83
83
  @classmethod
@@ -128,13 +128,27 @@ class RBLNPriorTransformer(RBLNModel):
128
128
 
129
129
  def forward(
130
130
  self,
131
- hidden_states,
131
+ hidden_states: torch.Tensor,
132
132
  timestep: Union[torch.Tensor, float, int],
133
133
  proj_embedding: torch.Tensor,
134
134
  encoder_hidden_states: Optional[torch.Tensor] = None,
135
135
  attention_mask: Optional[torch.Tensor] = None,
136
136
  return_dict: bool = True,
137
137
  ):
138
+ """
139
+ Forward pass for the RBLN-optimized PriorTransformer.
140
+
141
+ Args:
142
+ hidden_states (torch.Tensor): The currently predicted image embeddings.
143
+ timestep (Union[torch.Tensor, float, int]): Current denoising step.
144
+ proj_embedding (torch.Tensor): Projected embedding vector the denoising process is conditioned on.
145
+ encoder_hidden_states (Optional[torch.Tensor]): Hidden states of the text embeddings the denoising process is conditioned on.
146
+ attention_mask (Optional[torch.Tensor]): Text mask for the text embeddings.
147
+ return_dict (bool): Whether or not to return a [`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`] instead of a plain tuple.
148
+
149
+ Returns:
150
+ (Union[`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`, Tuple])
151
+ """
138
152
  # Convert timestep(long) and attention_mask(bool) to float
139
153
  return super().forward(
140
154
  hidden_states,
@@ -185,7 +185,7 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
185
185
  )
186
186
 
187
187
  @classmethod
188
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
188
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
189
189
  num_latent_frames = rbln_config.num_latent_frames
190
190
  latent_height = rbln_config.latent_height
191
191
  latent_width = rbln_config.latent_width
@@ -303,6 +303,21 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
303
303
  padding_mask: Optional[torch.Tensor] = None,
304
304
  return_dict: bool = True,
305
305
  ):
306
+ """
307
+ Forward pass for the RBLN-optimized CosmosTransformer3DModel.
308
+
309
+ Args:
310
+ hidden_states (torch.Tensor): The currently predicted image embeddings.
311
+ timestep (torch.Tensor): Current denoising step.
312
+ encoder_hidden_states (torch.Tensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
313
+ fps: (Optional[int]): Frames per second for the video being generated.
314
+ condition_mask (Optional[torch.Tensor]): Tensor of condition mask.
315
+ padding_mask (Optional[torch.Tensor]): Tensor of padding mask.
316
+ return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
317
+
318
+ Returns:
319
+ (Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
320
+ """
306
321
  (
307
322
  hidden_states,
308
323
  temb,
@@ -77,7 +77,7 @@ class RBLNSD3Transformer2DModel(RBLNModel):
77
77
  super().__post_init__(**kwargs)
78
78
 
79
79
  @classmethod
80
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
80
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
81
81
  return SD3Transformer2DModelWrapper(model).eval()
82
82
 
83
83
  @classmethod
@@ -161,6 +161,19 @@ class RBLNSD3Transformer2DModel(RBLNModel):
161
161
  return_dict: bool = True,
162
162
  **kwargs,
163
163
  ):
164
+ """
165
+ Forward pass for the RBLN-optimized SD3Transformer2DModel.
166
+
167
+ Args:
168
+ hidden_states (torch.FloatTensor): The currently predicted image embeddings.
169
+ encoder_hidden_states (torch.FloatTensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
170
+ pooled_projections (torch.FloatTensor): Embeddings projected from the embeddings of input conditions.
171
+ timestep (torch.LongTensor): Current denoising step.
172
+ return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
173
+
174
+ Returns:
175
+ (Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
176
+ """
164
177
  sample_batch_size = hidden_states.size()[0]
165
178
  compiled_batch_size = self.compiled_batch_size
166
179
  if sample_batch_size != compiled_batch_size and (
@@ -13,3 +13,4 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from .unet_2d_condition import RBLNUNet2DConditionModel
16
+ from .unet_spatio_temporal_condition import RBLNUNetSpatioTemporalConditionModel