optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. optimum/rbln/__init__.py +108 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +156 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +48 -21
  52. optimum/rbln/modeling_base.py +99 -22
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +92 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +91 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  157. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  158. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  159. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  160. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  161. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  162. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  163. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
  164. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  165. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  166. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  167. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  168. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  169. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  170. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  171. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  172. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  173. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  174. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  175. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
  176. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  177. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  178. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  179. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  180. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  181. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  182. optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
  183. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  184. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  185. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  186. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  187. optimum/rbln/utils/deprecation.py +213 -0
  188. optimum/rbln/utils/hub.py +14 -3
  189. optimum/rbln/utils/runtime_utils.py +60 -18
  190. optimum/rbln/utils/submodule.py +31 -9
  191. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  192. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  193. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  194. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  195. optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
  196. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,201 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from diffusers.models.unets.unet_spatio_temporal_condition import (
20
+ UNetSpatioTemporalConditionModel,
21
+ UNetSpatioTemporalConditionOutput,
22
+ )
23
+ from transformers import PretrainedConfig
24
+
25
+ from ....configuration_utils import RBLNCompileConfig
26
+ from ....modeling import RBLNModel
27
+ from ....utils.logging import get_logger
28
+ from ...configurations import RBLNUNetSpatioTemporalConditionModelConfig
29
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
30
+
31
+
32
+ if TYPE_CHECKING:
33
+ from transformers import AutoFeatureExtractor, AutoProcessor, PreTrainedModel
34
+
35
+ logger = get_logger(__name__)
36
+
37
+
38
+ class _UNet_STCM(torch.nn.Module):
39
+ def __init__(self, unet: "UNetSpatioTemporalConditionModel"):
40
+ super().__init__()
41
+ self.unet = unet
42
+
43
+ def forward(
44
+ self,
45
+ sample: torch.Tensor,
46
+ timestep: Union[torch.Tensor, float, int],
47
+ encoder_hidden_states: torch.Tensor,
48
+ added_time_ids: torch.Tensor,
49
+ ) -> torch.Tensor:
50
+ unet_out = self.unet(
51
+ sample=sample,
52
+ timestep=timestep,
53
+ encoder_hidden_states=encoder_hidden_states,
54
+ added_time_ids=added_time_ids,
55
+ return_dict=False,
56
+ )
57
+ return unet_out
58
+
59
+
60
+ class RBLNUNetSpatioTemporalConditionModel(RBLNModel):
61
+ hf_library_name = "diffusers"
62
+ auto_model_class = UNetSpatioTemporalConditionModel
63
+ _rbln_config_class = RBLNUNetSpatioTemporalConditionModelConfig
64
+ output_class = UNetSpatioTemporalConditionOutput
65
+ output_key = "sample"
66
+
67
+ def __post_init__(self, **kwargs):
68
+ super().__post_init__(**kwargs)
69
+ self.in_features = self.rbln_config.in_features
70
+ if self.in_features is not None:
71
+
72
+ @dataclass
73
+ class LINEAR1:
74
+ in_features: int
75
+
76
+ @dataclass
77
+ class ADDEMBEDDING:
78
+ linear_1: LINEAR1
79
+
80
+ self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
81
+
82
+ @classmethod
83
+ def _wrap_model_if_needed(
84
+ cls, model: torch.nn.Module, rbln_config: RBLNUNetSpatioTemporalConditionModelConfig
85
+ ) -> torch.nn.Module:
86
+ return _UNet_STCM(model).eval()
87
+
88
+ @classmethod
89
+ def get_unet_sample_size(
90
+ cls,
91
+ pipe: RBLNDiffusionMixin,
92
+ rbln_config: RBLNUNetSpatioTemporalConditionModelConfig,
93
+ image_size: Optional[Tuple[int, int]] = None,
94
+ ) -> Union[int, Tuple[int, int]]:
95
+ scale_factor = pipe.vae_scale_factor
96
+
97
+ if image_size is None:
98
+ vae_sample_size = pipe.vae.config.sample_size
99
+ if isinstance(vae_sample_size, int):
100
+ vae_sample_size = (vae_sample_size, vae_sample_size)
101
+
102
+ sample_size = (
103
+ vae_sample_size[0] // scale_factor,
104
+ vae_sample_size[1] // scale_factor,
105
+ )
106
+ else:
107
+ sample_size = (image_size[0] // scale_factor, image_size[1] // scale_factor)
108
+ return sample_size
109
+
110
+ @classmethod
111
+ def update_rbln_config_using_pipe(
112
+ cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
113
+ ) -> Dict[str, Any]:
114
+ rbln_config.unet.sample_size = cls.get_unet_sample_size(
115
+ pipe, rbln_config.unet, image_size=rbln_config.image_size
116
+ )
117
+ return rbln_config
118
+
119
+ @classmethod
120
+ def _update_rbln_config(
121
+ cls,
122
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor"],
123
+ model: "PreTrainedModel",
124
+ model_config: "PretrainedConfig",
125
+ rbln_config: RBLNUNetSpatioTemporalConditionModelConfig,
126
+ ) -> RBLNUNetSpatioTemporalConditionModelConfig:
127
+ if rbln_config.num_frames is None:
128
+ rbln_config.num_frames = model_config.num_frames
129
+
130
+ if rbln_config.sample_size is None:
131
+ rbln_config.sample_size = model_config.sample_size
132
+
133
+ input_info = [
134
+ (
135
+ "sample",
136
+ [
137
+ rbln_config.batch_size,
138
+ rbln_config.num_frames,
139
+ model_config.in_channels,
140
+ rbln_config.sample_size[0],
141
+ rbln_config.sample_size[1],
142
+ ],
143
+ "float32",
144
+ ),
145
+ ("timestep", [], "float32"),
146
+ ("encoder_hidden_states", [rbln_config.batch_size, 1, model_config.cross_attention_dim], "float32"),
147
+ ("added_time_ids", [rbln_config.batch_size, 3], "float32"),
148
+ ]
149
+
150
+ if hasattr(model_config, "addition_time_embed_dim"):
151
+ rbln_config.in_features = model_config.projection_class_embeddings_input_dim
152
+
153
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
154
+ rbln_config.set_compile_cfgs([rbln_compile_config])
155
+
156
+ return rbln_config
157
+
158
+ @property
159
+ def compiled_batch_size(self):
160
+ return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
161
+
162
+ def forward(
163
+ self,
164
+ sample: torch.Tensor,
165
+ timestep: Union[torch.Tensor, float, int],
166
+ encoder_hidden_states: torch.Tensor,
167
+ added_time_ids: torch.Tensor,
168
+ return_dict: bool = True,
169
+ **kwargs,
170
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
171
+ """
172
+ Forward pass for the RBLN-optimized UNetSpatioTemporalConditionModel.
173
+
174
+ Args:
175
+ sample (torch.Tensor): The noisy input tensor with the following shape `(batch, channel, height, width)`.
176
+ timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
177
+ encoder_hidden_states (torch.Tensor): The encoder hidden states.
178
+ added_time_ids (torch.Tensor): A tensor containing additional sinusoidal embeddings and added to the time embeddings.
179
+ return_dict (bool): Whether or not to return a [`~diffusers.models.unets.unet_spatio_temporal_condition.UNetSpatioTemporalConditionOutput`] instead of a plain tuple.
180
+
181
+ Returns:
182
+ (Union[`~diffusers.models.unets.unet_spatio_temporal_condition.UNetSpatioTemporalConditionOutput`], Tuple)
183
+ """
184
+ sample_batch_size = sample.size()[0]
185
+ compiled_batch_size = self.compiled_batch_size
186
+ if sample_batch_size != compiled_batch_size and (
187
+ sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
188
+ ):
189
+ raise ValueError(
190
+ f"Mismatch between UNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
191
+ "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
192
+ "Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
193
+ "For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
194
+ )
195
+ return super().forward(
196
+ sample.contiguous(),
197
+ timestep.float(),
198
+ encoder_hidden_states,
199
+ added_time_ids,
200
+ return_dict=return_dict,
201
+ )
@@ -18,6 +18,11 @@ from transformers.utils import _LazyModule
18
18
 
19
19
 
20
20
  _import_structure = {
21
+ "auto_pipeline": [
22
+ "RBLNAutoPipelineForImage2Image",
23
+ "RBLNAutoPipelineForInpainting",
24
+ "RBLNAutoPipelineForText2Image",
25
+ ],
21
26
  "controlnet": [
22
27
  "RBLNMultiControlNetModel",
23
28
  "RBLNStableDiffusionControlNetImg2ImgPipeline",
@@ -54,8 +59,16 @@ _import_structure = {
54
59
  "RBLNStableDiffusion3Img2ImgPipeline",
55
60
  "RBLNStableDiffusion3InpaintPipeline",
56
61
  ],
62
+ "stable_video_diffusion": [
63
+ "RBLNStableVideoDiffusionPipeline",
64
+ ],
57
65
  }
58
66
  if TYPE_CHECKING:
67
+ from .auto_pipeline import (
68
+ RBLNAutoPipelineForImage2Image,
69
+ RBLNAutoPipelineForInpainting,
70
+ RBLNAutoPipelineForText2Image,
71
+ )
59
72
  from .controlnet import (
60
73
  RBLNMultiControlNetModel,
61
74
  RBLNStableDiffusionControlNetImg2ImgPipeline,
@@ -63,11 +76,7 @@ if TYPE_CHECKING:
63
76
  RBLNStableDiffusionXLControlNetImg2ImgPipeline,
64
77
  RBLNStableDiffusionXLControlNetPipeline,
65
78
  )
66
- from .cosmos import (
67
- RBLNCosmosSafetyChecker,
68
- RBLNCosmosTextToWorldPipeline,
69
- RBLNCosmosVideoToWorldPipeline,
70
- )
79
+ from .cosmos import RBLNCosmosSafetyChecker, RBLNCosmosTextToWorldPipeline, RBLNCosmosVideoToWorldPipeline
71
80
  from .kandinsky2_2 import (
72
81
  RBLNKandinskyV22CombinedPipeline,
73
82
  RBLNKandinskyV22Img2ImgCombinedPipeline,
@@ -92,6 +101,7 @@ if TYPE_CHECKING:
92
101
  RBLNStableDiffusionXLInpaintPipeline,
93
102
  RBLNStableDiffusionXLPipeline,
94
103
  )
104
+ from .stable_video_diffusion import RBLNStableVideoDiffusionPipeline
95
105
  else:
96
106
  import sys
97
107
 
@@ -0,0 +1,307 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import importlib
17
+ from pathlib import Path
18
+ from typing import Any, Dict, Type, Union
19
+
20
+ from diffusers.models.controlnets import ControlNetUnionModel
21
+ from diffusers.pipelines.auto_pipeline import (
22
+ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
23
+ AUTO_INPAINT_PIPELINES_MAPPING,
24
+ AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
25
+ AutoPipelineForImage2Image,
26
+ AutoPipelineForInpainting,
27
+ AutoPipelineForText2Image,
28
+ _get_task_class,
29
+ )
30
+ from huggingface_hub.utils import validate_hf_hub_args
31
+
32
+ from optimum.rbln.configuration_utils import RBLNModelConfig
33
+ from optimum.rbln.modeling_base import RBLNBaseModel
34
+ from optimum.rbln.utils.model_utils import (
35
+ MODEL_MAPPING,
36
+ convert_hf_to_rbln_model_name,
37
+ convert_rbln_to_hf_model_name,
38
+ get_rbln_model_cls,
39
+ )
40
+
41
+
42
+ class RBLNAutoPipelineBase:
43
+ _model_mapping = None
44
+ _model_mapping_names = None
45
+
46
+ @classmethod
47
+ def get_rbln_cls(cls, pretrained_model_name_or_path: Union[str, Path], export: bool = None, **kwargs):
48
+ if isinstance(pretrained_model_name_or_path, Path):
49
+ pretrained_model_name_or_path = pretrained_model_name_or_path.as_posix()
50
+
51
+ if export is None:
52
+ export = not cls._is_compiled_pipeline(pretrained_model_name_or_path, **kwargs)
53
+
54
+ if export:
55
+ hf_model_class = cls.infer_hf_model_class(pretrained_model_name_or_path, **kwargs)
56
+ rbln_class_name = convert_hf_to_rbln_model_name(hf_model_class.__name__)
57
+ else:
58
+ rbln_class_name = cls.get_rbln_model_cls_name(pretrained_model_name_or_path, **kwargs)
59
+ if convert_rbln_to_hf_model_name(rbln_class_name) not in cls._model_mapping_names.values():
60
+ raise ValueError(
61
+ f"The architecture '{rbln_class_name}' is not supported by the `{cls.__name__}.from_pretrained()` method. "
62
+ "Please use the `from_pretrained()` method of the appropriate class to load this model, "
63
+ f"or directly use '{rbln_class_name}.from_pretrained()`."
64
+ )
65
+
66
+ try:
67
+ rbln_cls = get_rbln_model_cls(rbln_class_name)
68
+ except AttributeError as e:
69
+ raise AttributeError(
70
+ f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{pretrained_model_name_or_path}'. "
71
+ "Ensure that the class name is correctly mapped and available in the 'optimum.rbln' module."
72
+ ) from e
73
+
74
+ return rbln_cls
75
+
76
+ @classmethod
77
+ def get_rbln_model_cls_name(cls, pretrained_model_name_or_path: Union[str, Path], **kwargs):
78
+ """
79
+ Retrieve the path to the compiled model directory for a given RBLN model.
80
+
81
+ Args:
82
+ pretrained_model_name_or_path (str): Identifier of the model.
83
+
84
+ Returns:
85
+ str: Path to the compiled model directory.
86
+ """
87
+ model_index_config = cls.load_config(pretrained_model_name_or_path)
88
+
89
+ if "_class_name" not in model_index_config:
90
+ raise ValueError(
91
+ "The `_class_name` field is missing from model_index_config. This is unexpected and should be reported as an issue. "
92
+ "Please use the `from_pretrained()` method of the appropriate class to load this model."
93
+ )
94
+
95
+ return model_index_config["_class_name"]
96
+
97
+ @classmethod
98
+ def _is_compiled_pipeline(
99
+ cls,
100
+ pretrained_model_name_or_path: Union[str, Path],
101
+ cache_dir=None,
102
+ force_download=False,
103
+ proxies=None,
104
+ token=None,
105
+ local_files_only=False,
106
+ revision=None,
107
+ **kwargs,
108
+ ):
109
+ config: dict = cls.load_config(
110
+ pretrained_model_name_or_path,
111
+ cache_dir=cache_dir,
112
+ force_download=force_download,
113
+ proxies=proxies,
114
+ token=token,
115
+ local_files_only=local_files_only,
116
+ revision=revision,
117
+ )
118
+ for value in config.values():
119
+ if isinstance(value, list) and len(value) > 0 and value[0] == "optimum.rbln":
120
+ return True
121
+ return False
122
+
123
+ @classmethod
124
+ def infer_hf_model_class(
125
+ cls,
126
+ pretrained_model_or_path: Union[str, Path],
127
+ cache_dir=None,
128
+ force_download=False,
129
+ proxies=None,
130
+ token=None,
131
+ local_files_only=False,
132
+ revision=None,
133
+ **kwargs,
134
+ ):
135
+ config = cls.load_config(
136
+ pretrained_model_or_path,
137
+ cache_dir=cache_dir,
138
+ force_download=force_download,
139
+ proxies=proxies,
140
+ token=token,
141
+ local_files_only=local_files_only,
142
+ revision=revision,
143
+ )
144
+ pipeline_key_name = cls.get_pipeline_key_name(config, **kwargs)
145
+
146
+ pipeline_cls = _get_task_class(cls._model_mapping, pipeline_key_name)
147
+
148
+ return pipeline_cls
149
+
150
+ @classmethod
151
+ def get_pipeline_key_name(cls, config, **kwargs):
152
+ orig_class_name = config["_class_name"]
153
+ if "ControlPipeline" in orig_class_name:
154
+ to_replace = "ControlPipeline"
155
+ else:
156
+ to_replace = "Pipeline"
157
+
158
+ if "controlnet" in kwargs:
159
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
160
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
161
+ else:
162
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
163
+ if "enable_pag" in kwargs:
164
+ enable_pag = kwargs.pop("enable_pag")
165
+ if enable_pag:
166
+ orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
167
+
168
+ return orig_class_name
169
+
170
+ @classmethod
171
+ @validate_hf_hub_args
172
+ def from_pretrained(
173
+ cls,
174
+ model_id: Union[str, Path],
175
+ *,
176
+ export: bool = None,
177
+ rbln_config: Union[Dict[str, Any], RBLNModelConfig] = {},
178
+ **kwargs: Any,
179
+ ):
180
+ """
181
+ Load an RBLN-accelerated Diffusers pipeline from a pretrained checkpoint or a compiled RBLN artifact.
182
+
183
+ This method determines the concrete `RBLN*` model class that corresponds to the
184
+ underlying Diffusers pipeline architecture and dispatches to that class's
185
+ `from_pretrained()` implementation. If a compiled RBLN folder is detected at `model_id`
186
+ (or `export=False` is explicitly passed), it loads the compiled artifacts; otherwise it
187
+ compiles from the original Diffusers checkpoint.
188
+
189
+ Args:
190
+ model_id:
191
+ HF repo id or local path. For compiled models, this should point to a directory
192
+ (optionally under `subfolder`) that contains `*.rbln` files and `rbln_config.json`.
193
+ export:
194
+ Force compilation from a Diffusers checkpoint. When `None`, this is inferred by
195
+ checking whether compiled artifacts exist at `model_id`.
196
+ rbln_config:
197
+ RBLN compilation/runtime configuration. May be provided as a dictionary or as an
198
+ instance of the specific model's config class (e.g., `RBLNFluxPipelineConfig`).
199
+ kwargs: Additional keyword arguments.
200
+ - Arguments prefixed with `rbln_` are forwarded to the RBLN config.
201
+ - Remaining arguments are forwarded to the Diffusers loader.
202
+
203
+ Returns:
204
+ RBLNBaseModel: An instantiated RBLN model wrapping the Diffusers pipeline, ready for
205
+ inference on RBLN NPUs.
206
+
207
+ """
208
+ rbln_cls = cls.get_rbln_cls(model_id, export=export, **kwargs)
209
+ return rbln_cls.from_pretrained(model_id, export=export, rbln_config=rbln_config, **kwargs)
210
+
211
+ @staticmethod
212
+ def register(rbln_cls: Type[RBLNBaseModel], exist_ok=False):
213
+ """
214
+ Register a new RBLN model class.
215
+
216
+ Args:
217
+ rbln_cls (Type[RBLNBaseModel]): The RBLN model class to register.
218
+ exist_ok (bool): Whether to allow registering an already registered model.
219
+ """
220
+ if not issubclass(rbln_cls, RBLNBaseModel):
221
+ raise ValueError("`rbln_cls` must be a subclass of RBLNBaseModel.")
222
+
223
+ native_cls = getattr(importlib.import_module("optimum.rbln"), rbln_cls.__name__, None)
224
+ if rbln_cls.__name__ in MODEL_MAPPING or native_cls is not None:
225
+ if not exist_ok:
226
+ raise ValueError(f"Model for {rbln_cls.__name__} already registered.")
227
+
228
+ MODEL_MAPPING[rbln_cls.__name__] = rbln_cls
229
+
230
+
231
+ class RBLNAutoPipelineForText2Image(RBLNAutoPipelineBase, AutoPipelineForText2Image):
232
+ """Text2Image AutoPipeline for RBLN NPUs."""
233
+
234
+ _model_mapping = AUTO_TEXT2IMAGE_PIPELINES_MAPPING
235
+ _model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_TEXT2IMAGE_PIPELINES_MAPPING.items()}
236
+
237
+
238
+ class RBLNAutoPipelineForImage2Image(RBLNAutoPipelineBase, AutoPipelineForImage2Image):
239
+ """Image2Image AutoPipeline for RBLN NPUs."""
240
+
241
+ _model_mapping = AUTO_IMAGE2IMAGE_PIPELINES_MAPPING
242
+ _model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.items()}
243
+
244
+ @classmethod
245
+ def get_pipeline_key_name(cls, config, **kwargs):
246
+ orig_class_name = config["_class_name"]
247
+ # the `orig_class_name` can be:
248
+ # `- *Pipeline` (for regular text-to-image checkpoint)
249
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
250
+ # `- *Img2ImgPipeline` (for refiner checkpoint)
251
+ if "Img2Img" in orig_class_name:
252
+ to_replace = "Img2ImgPipeline"
253
+ elif "ControlPipeline" in orig_class_name:
254
+ to_replace = "ControlPipeline"
255
+ else:
256
+ to_replace = "Pipeline"
257
+
258
+ if "controlnet" in kwargs:
259
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
260
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
261
+ else:
262
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
263
+ if "enable_pag" in kwargs:
264
+ enable_pag = kwargs.pop("enable_pag")
265
+ if enable_pag:
266
+ orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
267
+
268
+ if to_replace == "ControlPipeline":
269
+ orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
270
+
271
+ return orig_class_name
272
+
273
+
274
+ class RBLNAutoPipelineForInpainting(RBLNAutoPipelineBase, AutoPipelineForInpainting):
275
+ """Inpainting AutoPipeline for RBLN NPUs."""
276
+
277
+ _model_mapping = AUTO_INPAINT_PIPELINES_MAPPING
278
+ _model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_INPAINT_PIPELINES_MAPPING.items()}
279
+
280
+ @classmethod
281
+ def get_pipeline_key_name(cls, config, **kwargs):
282
+ orig_class_name = config["_class_name"]
283
+
284
+ # The `orig_class_name`` can be:
285
+ # `- *InpaintPipeline` (for inpaint-specific checkpoint)
286
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
287
+ # - or *Pipeline (for regular text-to-image checkpoint)
288
+ if "Inpaint" in orig_class_name:
289
+ to_replace = "InpaintPipeline"
290
+ elif "ControlPipeline" in orig_class_name:
291
+ to_replace = "ControlPipeline"
292
+ else:
293
+ to_replace = "Pipeline"
294
+
295
+ if "controlnet" in kwargs:
296
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
297
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
298
+ else:
299
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
300
+ if "enable_pag" in kwargs:
301
+ enable_pag = kwargs.pop("enable_pag")
302
+ if enable_pag:
303
+ orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
304
+ if to_replace == "ControlPipeline":
305
+ orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
306
+
307
+ return orig_class_name
@@ -96,6 +96,26 @@ class RBLNMultiControlNetModel(RBLNModel):
96
96
  guess_mode: bool = False,
