optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. optimum/rbln/__init__.py +116 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +171 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +33 -18
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +50 -24
  52. optimum/rbln/modeling_base.py +116 -35
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +100 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +93 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  158. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  159. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  160. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  161. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  162. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  163. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  164. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
  165. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  166. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  167. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  168. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  169. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  170. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  171. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  172. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  173. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  174. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  175. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  176. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
  177. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  178. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  179. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  180. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  181. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  182. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  183. optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
  184. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  185. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  186. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  187. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  188. optimum/rbln/utils/deprecation.py +213 -0
  189. optimum/rbln/utils/hub.py +22 -50
  190. optimum/rbln/utils/runtime_utils.py +85 -17
  191. optimum/rbln/utils/submodule.py +31 -9
  192. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  193. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  194. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  195. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  196. optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
  197. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,275 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Dict, List, Tuple, Union
16
+
17
+ import rebel
18
+ import torch # noqa: I001
19
+ from diffusers import AutoencoderKLTemporalDecoder
20
+ from diffusers.models.autoencoders.vae import DecoderOutput
21
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
22
+ from transformers import PretrainedConfig
23
+
24
+ from ....configuration_utils import RBLNCompileConfig
25
+ from ....modeling import RBLNModel
26
+ from ....utils.logging import get_logger
27
+ from ...configurations import RBLNAutoencoderKLTemporalDecoderConfig
28
+ from ...modeling_diffusers import RBLNDiffusionMixin
29
+ from .vae import (
30
+ DiagonalGaussianDistribution,
31
+ RBLNRuntimeVAEDecoder,
32
+ RBLNRuntimeVAEEncoder,
33
+ _VAEEncoder,
34
+ _VAETemporalDecoder,
35
+ )
36
+
37
+
38
+ if TYPE_CHECKING:
39
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
40
+
41
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
42
+
43
+ logger = get_logger(__name__)
44
+
45
+
46
+ class RBLNAutoencoderKLTemporalDecoder(RBLNModel):
47
+ auto_model_class = AutoencoderKLTemporalDecoder
48
+ hf_library_name = "diffusers"
49
+ _rbln_config_class = RBLNAutoencoderKLTemporalDecoderConfig
50
+
51
+ def __post_init__(self, **kwargs):
52
+ super().__post_init__(**kwargs)
53
+
54
+ if self.rbln_config.uses_encoder:
55
+ self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
56
+ self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[-1], main_input_name="z")
57
+ self.image_size = self.rbln_config.image_size
58
+
59
+ @classmethod
60
+ def _wrap_model_if_needed(
61
+ cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLTemporalDecoderConfig
62
+ ) -> torch.nn.Module:
63
+ decoder_model = _VAETemporalDecoder(model)
64
+ decoder_model.num_frames = rbln_config.decode_chunk_size
65
+ decoder_model.eval()
66
+
67
+ if rbln_config.uses_encoder:
68
+ encoder_model = _VAEEncoder(model)
69
+ encoder_model.eval()
70
+ return encoder_model, decoder_model
71
+ else:
72
+ return decoder_model
73
+
74
+ @classmethod
75
+ def get_compiled_model(
76
+ cls, model, rbln_config: RBLNAutoencoderKLTemporalDecoderConfig
77
+ ) -> Dict[str, rebel.RBLNCompiledModel]:
78
+ compiled_models = {}
79
+ if rbln_config.uses_encoder:
80
+ encoder_model, decoder_model = cls._wrap_model_if_needed(model, rbln_config)
81
+ enc_compiled_model = cls.compile(
82
+ encoder_model,
83
+ rbln_compile_config=rbln_config.compile_cfgs[0],
84
+ create_runtimes=rbln_config.create_runtimes,
85
+ device=rbln_config.device_map["encoder"],
86
+ )
87
+ compiled_models["encoder"] = enc_compiled_model
88
+ else:
89
+ decoder_model = cls._wrap_model_if_needed(model, rbln_config)
90
+ dec_compiled_model = cls.compile(
91
+ decoder_model,
92
+ rbln_compile_config=rbln_config.compile_cfgs[-1],
93
+ create_runtimes=rbln_config.create_runtimes,
94
+ device=rbln_config.device_map["decoder"],
95
+ )
96
+ compiled_models["decoder"] = dec_compiled_model
97
+
98
+ return compiled_models
99
+
100
+ @classmethod
101
+ def get_vae_sample_size(
102
+ cls,
103
+ pipe: "RBLNDiffusionMixin",
104
+ rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
105
+ return_vae_scale_factor: bool = False,
106
+ ) -> Tuple[int, int]:
107
+ sample_size = rbln_config.sample_size
108
+ if hasattr(pipe, "vae_scale_factor"):
109
+ vae_scale_factor = pipe.vae_scale_factor
110
+ else:
111
+ if hasattr(pipe.vae.config, "block_out_channels"):
112
+ vae_scale_factor = 2 ** (len(pipe.vae.config.block_out_channels) - 1)
113
+ else:
114
+ vae_scale_factor = 8 # vae image processor default value 8 (int)
115
+
116
+ if sample_size is None:
117
+ sample_size = pipe.unet.config.sample_size
118
+ if isinstance(sample_size, int):
119
+ sample_size = (sample_size, sample_size)
120
+ sample_size = (sample_size[0] * vae_scale_factor, sample_size[1] * vae_scale_factor)
121
+
122
+ if return_vae_scale_factor:
123
+ return sample_size, vae_scale_factor
124
+ else:
125
+ return sample_size
126
+
127
+ @classmethod
128
+ def update_rbln_config_using_pipe(
129
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
130
+ ) -> "RBLNDiffusionMixinConfig":
131
+ rbln_config.vae.sample_size, rbln_config.vae.vae_scale_factor = cls.get_vae_sample_size(
132
+ pipe, rbln_config.vae, return_vae_scale_factor=True
133
+ )
134
+
135
+ if rbln_config.vae.num_frames is None:
136
+ if hasattr(pipe.unet.config, "num_frames"):
137
+ rbln_config.vae.num_frames = pipe.unet.config.num_frames
138
+ else:
139
+ raise ValueError("num_frames should be specified in unet config.json")
140
+
141
+ if rbln_config.vae.decode_chunk_size is None:
142
+ rbln_config.vae.decode_chunk_size = rbln_config.vae.num_frames
143
+
144
+ def chunk_frame(num_frames, decode_chunk_size):
145
+ # get closest divisor to num_frames
146
+ divisors = [i for i in range(1, num_frames) if num_frames % i == 0]
147
+ closest = min(divisors, key=lambda x: abs(x - decode_chunk_size))
148
+ if decode_chunk_size != closest:
149
+ logger.warning(
150
+ f"To ensure successful model compilation and prevent device OOM, {decode_chunk_size} is set to {closest}."
151
+ )
152
+ return closest
153
+
154
+ decode_chunk_size = chunk_frame(rbln_config.vae.num_frames, rbln_config.vae.decode_chunk_size)
155
+ rbln_config.vae.decode_chunk_size = decode_chunk_size
156
+ return rbln_config
157
+
158
+ @classmethod
159
+ def _update_rbln_config(
160
+ cls,
161
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
162
+ model: "PreTrainedModel",
163
+ model_config: "PretrainedConfig",
164
+ rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
165
+ ) -> RBLNAutoencoderKLTemporalDecoderConfig:
166
+ if rbln_config.sample_size is None:
167
+ rbln_config.sample_size = model_config.sample_size
168
+
169
+ if rbln_config.vae_scale_factor is None:
170
+ if hasattr(model_config, "block_out_channels"):
171
+ rbln_config.vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
172
+ else:
173
+ # vae image processor default value 8 (int)
174
+ rbln_config.vae_scale_factor = 8
175
+
176
+ compile_cfgs = []
177
+ if rbln_config.uses_encoder:
178
+ vae_enc_input_info = [
179
+ (
180
+ "x",
181
+ [
182
+ rbln_config.batch_size,
183
+ model_config.in_channels,
184
+ rbln_config.sample_size[0],
185
+ rbln_config.sample_size[1],
186
+ ],
187
+ "float32",
188
+ )
189
+ ]
190
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
191
+
192
+ decode_batch_size = rbln_config.batch_size * rbln_config.decode_chunk_size
193
+ vae_dec_input_info = [
194
+ (
195
+ "z",
196
+ [
197
+ decode_batch_size,
198
+ model_config.latent_channels,
199
+ rbln_config.latent_sample_size[0],
200
+ rbln_config.latent_sample_size[1],
201
+ ],
202
+ "float32",
203
+ )
204
+ ]
205
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
206
+
207
+ rbln_config.set_compile_cfgs(compile_cfgs)
208
+ return rbln_config
209
+
210
+ @classmethod
211
+ def _create_runtimes(
212
+ cls,
213
+ compiled_models: List[rebel.RBLNCompiledModel],
214
+ rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
215
+ ) -> List[rebel.Runtime]:
216
+ if len(compiled_models) == 1:
217
+ # decoder
218
+ expected_models = ["decoder"]
219
+ else:
220
+ expected_models = ["encoder", "decoder"]
221
+
222
+ if any(model_name not in rbln_config.device_map for model_name in expected_models):
223
+ cls._raise_missing_compiled_file_error(expected_models)
224
+
225
+ device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
226
+ return [
227
+ rebel.Runtime(
228
+ compiled_model,
229
+ tensor_type="pt",
230
+ device=device_val,
231
+ activate_profiler=rbln_config.activate_profiler,
232
+ timeout=rbln_config.timeout,
233
+ )
234
+ for compiled_model, device_val in zip(compiled_models, device_vals)
235
+ ]
236
+
237
+ def encode(
238
+ self, x: torch.FloatTensor, return_dict: bool = True
239
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
240
+ """
241
+ Encode an input image into a latent representation.
242
+
243
+ Args:
244
+ x: The input image to encode.
245
+ return_dict:
246
+ Whether to return output as a dictionary. Defaults to True.
247
+
248
+ Returns:
249
+ The latent representation or AutoencoderKLOutput if return_dict=True
250
+ """
251
+ posterior = self.encoder.encode(x)
252
+
253
+ if not return_dict:
254
+ return (posterior,)
255
+
256
+ return AutoencoderKLOutput(latent_dist=posterior)
257
+
258
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
259
+ """
260
+ Decode a latent representation into a video.
261
+
262
+ Args:
263
+ z: The latent representation to decode.
264
+ return_dict:
265
+ Whether to return output as a dictionary. Defaults to True.
266
+
267
+ Returns:
268
+ The decoded video or DecoderOutput if return_dict=True
269
+ """
270
+ decoded = self.decoder.decode(z)
271
+
272
+ if not return_dict:
273
+ return (decoded,)
274
+
275
+ return DecoderOutput(sample=decoded)
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, List
15
+ from typing import TYPE_CHECKING, List, Union
16
16
 
17
17
  import torch
18
18
  from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution, IdentityDistribution
@@ -21,7 +21,7 @@ from ....utils.runtime_utils import RBLNPytorchRuntime
21
21
 
22
22
 
23
23
  if TYPE_CHECKING:
24
- from diffusers import AutoencoderKL, AutoencoderKLCosmos, VQModel
24
+ from diffusers import AutoencoderKL, AutoencoderKLCosmos, AutoencoderKLTemporalDecoder, VQModel
25
25
 
26
26
 
27
27
  class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
@@ -67,18 +67,37 @@ class _VAEDecoder(torch.nn.Module):
67
67
  return vae_out
68
68
 
69
69
 
70
+ class _VAETemporalDecoder(torch.nn.Module):
71
+ def __init__(self, vae: "AutoencoderKLTemporalDecoder"):
72
+ super().__init__()
73
+ self.vae = vae
74
+ self.num_frames = None
75
+
76
+ def forward(self, z):
77
+ vae_out = self.vae.decode(z, num_frames=self.num_frames, return_dict=False)
78
+ return vae_out
79
+
80
+
70
81
  class _VAEEncoder(torch.nn.Module):
71
- def __init__(self, vae: "AutoencoderKL"):
82
+ def __init__(self, vae: Union["AutoencoderKL", "AutoencoderKLTemporalDecoder"]):
72
83
  super().__init__()
73
84
  self.vae = vae
74
85
 
75
86
  def encode(self, x: torch.FloatTensor, return_dict: bool = True):
76
- if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
77
- return self.tiled_encode(x, return_dict=return_dict)
87
+ if hasattr(self, "use_tiling") and hasattr(self, "use_slicing"):
88
+ if self.use_tiling and (
89
+ x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size
90
+ ):
91
+ return self.tiled_encode(x, return_dict=return_dict)
92
+
93
+ if self.use_slicing and x.shape[0] > 1:
94
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
95
+ h = torch.cat(encoded_slices)
96
+ else:
97
+ h = self.encoder(x)
98
+ if self.quant_conv is not None:
99
+ h = self.quant_conv(h)
78
100
 
79
- if self.use_slicing and x.shape[0] > 1:
80
- encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
81
- h = torch.cat(encoded_slices)
82
101
  else:
83
102
  h = self.encoder(x)
84
103
  if self.quant_conv is not None:
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, List, Union
15
+ from typing import TYPE_CHECKING, Any, List, Union
16
16
 
17
17
  import rebel
18
18
  import torch
@@ -165,17 +165,46 @@ class RBLNVQModel(RBLNModel):
165
165
  tensor_type="pt",
166
166
  device=device_val,
167
167
  activate_profiler=rbln_config.activate_profiler,
168
+ timeout=rbln_config.timeout,
168
169
  )
