optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. optimum/rbln/__init__.py +108 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +156 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +48 -21
  52. optimum/rbln/modeling_base.py +99 -22
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +92 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +91 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  157. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  158. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  159. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  160. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  161. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  162. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  163. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
  164. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  165. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  166. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  167. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  168. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  169. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  170. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  171. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  172. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  173. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  174. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  175. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
  176. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  177. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  178. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  179. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  180. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  181. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  182. optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
  183. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  184. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  185. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  186. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  187. optimum/rbln/utils/deprecation.py +213 -0
  188. optimum/rbln/utils/hub.py +14 -3
  189. optimum/rbln/utils/runtime_utils.py +60 -18
  190. optimum/rbln/utils/submodule.py +31 -9
  191. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  192. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  193. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  194. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  195. optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
  196. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ....transformers import RBLNCLIPTextModelWithProjectionConfig, RBLNT5EncoderModelConfig
@@ -40,7 +40,7 @@ class RBLNStableDiffusion3PipelineBaseConfig(RBLNModelConfig):
40
40
  height: Optional[int] = None,
41
41
  width: Optional[int] = None,
42
42
  guidance_scale: Optional[float] = None,
43
- **kwargs: Dict[str, Any],
43
+ **kwargs: Any,
44
44
  ):
45
45
  """
46
46
  Args:
@@ -64,7 +64,7 @@ class RBLNStableDiffusion3PipelineBaseConfig(RBLNModelConfig):
64
64
  height (Optional[int]): Height of the generated images.
65
65
  width (Optional[int]): Width of the generated images.
66
66
  guidance_scale (Optional[float]): Scale for classifier-free guidance.
67
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
67
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
68
68
 
69
69
  Raises:
70
70
  ValueError: If both image_size and img_height/img_width are provided.
@@ -100,27 +100,31 @@ class RBLNStableDiffusion3PipelineBaseConfig(RBLNModelConfig):
100
100
 
101
101
  max_seq_len = max_seq_len or 256
102
102
 
103
- self.text_encoder = self.init_submodule_config(
104
- RBLNCLIPTextModelWithProjectionConfig, text_encoder, batch_size=batch_size
103
+ self.text_encoder = self.initialize_submodule_config(
104
+ text_encoder,
105
+ cls_name="RBLNCLIPTextModelWithProjectionConfig",
106
+ batch_size=batch_size,
105
107
  )
106
- self.text_encoder_2 = self.init_submodule_config(
107
- RBLNCLIPTextModelWithProjectionConfig, text_encoder_2, batch_size=batch_size
108
+ self.text_encoder_2 = self.initialize_submodule_config(
109
+ text_encoder_2,
110
+ cls_name="RBLNCLIPTextModelWithProjectionConfig",
111
+ batch_size=batch_size,
108
112
  )
109
- self.text_encoder_3 = self.init_submodule_config(
110
- RBLNT5EncoderModelConfig,
113
+ self.text_encoder_3 = self.initialize_submodule_config(
111
114
  text_encoder_3,
115
+ cls_name="RBLNT5EncoderModelConfig",
112
116
  batch_size=batch_size,
113
117
  max_seq_len=max_seq_len,
114
118
  model_input_names=["input_ids"],
115
119
  )
116
- self.transformer = self.init_submodule_config(
117
- RBLNSD3Transformer2DModelConfig,
120
+ self.transformer = self.initialize_submodule_config(
118
121
  transformer,
122
+ cls_name="RBLNSD3Transformer2DModelConfig",
119
123
  sample_size=sample_size,
120
124
  )
121
- self.vae = self.init_submodule_config(
122
- RBLNAutoencoderKLConfig,
125
+ self.vae = self.initialize_submodule_config(
123
126
  vae,
127
+ cls_name="RBLNAutoencoderKLConfig",
124
128
  batch_size=batch_size,
125
129
  uses_encoder=self.__class__._vae_uses_encoder,
126
130
  sample_size=image_size,
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ....transformers import RBLNCLIPTextModelConfig, RBLNCLIPTextModelWithProjectionConfig
@@ -38,7 +38,7 @@ class RBLNStableDiffusionXLPipelineBaseConfig(RBLNModelConfig):
38
38
  sample_size: Optional[Tuple[int, int]] = None,
39
39
  image_size: Optional[Tuple[int, int]] = None,
40
40
  guidance_scale: Optional[float] = None,
41
- **kwargs: Dict[str, Any],
41
+ **kwargs: Any,
42
42
  ):
43
43
  """