97
97
  return_dict: bool = True,
98
98
  ):
99
+ """
100
+ Forward pass for the RBLN-optimized MultiControlNetModel.
101
+
102
+ This method processes multiple ControlNet models in sequence, applying each one to the input sample
103
+ with its corresponding conditioning image and scale factor. The outputs from all ControlNets are
104
+ merged by addition to produce the final control signals.
105
+
106
+ Args:
107
+ sample (torch.FloatTensor): The noisy input tensor.
108
+ timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
109
+ encoder_hidden_states (torch.Tensor): The encoder hidden states from the text encoder.
110
+ controlnet_cond (List[torch.Tensor]): A list of conditional input tensors, one for each ControlNet model.
111
+ conditioning_scale (List[float]): A list of scale factors for each ControlNet output. Each scale
112
+ controls the strength of the corresponding ControlNet's influence on the generation.
113
+ return_dict (bool): Whether or not to return a dictionary instead of a plain tuple. Currently,
114
+ this method always returns a tuple regardless of this parameter.
115
+
116
+ Returns:
117
+ (Tuple[List[torch.Tensor], torch.Tensor])
118
+ """
99
119
  for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
100
120
  down_samples, mid_sample = controlnet(
101
121
  sample=sample.contiguous(),
@@ -12,10 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNAutoConfig, RBLNModelConfig
18
- from ....transformers import RBLNLlamaForCausalLMConfig, RBLNSiglipVisionModelConfig
18
+ from ....transformers import RBLNSiglipVisionModelConfig
19
19
 
20
20
 
21
21
  class RBLNVideoSafetyModelConfig(RBLNModelConfig):
@@ -56,11 +56,11 @@ class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
56
56
  Configuration class for RBLN Cosmos Safety Checker.
57
57
  """
58
58
 
59
- submodules = ["aegis", "video_safety_model", "face_blur_filter", "siglip_encoder"]
59
+ submodules = ["llamaguard3", "video_safety_model", "face_blur_filter", "siglip_encoder"]
60
60
 
61
61
  def __init__(
62
62
  self,
63
- aegis: Optional[RBLNModelConfig] = None,
63
+ llamaguard3: Optional[RBLNModelConfig] = None,
64
64
  video_safety_model: Optional[RBLNModelConfig] = None,
65
65
  face_blur_filter: Optional[RBLNModelConfig] = None,
66
66
  siglip_encoder: Optional[RBLNSiglipVisionModelConfig] = None,
@@ -69,37 +69,40 @@ class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
69
69
  image_size: Optional[Tuple[int, int]] = None,
70
70
  height: Optional[int] = None,
71
71
  width: Optional[int] = None,
72
- **kwargs: Dict[str, Any],
72
+ max_seq_len: Optional[int] = None,
73
+ **kwargs: Any,
73
74
  ):
74
75
  super().__init__(**kwargs)
75
76
  if height is not None and width is not None:
76
77
  image_size = (height, width)
77
78
 
79
+ if max_seq_len is None:
80
+ max_seq_len = 512
81
+
78
82
  tensor_parallel_size = kwargs.get("tensor_parallel_size")
79
83
 
80
- self.aegis = self.init_submodule_config(
81
- RBLNLlamaForCausalLMConfig,
82
- aegis,
84
+ self.llamaguard3 = self.initialize_submodule_config(
85
+ llamaguard3,
86
+ cls_name="RBLNLlamaForCausalLMConfig",
83
87
  batch_size=batch_size,
84
88
  tensor_parallel_size=tensor_parallel_size,
89
+ max_seq_len=max_seq_len,
85
90
  )
86
-
87
- self.siglip_encoder = self.init_submodule_config(
88
- RBLNSiglipVisionModelConfig,
91
+ self.siglip_encoder = self.initialize_submodule_config(
89
92
  siglip_encoder,
93
+ cls_name="RBLNSiglipVisionModelConfig",
90
94
  batch_size=batch_size,
91
95
  image_size=(384, 384),
92
96
  )
93
-
94
- self.video_safety_model = self.init_submodule_config(
95
- RBLNVideoSafetyModelConfig,
97
+ self.video_safety_model = self.initialize_submodule_config(
96
98
  video_safety_model,
99
+ cls_name="RBLNVideoSafetyModelConfig",
97
100
  batch_size=batch_size,
98
101
  input_size=1152,
99
102
  )
100
- self.face_blur_filter = self.init_submodule_config(
101
- RBLNRetinaFaceFilterConfig,
103
+ self.face_blur_filter = self.initialize_submodule_config(
102
104
  face_blur_filter,
105
+ cls_name="RBLNRetinaFaceFilterConfig",
103
106
  batch_size=batch_size,
104
107
  image_size=image_size,
105
108
  )