169
170
  for compiled_model, device_val in zip(compiled_models, device_vals)
170
171
  ]
171
172
 
172
- def encode(self, x: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
173
+ def encode(
174
+ self, x: torch.FloatTensor, return_dict: bool = True, **kwargs: Any
175
+ ) -> Union[torch.FloatTensor, VQEncoderOutput]:
176
+ """
177
+ Encode an input image into a quantized latent representation.
178
+
179
+ Args:
180
+ x: The input image to encode.
181
+ return_dict:
182
+ Whether to return output as a dictionary. Defaults to True.
183
+ kwargs: Additional arguments to pass to the encoder/quantizer.
184
+
185
+ Returns:
186
+ The quantized latent representation or a specific output object.
187
+ """
173
188
  posterior = self.encoder.encode(x)
174
189
  if not return_dict:
175
190
  return (posterior,)
176
191
  return VQEncoderOutput(latents=posterior)
177
192
 
178
- def decode(self, h: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
193
+ def decode(
194
+ self, h: torch.FloatTensor, return_dict: bool = True, **kwargs: Any
195
+ ) -> Union[torch.FloatTensor, DecoderOutput]:
196
+ """
197
+ Decode a quantized latent representation back into an image.
198
+
199
+ Args:
200
+ h: The quantized latent representation to decode.
201
+ return_dict:
202
+ Whether to return output as a dictionary. Defaults to True.
203
+ kwargs: Additional arguments to pass to the decoder.
204
+
205
+ Returns:
206
+ The decoded image or a DecoderOutput object.
207
+ """
179
208
  dec, commit_loss = self.decoder.decode(h, **kwargs)
180
209
  if not return_dict:
181
210
  return (dec, commit_loss)
@@ -118,7 +118,7 @@ class RBLNControlNetModel(RBLNModel):
118
118
  )
119
119
 
120
120
  @classmethod
121
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
121
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
122
122
  use_encoder_hidden_states = False
123
123
  for down_block in model.down_blocks:
124
124
  if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
@@ -219,6 +219,21 @@ class RBLNControlNetModel(RBLNModel):
219
219
  return_dict: bool = True,
220
220
  **kwargs,
221
221
  ):
222
+ """
223
+ Forward pass for the RBLN-optimized ControlNetModel.
224
+
225
+ Args:
226
+ sample (torch.FloatTensor): The noisy input tensor.
227
+ timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
228
+ encoder_hidden_states (torch.Tensor): The encoder hidden states.
229
+ controlnet_cond (torch.FloatTensor): The conditional input tensor of shape `(batch_size, max_seq_len, hidden_size)`.
230
+ conditioning_scale (torch.Tensor): The scale factor for ControlNet outputs.
231
+ added_cond_kwargs (Dict[str, torch.Tensor]): Additional conditions for the Stable Diffusion XL UNet.
232
+ return_dict (bool): Whether or not to return a [`~diffusers.models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple
233
+
234
+ Returns:
235
+ (Union[`~diffusers.models.controlnets.controlnet.ControlNetOutput`], Tuple)
236
+ """
222
237
  sample_batch_size = sample.size()[0]
223
238
  compiled_batch_size = self.compiled_batch_size
224
239
  if sample_batch_size != compiled_batch_size and (
@@ -59,7 +59,7 @@ class RBLNPriorTransformer(RBLNModel):
59
59
  """
60
60
  RBLN implementation of PriorTransformer for diffusion models like Kandinsky V2.2.
61
61
 
62
- The Prior Transformer takes text and/or image embeddings from encoders (like CLIP) and
62
+ The PriorTransformer takes text and/or image embeddings from encoders (like CLIP) and
63
63
  maps them to a shared latent space that guides the diffusion process to generate the desired image.
64
64
 
65
65
  This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
@@ -77,7 +77,7 @@ class RBLNPriorTransformer(RBLNModel):
77
77
  self.clip_std = artifacts["clip_std"]
78
78
 
79
79
  @classmethod
80
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
80
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
81
81
  return _PriorTransformer(model).eval()
82
82
 
83
83
  @classmethod
@@ -128,13 +128,27 @@ class RBLNPriorTransformer(RBLNModel):
128
128
 
129
129
  def forward(
130
130
  self,
131
- hidden_states,
131
+ hidden_states: torch.Tensor,
132
132
  timestep: Union[torch.Tensor, float, int],
133
133
  proj_embedding: torch.Tensor,
134
134
  encoder_hidden_states: Optional[torch.Tensor] = None,
135
135
  attention_mask: Optional[torch.Tensor] = None,
136
136
  return_dict: bool = True,
137
137
  ):
138
+ """
139
+ Forward pass for the RBLN-optimized PriorTransformer.
140
+
141
+ Args:
142
+ hidden_states (torch.Tensor): The currently predicted image embeddings.
143
+ timestep (Union[torch.Tensor, float, int]): Current denoising step.
144
+ proj_embedding (torch.Tensor): Projected embedding vector the denoising process is conditioned on.
145
+ encoder_hidden_states (Optional[torch.Tensor]): Hidden states of the text embeddings the denoising process is conditioned on.
146
+ attention_mask (Optional[torch.Tensor]): Text mask for the text embeddings.
147
+ return_dict (bool): Whether or not to return a [`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`] instead of a plain tuple.
148
+
149
+ Returns:
150
+ (Union[`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`, Tuple])
151
+ """
138
152
  # Convert timestep(long) and attention_mask(bool) to float
139
153
  return super().forward(
140
154
  hidden_states,
@@ -94,7 +94,15 @@ class CosmosTransformer3DModelWrapper(torch.nn.Module):
94
94
 
95
95
 
96
96
  class RBLNCosmosTransformer3DModel(RBLNModel):
97
- """RBLN wrapper for the Cosmos Transformer model."""
97
+ """
98
+ RBLN implementation of CosmosTransformer3DModel for diffusion models like Cosmos.
99
+
100
+ The CosmosTransformer3DModel takes text and/or image embeddings from encoders (like CLIP) and
101
+ maps them to a shared latent space that guides the diffusion process to generate the desired image.
102
+
103
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
104
+ the library implements for all its models.
105
+ """
98
106
 
99
107
  hf_library_name = "diffusers"
100
108
  auto_model_class = CosmosTransformer3DModel
@@ -177,7 +185,7 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
177
185
  )
178
186
 
179
187
  @classmethod
180
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
188
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
181
189
  num_latent_frames = rbln_config.num_latent_frames
182
190
  latent_height = rbln_config.latent_height
183
191
  latent_width = rbln_config.latent_width
@@ -279,7 +287,7 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
279
287
  tensor_type="pt",
280
288
  device=rbln_config.device_map[DEFAULT_COMPILED_MODEL_NAME],
281
289
  activate_profiler=rbln_config.activate_profiler,
282
- timeout=120,
290
+ timeout=rbln_config.timeout,
283
291
  )
284
292
  for compiled_model in compiled_models
285
293
  ]
@@ -295,6 +303,21 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
295
303
  padding_mask: Optional[torch.Tensor] = None,
296
304
  return_dict: bool = True,
297
305
  ):
306
+ """
307
+ Forward pass for the RBLN-optimized CosmosTransformer3DModel.
308
+
309
+ Args:
310
+ hidden_states (torch.Tensor): The currently predicted image embeddings.
311
+ timestep (torch.Tensor): Current denoising step.
312
+ encoder_hidden_states (torch.Tensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
313
+ fps: (Optional[int]): Frames per second for the video being generated.
314
+ condition_mask (Optional[torch.Tensor]): Tensor of condition mask.
315
+ padding_mask (Optional[torch.Tensor]): Tensor of padding mask.
316
+ return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
317
+
318
+ Returns:
319
+ (Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
320
+ """
298
321
  (
299
322
  hidden_states,
300
323
  temb,
@@ -59,7 +59,15 @@ class SD3Transformer2DModelWrapper(torch.nn.Module):
59
59
 
60
60
 
61
61
  class RBLNSD3Transformer2DModel(RBLNModel):
62
- """RBLN wrapper for the Stable Diffusion 3 MMDiT Transformer model."""
62
+ """
63
+ RBLN implementation of SD3Transformer2DModel for diffusion models like Stable Diffusion 3.
64
+
65
+ The SD3Transformer2DModel takes text and/or image embeddings from encoders (like CLIP) and
66
+ maps them to a shared latent space that guides the diffusion process to generate the desired image.
67
+
68
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
69
+ the library implements for all its models.
70
+ """
63
71
 
64
72
  hf_library_name = "diffusers"
65
73
  auto_model_class = SD3Transformer2DModel
@@ -69,7 +77,7 @@ class RBLNSD3Transformer2DModel(RBLNModel):
69
77
  super().__post_init__(**kwargs)
70
78
 
71
79
  @classmethod
72
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
80
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
73
81
  return SD3Transformer2DModelWrapper(model).eval()
74
82
 
75
83
  @classmethod
@@ -153,6 +161,19 @@ class RBLNSD3Transformer2DModel(RBLNModel):
153
161
  return_dict: bool = True,
154
162
  **kwargs,
155
163
  ):
164
+ """
165
+ Forward pass for the RBLN-optimized SD3Transformer2DModel.
166
+
167
+ Args:
168
+ hidden_states (torch.FloatTensor): The currently predicted image embeddings.
169
+ encoder_hidden_states (torch.FloatTensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
170
+ pooled_projections (torch.FloatTensor): Embeddings projected from the embeddings of input conditions.
171
+ timestep (torch.LongTensor): Current denoising step.
172
+ return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
173
+
174
+ Returns:
175
+ (Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
176
+ """
156
177
  sample_batch_size = hidden_states.size()[0]
157
178
  compiled_batch_size = self.compiled_batch_size
158
179
  if sample_batch_size != compiled_batch_size and (
@@ -13,3 +13,4 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from .unet_2d_condition import RBLNUNet2DConditionModel
16
+ from .unet_spatio_temporal_condition import RBLNUNetSpatioTemporalConditionModel
@@ -141,10 +141,13 @@ class _UNet_Kandinsky(torch.nn.Module):
141
141
 
142
142
  class RBLNUNet2DConditionModel(RBLNModel):
143
143
  """
144
- Configuration class for RBLN UNet2DCondition models.
144
+ RBLN implementation of UNet2DConditionModel for diffusion models.
145
145
 
146
- This class inherits from RBLNModelConfig and provides specific configuration options
147
- for UNet2DCondition models used in diffusion-based image generation.
146
+ This model is used to accelerate UNet2DCondition models from diffusers library on RBLN NPUs.
147
+ It is a key component in diffusion-based image generation models like Stable Diffusion.
148
+
149
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
150
+ the library implements for all its models.
148
151
  """
149
152
 
150
153
  hf_library_name = "diffusers"
@@ -168,7 +171,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
168
171
  self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
169
172
 
170
173
  @classmethod
171
- def wrap_model_if_needed(
174
+ def _wrap_model_if_needed(
172
175
  cls, model: torch.nn.Module, rbln_config: RBLNUNet2DConditionModelConfig
173
176
  ) -> torch.nn.Module:
174
177
  if model.config.addition_embed_type == "text_time":
@@ -346,6 +349,22 @@ class RBLNUNet2DConditionModel(RBLNModel):
346
349
  return_dict: bool = True,
347
350
  **kwargs,
348
351
  ) -> Union[UNet2DConditionOutput, Tuple]:
352
+ """
353
+ Forward pass for the RBLN-optimized UNet2DConditionModel.
354
+
355
+ Args:
356
+ sample (torch.Tensor): The noisy input tensor with the following shape `(batch, channel, height, width)`.
357
+ timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
358
+ encoder_hidden_states (torch.Tensor): The encoder hidden states.
359
+ added_cond_kwargs (Dict[str, torch.Tensor]): A kwargs dictionary containing additional embeddings that
360
+ if specified are added to the embeddings that are passed along to the UNet blocks.
361
+ down_block_additional_residuals (Optional[Tuple[torch.Tensor]]): A tuple of tensors that if specified are added to the residuals of down unet blocks.
362
+ mid_block_additional_residual (Optional[torch.Tensor]): A tensor that if specified is added to the residual of the middle unet block.
363
+ return_dict (bool): Whether or not to return a [`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
364
+
365
+ Returns:
366
+ (Union[`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`], Tuple)
367
+ """
349
368
  sample_batch_size = sample.size()[0]
350
369
  compiled_batch_size = self.compiled_batch_size
351
370
  if sample_batch_size != compiled_batch_size and (