44
44
  Args:
@@ -59,7 +59,7 @@ class RBLNStableDiffusionXLPipelineBaseConfig(RBLNModelConfig):
59
59
  image_size (Optional[Tuple[int, int]]): Alternative way to specify image dimensions.
60
60
  Cannot be used together with img_height/img_width.
61
61
  guidance_scale (Optional[float]): Scale for classifier-free guidance.
62
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
62
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
63
63
 
64
64
  Raises:
65
65
  ValueError: If both image_size and img_height/img_width are provided.
@@ -93,18 +93,25 @@ class RBLNStableDiffusionXLPipelineBaseConfig(RBLNModelConfig):
93
93
  elif (img_height is not None and img_width is None) or (img_height is None and img_width is not None):
94
94
  raise ValueError("Both img_height and img_width must be provided together if used")
95
95
 
96
- self.text_encoder = self.init_submodule_config(RBLNCLIPTextModelConfig, text_encoder, batch_size=batch_size)
97
- self.text_encoder_2 = self.init_submodule_config(
98
- RBLNCLIPTextModelWithProjectionConfig, text_encoder_2, batch_size=batch_size
96
+ self.text_encoder = self.initialize_submodule_config(
97
+ text_encoder,
98
+ cls_name="RBLNCLIPTextModelConfig",
99
+ batch_size=batch_size,
100
+ )
101
+ self.text_encoder_2 = self.initialize_submodule_config(
102
+ text_encoder_2,
103
+ cls_name="RBLNCLIPTextModelWithProjectionConfig",
104
+ batch_size=batch_size,
99
105
  )
100
- self.unet = self.init_submodule_config(
101
- RBLNUNet2DConditionModelConfig,
106
+
107
+ self.unet = self.initialize_submodule_config(
102
108
  unet,
109
+ cls_name="RBLNUNet2DConditionModelConfig",
103
110
  sample_size=sample_size,
104
111
  )
105
- self.vae = self.init_submodule_config(
106
- RBLNAutoencoderKLConfig,
112
+ self.vae = self.initialize_submodule_config(
107
113
  vae,
114
+ cls_name="RBLNAutoencoderKLConfig",
108
115
  batch_size=batch_size,
109
116
  uses_encoder=self.__class__._vae_uses_encoder,
110
117
  sample_size=image_size, # image size is equal to sample size in vae
@@ -0,0 +1,114 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....transformers import RBLNCLIPVisionModelWithProjectionConfig
19
+ from ..models import RBLNAutoencoderKLTemporalDecoderConfig, RBLNUNetSpatioTemporalConditionModelConfig
20
+
21
+
22
+ class RBLNStableVideoDiffusionPipelineConfig(RBLNModelConfig):
23
+ submodules = ["image_encoder", "unet", "vae"]
24
+ _vae_uses_encoder = True
25
+
26
+ def __init__(
27
+ self,
28
+ image_encoder: Optional[RBLNCLIPVisionModelWithProjectionConfig] = None,
29
+ unet: Optional[RBLNUNetSpatioTemporalConditionModelConfig] = None,
30
+ vae: Optional[RBLNAutoencoderKLTemporalDecoderConfig] = None,
31
+ *,
32
+ batch_size: Optional[int] = None,
33
+ height: Optional[int] = None,
34
+ width: Optional[int] = None,
35
+ num_frames: Optional[int] = None,
36
+ decode_chunk_size: Optional[int] = None,
37
+ guidance_scale: Optional[float] = None,
38
+ **kwargs: Any,
39
+ ):
40
+ """
41
+ Args:
42
+ image_encoder (Optional[RBLNCLIPVisionModelWithProjectionConfig]): Configuration for the image encoder component.
43
+ Initialized as RBLNCLIPVisionModelWithProjectionConfig if not provided.
44
+ unet (Optional[RBLNUNetSpatioTemporalConditionModelConfig]): Configuration for the UNet model component.
45
+ Initialized as RBLNUNetSpatioTemporalConditionModelConfig if not provided.
46
+ vae (Optional[RBLNAutoencoderKLTemporalDecoderConfig]): Configuration for the VAE model component.
47
+ Initialized as RBLNAutoencoderKLTemporalDecoderConfig if not provided.
48
+ batch_size (Optional[int]): Batch size for inference, applied to all submodules.
49
+ height (Optional[int]): Height of the generated images.
50
+ width (Optional[int]): Width of the generated images.
51
+ num_frames (Optional[int]): The number of frames in the generated video.
52
+ decode_chunk_size (Optional[int]): The number of frames to decode at once during VAE decoding.
53
+ Useful for managing memory usage during video generation.
54
+ guidance_scale (Optional[float]): Scale for classifier-free guidance.
55
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
56
+
57
+ Raises:
58
+ ValueError: If both image_size and height/width are provided.
59
+
60
+ Note:
61
+ When guidance_scale > 1.0, the UNet batch size is automatically doubled to
62
+ accommodate classifier-free guidance.
63
+ """
64
+ super().__init__(**kwargs)
65
+ if height is not None and width is not None:
66
+ image_size = (height, width)
67
+ else:
68
+ # Get default image size from original class to set UNet, VAE image size
69
+ height = self.get_default_values_for_original_cls("__call__", ["height"])["height"]
70
+ width = self.get_default_values_for_original_cls("__call__", ["width"])["width"]
71
+ image_size = (height, width)
72
+
73
+ self.image_encoder = self.initialize_submodule_config(
74
+ image_encoder, cls_name="RBLNCLIPVisionModelWithProjectionConfig", batch_size=batch_size
75
+ )
76
+ self.unet = self.initialize_submodule_config(
77
+ unet,
78
+ cls_name="RBLNUNetSpatioTemporalConditionModelConfig",
79
+ num_frames=num_frames,
80
+ )
81
+ self.vae = self.initialize_submodule_config(
82
+ vae,
83
+ cls_name="RBLNAutoencoderKLTemporalDecoderConfig",
84
+ batch_size=batch_size,
85
+ num_frames=num_frames,
86
+ decode_chunk_size=decode_chunk_size,
87
+ uses_encoder=self.__class__._vae_uses_encoder,
88
+ sample_size=image_size, # image size is equal to sample size in vae
89
+ )
90
+
91
+ # Get default guidance scale from original class to set UNet batch size
92
+ if guidance_scale is None:
93
+ guidance_scale = self.get_default_values_for_original_cls("__call__", ["max_guidance_scale"])[
94
+ "max_guidance_scale"
95
+ ]
96
+
97
+ if not self.unet.batch_size_is_specified:
98
+ do_classifier_free_guidance = guidance_scale > 1.0
99
+ if do_classifier_free_guidance:
100
+ self.unet.batch_size = self.image_encoder.batch_size * 2
101
+ else:
102
+ self.unet.batch_size = self.image_encoder.batch_size
103
+
104
+ @property
105
+ def batch_size(self):
106
+ return self.vae.batch_size
107
+
108
+ @property
109
+ def sample_size(self):
110
+ return self.unet.sample_size
111
+
112
+ @property
113
+ def image_size(self):
114
+ return self.vae.sample_size
@@ -33,6 +33,10 @@ if TYPE_CHECKING:
33
33
 
34
34
 
35
35
  class RBLNDiffusionMixinConfig(RBLNModelConfig):
36
+ """
37
+ Configuration class for RBLN diffusion pipelines.
38
+ """
39
+
36
40
  pass
37
41
 
38
42
 
@@ -54,8 +58,8 @@ class RBLNDiffusionMixin:
54
58
  ```
55
59
 
56
60
  Class Variables:
57
- _submodules: List of submodule names that should be compiled (typically ["text_encoder", "unet", "vae"])
58
- _optional_submodules: List of submodule names compiled without inheriting RBLNModel (typically ["safety_checker"])
61
+ - `_submodules`: List of submodule names that should be compiled (typically ["text_encoder", "unet", "vae"])
62
+ - `_optional_submodules`: List of submodule names compiled without inheriting RBLNModel (typically ["safety_checker"])
59
63
 
60
64
  Methods:
61
65
  from_pretrained: Creates and optionally compiles a model from a pretrained checkpoint
@@ -130,20 +134,20 @@ class RBLNDiffusionMixin:
130
134
  cls,
131
135
  model_id: str,
132
136
  *,
133
- export: bool = False,
137
+ export: bool = None,
134
138
  model_save_dir: Optional[PathLike] = None,
135
139
  rbln_config: Dict[str, Any] = {},
136
140
  lora_ids: Optional[Union[str, List[str]]] = None,
137
141
  lora_weights_names: Optional[Union[str, List[str]]] = None,
138
142
  lora_scales: Optional[Union[float, List[float]]] = None,
139
- **kwargs: Dict[str, Any],
143
+ **kwargs: Any,
140
144
  ) -> "RBLNDiffusionMixin":
141
145
  """
142
146
  Load a pretrained diffusion pipeline from a model checkpoint, with optional compilation for RBLN NPUs.
143
147
 
144
148
  This method has two distinct operating modes:
145
- - When `export=True`: Takes a PyTorch-based diffusion model, compiles it for RBLN NPUs, and loads the compiled model
146
- - When `export=False`: Loads an already compiled RBLN model from `model_id` without recompilation
149
+ - When `export=True`: Takes a PyTorch-based diffusion model, compiles it for RBLN NPUs, and loads the compiled model
150
+ - When `export=False`: Loads an already compiled RBLN model from `model_id` without recompilation
147
151
 
148
152
  It supports various diffusion pipelines including Stable Diffusion, Kandinsky, ControlNet, and other diffusers-based models.
149
153
 
@@ -170,7 +174,7 @@ class RBLNDiffusionMixin:
170
174
  Names of specific LoRA weight files to load, corresponding to lora_ids. Only used when `export=True`.
171
175
  lora_scales:
172
176
  Scaling factor(s) to apply to the LoRA adapter(s). Only used when `export=True`.
173
- **kwargs:
177
+ kwargs:
174
178
  Additional arguments to pass to the underlying diffusion pipeline constructor or the
175
179
  RBLN compilation process. These may include parameters specific to individual submodules
176
180
  or the particular diffusion pipeline being used.
@@ -181,6 +185,20 @@ class RBLNDiffusionMixin:
181
185
  """
182
186
  rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
183
187
 
188
+ if export is None:
189
+ export = any(
190
+ not RBLNModel._is_compiled(
191
+ model_id,
192
+ token=kwargs.get("token"),
193
+ revision=kwargs.get("revision"),
194
+ force_download=kwargs.get("force_download", False),
195
+ cache_dir=kwargs.get("cache_dir"),
196
+ subfolder=submodule_name,
197
+ local_files_only=kwargs.get("local_files_only", False),
198
+ )
199
+ for submodule_name in cls._submodules
200
+ )
201
+
184
202
  if export:
185
203
  # keep submodules if user passed any of them.
186
204
  passed_submodules = {
@@ -226,7 +244,6 @@ class RBLNDiffusionMixin:
226
244
  device=rbln_config.device,
227
245
  device_map=rbln_config.device_map,
228
246
  create_runtimes=rbln_config.create_runtimes,
229
- optimize_host_mem=rbln_config.optimize_host_memory,
230
247
  activate_profiler=rbln_config.activate_profiler,
231
248
  timeout=rbln_config.timeout,
232
249
  ):
@@ -394,12 +411,11 @@ class RBLNDiffusionMixin:
394
411
  # overwrite to replace incorrect config
395
412
  model.save_config(model_save_dir)
396
413
 
397
- if rbln_config.optimize_host_memory is False:
398
- # Keep compiled_model objs to further analysis. -> TODO: remove soon...
399
- model.compiled_models = []
400
- for name in cls._submodules:
401
- submodule = getattr(model, name)
402
- model.compiled_models.extend(submodule.compiled_models)
414
+ # Keep compiled_model objs to further analysis. -> TODO: remove soon...
415
+ model.compiled_models = []
416
+ for name in cls._submodules:
417
+ submodule = getattr(model, name)
418
+ model.compiled_models.extend(submodule.compiled_models)
403
419
 
404
420
  return model
405
421
 
@@ -22,9 +22,11 @@ _import_structure = {
22
22
  "RBLNAutoencoderKL",
23
23
  "RBLNAutoencoderKLCosmos",
24
24
  "RBLNVQModel",
25
+ "RBLNAutoencoderKLTemporalDecoder",
25
26
  ],
26
27
  "unets": [
27
28
  "RBLNUNet2DConditionModel",
29
+ "RBLNUNetSpatioTemporalConditionModel",
28
30
  ],
29
31
  "controlnet": ["RBLNControlNetModel"],
30
32
  "transformers": [
@@ -38,6 +40,7 @@ if TYPE_CHECKING:
38
40
  from .autoencoders import (
39
41
  RBLNAutoencoderKL,
40
42
  RBLNAutoencoderKLCosmos,
43
+ RBLNAutoencoderKLTemporalDecoder,
41
44
  RBLNVQModel,
42
45
  )
43
46
  from .controlnet import RBLNControlNetModel
@@ -48,6 +51,7 @@ if TYPE_CHECKING:
48
51
  )
49
52
  from .unets import (
50
53
  RBLNUNet2DConditionModel,
54
+ RBLNUNetSpatioTemporalConditionModel,
51
55
  )
52
56
  else:
53
57
  import sys
@@ -14,4 +14,5 @@
14
14
 
15
15
  from .autoencoder_kl import RBLNAutoencoderKL
16
16
  from .autoencoder_kl_cosmos import RBLNAutoencoderKLCosmos
17
+ from .autoencoder_kl_temporal_decoder import RBLNAutoencoderKLTemporalDecoder
17
18
  from .vq_model import RBLNVQModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Dict, List, Tuple, Union
15
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
16
16
 
17
17
  import rebel
18
18
  import torch
@@ -214,13 +214,41 @@ class RBLNAutoencoderKL(RBLNModel):
214
214
  for compiled_model, device_val in zip(compiled_models, device_vals)
215
215
  ]
216
216
 
217
- def encode(self, x: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
217
+ def encode(
218
+ self, x: torch.FloatTensor, return_dict: bool = True, **kwargs: Dict[str, Any]
219
+ ) -> Union[torch.FloatTensor, AutoencoderKLOutput]:
220
+ """
221
+ Encode an input image into a latent representation.
222
+
223
+ Args:
224
+ x: The input image to encode.
225
+ return_dict:
226
+ Whether to return output as a dictionary. Defaults to True.
227
+ kwargs: Additional arguments to pass to the encoder.
228
+
229
+ Returns:
230
+ The latent representation or AutoencoderKLOutput if return_dict=True
231
+ """
218
232
  posterior = self.encoder.encode(x)
219
233
  if not return_dict:
220
234
  return (posterior,)
221
235
  return AutoencoderKLOutput(latent_dist=posterior)
222
236
 
223
- def decode(self, z: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
237
+ def decode(
238
+ self, z: torch.FloatTensor, return_dict: bool = True, **kwargs: Dict[str, Any]
239
+ ) -> Union[torch.FloatTensor, DecoderOutput]:
240
+ """
241
+ Decode a latent representation into an image.
242
+
243
+ Args:
244
+ z: The latent representation to decode.
245
+ return_dict:
246
+ Whether to return output as a dictionary. Defaults to True.
247
+ kwargs: Additional arguments to pass to the decoder.
248
+
249
+ Returns:
250
+ The decoded image or DecoderOutput if return_dict=True
251
+ """
224
252
  dec = self.decoder.decode(z)
225
253
  if not return_dict:
226
254
  return (dec,)
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Dict, List, Union
15
+ from typing import TYPE_CHECKING, Any, Dict, List, Union
16
16
 
17
17
  import rebel
18
18
  import torch
@@ -68,7 +68,7 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
68
68
  self.image_size = self.rbln_config.image_size
69
69
 
70
70
  @classmethod
71
- def wrap_model_if_needed(
71
+ def _wrap_model_if_needed(
72
72
  cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLCosmosConfig
73
73
  ) -> torch.nn.Module:
74
74
  decoder_model = _VAECosmosDecoder(model)
@@ -98,7 +98,7 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
98
98
 
99
99
  compiled_models = {}
100
100
  if rbln_config.uses_encoder:
101
- encoder_model, decoder_model = cls.wrap_model_if_needed(model, rbln_config)
101
+ encoder_model, decoder_model = cls._wrap_model_if_needed(model, rbln_config)
102
102
  enc_compiled_model = cls.compile(
103
103
  encoder_model,
104
104
  rbln_compile_config=rbln_config.compile_cfgs[0],
@@ -107,7 +107,7 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
107
107
  )
108
108
  compiled_models["encoder"] = enc_compiled_model
109
109
  else:
110
- decoder_model = cls.wrap_model_if_needed(model, rbln_config)
110
+ decoder_model = cls._wrap_model_if_needed(model, rbln_config)
111
111
  dec_compiled_model = cls.compile(
112
112
  decoder_model,
113
113
  rbln_compile_config=rbln_config.compile_cfgs[-1],
@@ -205,13 +205,38 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
205
205
  for compiled_model, device_val in zip(compiled_models, device_vals)
206
206
  ]
207
207
 
208
- def encode(self, x: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
208
+ def encode(
209
+ self, x: torch.FloatTensor, return_dict: bool = True, **kwargs: Dict[str, Any]
210
+ ) -> Union[torch.FloatTensor, AutoencoderKLOutput]:
211
+ """
212
+ Encode an input video into a latent representation.
213
+
214
+ Args:
215
+ x: The input video to encode.
216
+ return_dict:
217
+ Whether to return output as a dictionary. Defaults to True.
218
+ kwargs: Additional arguments to pass to the encoder.
219
+
220
+ Returns:
221
+ The latent representation or AutoencoderKLOutput if return_dict=True
222
+ """
209
223
  posterior = self.encoder.encode(x)
210
224
  if not return_dict:
211
225
  return (posterior,)
212
226
  return AutoencoderKLOutput(latent_dist=posterior)
213
227
 
214
- def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
228
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[torch.FloatTensor, DecoderOutput]:
229
+ """
230
+ Decode a latent representation into a video.
231
+
232
+ Args:
233
+ z: The latent representation to decode.
234
+ return_dict:
235
+ Whether to return output as a dictionary. Defaults to True.
236
+
237
+ Returns:
238
+ The decoded video or DecoderOutput if return_dict=True
239
+ """
215
240
  decoded = self.decoder.decode(z)
216
241
 
217
242
  if not return_dict: