optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. optimum/rbln/__init__.py +96 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +153 -42
  5. optimum/rbln/diffusers/__init__.py +7 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  12. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  13. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  20. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  21. optimum/rbln/diffusers/models/__init__.py +3 -13
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  23. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
  24. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  25. optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
  26. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
  27. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
  28. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
  29. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  30. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  31. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  32. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  34. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  36. optimum/rbln/modeling.py +71 -19
  37. optimum/rbln/modeling_base.py +99 -21
  38. optimum/rbln/ops/attn.py +158 -0
  39. optimum/rbln/ops/flash_attn.py +166 -0
  40. optimum/rbln/ops/kv_cache_update.py +5 -0
  41. optimum/rbln/ops/linear.py +7 -0
  42. optimum/rbln/transformers/__init__.py +92 -0
  43. optimum/rbln/transformers/configuration_generic.py +9 -7
  44. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  45. optimum/rbln/transformers/modeling_generic.py +51 -9
  46. optimum/rbln/transformers/modeling_outputs.py +37 -0
  47. optimum/rbln/transformers/models/__init__.py +91 -30
  48. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  49. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  50. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  51. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  52. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  53. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  54. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  55. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
  57. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  58. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
  59. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  60. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  61. optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
  62. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  63. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  64. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  65. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  66. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  67. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
  68. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  69. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
  71. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
  72. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  73. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
  74. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  75. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  76. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  77. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  78. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  79. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  80. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  81. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  82. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  83. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  84. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  85. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
  86. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  87. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  88. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  89. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  90. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  91. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  92. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  93. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  94. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  95. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
  96. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  97. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  98. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  99. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  100. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  101. optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
  102. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  103. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
  104. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  105. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  106. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  107. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  108. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  109. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  110. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  111. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  112. optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
  113. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  114. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  115. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  116. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  117. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  118. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  119. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  120. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  121. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  122. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  123. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  124. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
  125. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  126. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  127. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  128. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  129. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  130. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
  131. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  132. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  133. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  134. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
  135. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  136. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  137. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  138. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  139. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
  140. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
  141. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  142. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  143. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  144. optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
  145. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  146. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  147. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  148. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  149. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  150. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  151. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  152. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  153. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
  154. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  155. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  156. optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
  157. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  158. optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
  159. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  160. optimum/rbln/utils/depreacate_utils.py +16 -0
  161. optimum/rbln/utils/runtime_utils.py +28 -18
  162. optimum/rbln/utils/submodule.py +31 -9
  163. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
  164. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
  165. optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
  166. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
  167. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Dict, List, Tuple, Union
15
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
16
16
 
17
17
  import rebel
18
18
  import torch
@@ -214,13 +214,41 @@ class RBLNAutoencoderKL(RBLNModel):
214
214
  for compiled_model, device_val in zip(compiled_models, device_vals)
215
215
  ]
216
216
 
217
- def encode(self, x: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
217
+ def encode(
218
+ self, x: torch.FloatTensor, return_dict: bool = True, **kwargs: Dict[str, Any]
219
+ ) -> Union[torch.FloatTensor, AutoencoderKLOutput]:
220
+ """
221
+ Encode an input image into a latent representation.
222
+
223
+ Args:
224
+ x: The input image to encode.
225
+ return_dict:
226
+ Whether to return output as a dictionary. Defaults to True.
227
+ kwargs: Additional arguments to pass to the encoder.
228
+
229
+ Returns:
230
+ The latent representation or AutoencoderKLOutput if return_dict=True
231
+ """
218
232
  posterior = self.encoder.encode(x)
219
233
  if not return_dict:
220
234
  return (posterior,)
221
235
  return AutoencoderKLOutput(latent_dist=posterior)
222
236
 
223
- def decode(self, z: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
237
+ def decode(
238
+ self, z: torch.FloatTensor, return_dict: bool = True, **kwargs: Dict[str, Any]
239
+ ) -> Union[torch.FloatTensor, DecoderOutput]:
240
+ """
241
+ Decode a latent representation into an image.
242
+
243
+ Args:
244
+ z: The latent representation to decode.
245
+ return_dict:
246
+ Whether to return output as a dictionary. Defaults to True.
247
+ kwargs: Additional arguments to pass to the decoder.
248
+
249
+ Returns:
250
+ The decoded image or DecoderOutput if return_dict=True
251
+ """
224
252
  dec = self.decoder.decode(z)
225
253
  if not return_dict:
226
254
  return (dec,)
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Dict, List, Union
15
+ from typing import TYPE_CHECKING, Any, Dict, List, Union
16
16
 
17
17
  import rebel
18
18
  import torch
@@ -205,13 +205,38 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
205
205
  for compiled_model, device_val in zip(compiled_models, device_vals)
206
206
  ]
207
207
 
208
- def encode(self, x: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
208
+ def encode(
209
+ self, x: torch.FloatTensor, return_dict: bool = True, **kwargs: Dict[str, Any]
210
+ ) -> Union[torch.FloatTensor, AutoencoderKLOutput]:
211
+ """
212
+ Encode an input video into a latent representation.
213
+
214
+ Args:
215
+ x: The input video to encode.
216
+ return_dict:
217
+ Whether to return output as a dictionary. Defaults to True.
218
+ kwargs: Additional arguments to pass to the encoder.
219
+
220
+ Returns:
221
+ The latent representation or AutoencoderKLOutput if return_dict=True
222
+ """
209
223
  posterior = self.encoder.encode(x)
210
224
  if not return_dict:
211
225
  return (posterior,)
212
226
  return AutoencoderKLOutput(latent_dist=posterior)
213
227
 
214
- def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
228
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[torch.FloatTensor, DecoderOutput]:
229
+ """
230
+ Decode a latent representation into a video.
231
+
232
+ Args:
233
+ z: The latent representation to decode.
234
+ return_dict:
235
+ Whether to return output as a dictionary. Defaults to True.
236
+
237
+ Returns:
238
+ The decoded video or DecoderOutput if return_dict=True
239
+ """
215
240
  decoded = self.decoder.decode(z)
216
241
 
217
242
  if not return_dict:
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, List, Union
15
+ from typing import TYPE_CHECKING, Any, List, Union
16
16
 
17
17
  import rebel
18
18
  import torch
@@ -170,13 +170,41 @@ class RBLNVQModel(RBLNModel):
170
170
  for compiled_model, device_val in zip(compiled_models, device_vals)
171
171
  ]
172
172
 
173
- def encode(self, x: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
173
+ def encode(
174
+ self, x: torch.FloatTensor, return_dict: bool = True, **kwargs: Any
175
+ ) -> Union[torch.FloatTensor, VQEncoderOutput]:
176
+ """
177
+ Encode an input image into a quantized latent representation.
178
+
179
+ Args:
180
+ x: The input image to encode.
181
+ return_dict:
182
+ Whether to return output as a dictionary. Defaults to True.
183
+ kwargs: Additional arguments to pass to the encoder/quantizer.
184
+
185
+ Returns:
186
+ The quantized latent representation or a specific output object.
187
+ """
174
188
  posterior = self.encoder.encode(x)
175
189
  if not return_dict:
176
190
  return (posterior,)
177
191
  return VQEncoderOutput(latents=posterior)
178
192
 
179
- def decode(self, h: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
193
+ def decode(
194
+ self, h: torch.FloatTensor, return_dict: bool = True, **kwargs: Any
195
+ ) -> Union[torch.FloatTensor, DecoderOutput]:
196
+ """
197
+ Decode a quantized latent representation back into an image.
198
+
199
+ Args:
200
+ h: The quantized latent representation to decode.
201
+ return_dict:
202
+ Whether to return output as a dictionary. Defaults to True.
203
+ kwargs: Additional arguments to pass to the decoder.
204
+
205
+ Returns:
206
+ The decoded image or a DecoderOutput object.
207
+ """
180
208
  dec, commit_loss = self.decoder.decode(h, **kwargs)
181
209
  if not return_dict:
182
210
  return (dec, commit_loss)
@@ -59,7 +59,7 @@ class RBLNPriorTransformer(RBLNModel):
59
59
  """
60
60
  RBLN implementation of PriorTransformer for diffusion models like Kandinsky V2.2.
61
61
 
62
- The Prior Transformer takes text and/or image embeddings from encoders (like CLIP) and
62
+ The PriorTransformer takes text and/or image embeddings from encoders (like CLIP) and
63
63
  maps them to a shared latent space that guides the diffusion process to generate the desired image.
64
64
 
65
65
  This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
@@ -94,7 +94,15 @@ class CosmosTransformer3DModelWrapper(torch.nn.Module):
94
94
 
95
95
 
96
96
  class RBLNCosmosTransformer3DModel(RBLNModel):
97
- """RBLN wrapper for the Cosmos Transformer model."""
97
+ """
98
+ RBLN implementation of CosmosTransformer3DModel for diffusion models like Cosmos.
99
+
100
+ The CosmosTransformer3DModel takes text and/or image embeddings from encoders (like CLIP) and
101
+ maps them to a shared latent space that guides the diffusion process to generate the desired image.
102
+
103
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
104
+ the library implements for all its models.
105
+ """
98
106
 
99
107
  hf_library_name = "diffusers"
100
108
  auto_model_class = CosmosTransformer3DModel
@@ -59,7 +59,15 @@ class SD3Transformer2DModelWrapper(torch.nn.Module):
59
59
 
60
60
 
61
61
  class RBLNSD3Transformer2DModel(RBLNModel):
62
- """RBLN wrapper for the Stable Diffusion 3 MMDiT Transformer model."""
62
+ """
63
+ RBLN implementation of SD3Transformer2DModel for diffusion models like Stable Diffusion 3.
64
+
65
+ The SD3Transformer2DModel takes text and/or image embeddings from encoders (like CLIP) and
66
+ maps them to a shared latent space that guides the diffusion process to generate the desired image.
67
+
68
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
69
+ the library implements for all its models.
70
+ """
63
71
 
64
72
  hf_library_name = "diffusers"
65
73
  auto_model_class = SD3Transformer2DModel
@@ -141,10 +141,13 @@ class _UNet_Kandinsky(torch.nn.Module):
141
141
 
142
142
  class RBLNUNet2DConditionModel(RBLNModel):
143
143
  """
144
- Configuration class for RBLN UNet2DCondition models.
144
+ RBLN implementation of UNet2DConditionModel for diffusion models.
145
145
 
146
- This class inherits from RBLNModelConfig and provides specific configuration options
147
- for UNet2DCondition models used in diffusion-based image generation.
146
+ This model is used to accelerate UNet2DCondition models from diffusers library on RBLN NPUs.
147
+ It is a key component in diffusion-based image generation models like Stable Diffusion.
148
+
149
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
150
+ the library implements for all its models.
148
151
  """
149
152
 
150
153
  hf_library_name = "diffusers"
@@ -18,6 +18,11 @@ from transformers.utils import _LazyModule
18
18
 
19
19
 
20
20
  _import_structure = {
21
+ "auto_pipeline": [
22
+ "RBLNAutoPipelineForImage2Image",
23
+ "RBLNAutoPipelineForInpainting",
24
+ "RBLNAutoPipelineForText2Image",
25
+ ],
21
26
  "controlnet": [
22
27
  "RBLNMultiControlNetModel",
23
28
  "RBLNStableDiffusionControlNetImg2ImgPipeline",
@@ -56,6 +61,11 @@ _import_structure = {
56
61
  ],
57
62
  }
58
63
  if TYPE_CHECKING:
64
+ from .auto_pipeline import (
65
+ RBLNAutoPipelineForImage2Image,
66
+ RBLNAutoPipelineForInpainting,
67
+ RBLNAutoPipelineForText2Image,
68
+ )
59
69
  from .controlnet import (
60
70
  RBLNMultiControlNetModel,
61
71
  RBLNStableDiffusionControlNetImg2ImgPipeline,
@@ -63,11 +73,7 @@ if TYPE_CHECKING:
63
73
  RBLNStableDiffusionXLControlNetImg2ImgPipeline,
64
74
  RBLNStableDiffusionXLControlNetPipeline,
65
75
  )
66
- from .cosmos import (
67
- RBLNCosmosSafetyChecker,
68
- RBLNCosmosTextToWorldPipeline,
69
- RBLNCosmosVideoToWorldPipeline,
70
- )
76
+ from .cosmos import RBLNCosmosSafetyChecker, RBLNCosmosTextToWorldPipeline, RBLNCosmosVideoToWorldPipeline
71
77
  from .kandinsky2_2 import (
72
78
  RBLNKandinskyV22CombinedPipeline,
73
79
  RBLNKandinskyV22Img2ImgCombinedPipeline,
@@ -0,0 +1,307 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import importlib
17
+ from pathlib import Path
18
+ from typing import Any, Dict, Type, Union
19
+
20
+ from diffusers.models.controlnets import ControlNetUnionModel
21
+ from diffusers.pipelines.auto_pipeline import (
22
+ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
23
+ AUTO_INPAINT_PIPELINES_MAPPING,
24
+ AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
25
+ AutoPipelineForImage2Image,
26
+ AutoPipelineForInpainting,
27
+ AutoPipelineForText2Image,
28
+ _get_task_class,
29
+ )
30
+ from huggingface_hub.utils import validate_hf_hub_args
31
+
32
+ from optimum.rbln.configuration_utils import RBLNModelConfig
33
+ from optimum.rbln.modeling_base import RBLNBaseModel
34
+ from optimum.rbln.utils.model_utils import (
35
+ MODEL_MAPPING,
36
+ convert_hf_to_rbln_model_name,
37
+ convert_rbln_to_hf_model_name,
38
+ get_rbln_model_cls,
39
+ )
40
+
41
+
42
+ class RBLNAutoPipelineBase:
43
+ _model_mapping = None
44
+ _model_mapping_names = None
45
+
46
+ @classmethod
47
+ def get_rbln_cls(cls, pretrained_model_name_or_path: Union[str, Path], export: bool = None, **kwargs):
48
+ if isinstance(pretrained_model_name_or_path, Path):
49
+ pretrained_model_name_or_path = pretrained_model_name_or_path.as_posix()
50
+
51
+ if export is None:
52
+ export = not cls._is_compiled_pipeline(pretrained_model_name_or_path, **kwargs)
53
+
54
+ if export:
55
+ hf_model_class = cls.infer_hf_model_class(pretrained_model_name_or_path, **kwargs)
56
+ rbln_class_name = convert_hf_to_rbln_model_name(hf_model_class.__name__)
57
+ else:
58
+ rbln_class_name = cls.get_rbln_model_cls_name(pretrained_model_name_or_path, **kwargs)
59
+ if convert_rbln_to_hf_model_name(rbln_class_name) not in cls._model_mapping_names.values():
60
+ raise ValueError(
61
+ f"The architecture '{rbln_class_name}' is not supported by the `{cls.__name__}.from_pretrained()` method. "
62
+ "Please use the `from_pretrained()` method of the appropriate class to load this model, "
63
+ f"or directly use '{rbln_class_name}.from_pretrained()`."
64
+ )
65
+
66
+ try:
67
+ rbln_cls = get_rbln_model_cls(rbln_class_name)
68
+ except AttributeError as e:
69
+ raise AttributeError(
70
+ f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{pretrained_model_name_or_path}'. "
71
+ "Ensure that the class name is correctly mapped and available in the 'optimum.rbln' module."
72
+ ) from e
73
+
74
+ return rbln_cls
75
+
76
+ @classmethod
77
+ def get_rbln_model_cls_name(cls, pretrained_model_name_or_path: Union[str, Path], **kwargs):
78
+ """
79
+ Retrieve the path to the compiled model directory for a given RBLN model.
80
+
81
+ Args:
82
+ pretrained_model_name_or_path (str): Identifier of the model.
83
+
84
+ Returns:
85
+ str: Path to the compiled model directory.
86
+ """
87
+ model_index_config = cls.load_config(pretrained_model_name_or_path)
88
+
89
+ if "_class_name" not in model_index_config:
90
+ raise ValueError(
91
+ "The `_class_name` field is missing from model_index_config. This is unexpected and should be reported as an issue. "
92
+ "Please use the `from_pretrained()` method of the appropriate class to load this model."
93
+ )
94
+
95
+ return model_index_config["_class_name"]
96
+
97
+ @classmethod
98
+ def _is_compiled_pipeline(
99
+ cls,
100
+ pretrained_model_name_or_path: Union[str, Path],
101
+ cache_dir=None,
102
+ force_download=False,
103
+ proxies=None,
104
+ token=None,
105
+ local_files_only=False,
106
+ revision=None,
107
+ **kwargs,
108
+ ):
109
+ config: dict = cls.load_config(
110
+ pretrained_model_name_or_path,
111
+ cache_dir=cache_dir,
112
+ force_download=force_download,
113
+ proxies=proxies,
114
+ token=token,
115
+ local_files_only=local_files_only,
116
+ revision=revision,
117
+ )
118
+ for value in config.values():
119
+ if isinstance(value, list) and len(value) > 0 and value[0] == "optimum.rbln":
120
+ return True
121
+ return False
122
+
123
+ @classmethod
124
+ def infer_hf_model_class(
125
+ cls,
126
+ pretrained_model_or_path: Union[str, Path],
127
+ cache_dir=None,
128
+ force_download=False,
129
+ proxies=None,
130
+ token=None,
131
+ local_files_only=False,
132
+ revision=None,
133
+ **kwargs,
134
+ ):
135
+ config = cls.load_config(
136
+ pretrained_model_or_path,
137
+ cache_dir=cache_dir,
138
+ force_download=force_download,
139
+ proxies=proxies,
140
+ token=token,
141
+ local_files_only=local_files_only,
142
+ revision=revision,
143
+ )
144
+ pipeline_key_name = cls.get_pipeline_key_name(config, **kwargs)
145
+
146
+ pipeline_cls = _get_task_class(cls._model_mapping, pipeline_key_name)
147
+
148
+ return pipeline_cls
149
+
150
+ @classmethod
151
+ def get_pipeline_key_name(cls, config, **kwargs):
152
+ orig_class_name = config["_class_name"]
153
+ if "ControlPipeline" in orig_class_name:
154
+ to_replace = "ControlPipeline"
155
+ else:
156
+ to_replace = "Pipeline"
157
+
158
+ if "controlnet" in kwargs:
159
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
160
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
161
+ else:
162
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
163
+ if "enable_pag" in kwargs:
164
+ enable_pag = kwargs.pop("enable_pag")
165
+ if enable_pag:
166
+ orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
167
+
168
+ return orig_class_name
169
+
170
+ @classmethod
171
+ @validate_hf_hub_args
172
+ def from_pretrained(
173
+ cls,
174
+ model_id: Union[str, Path],
175
+ *,
176
+ export: bool = None,
177
+ rbln_config: Union[Dict[str, Any], RBLNModelConfig] = {},
178
+ **kwargs: Any,
179
+ ):
180
+ """
181
+ Load an RBLN-accelerated Diffusers pipeline from a pretrained checkpoint or a compiled RBLN artifact.
182
+
183
+ This method determines the concrete `RBLN*` model class that corresponds to the
184
+ underlying Diffusers pipeline architecture and dispatches to that class's
185
+ `from_pretrained()` implementation. If a compiled RBLN folder is detected at `model_id`
186
+ (or `export=False` is explicitly passed), it loads the compiled artifacts; otherwise it
187
+ compiles from the original Diffusers checkpoint.
188
+
189
+ Args:
190
+ model_id:
191
+ HF repo id or local path. For compiled models, this should point to a directory
192
+ (optionally under `subfolder`) that contains `*.rbln` files and `rbln_config.json`.
193
+ export:
194
+ Force compilation from a Diffusers checkpoint. When `None`, this is inferred by
195
+ checking whether compiled artifacts exist at `model_id`.
196
+ rbln_config:
197
+ RBLN compilation/runtime configuration. May be provided as a dictionary or as an
198
+ instance of the specific model's config class (e.g., `RBLNFluxPipelineConfig`).
199
+ kwargs: Additional keyword arguments.
200
+ - Arguments prefixed with `rbln_` are forwarded to the RBLN config.
201
+ - Remaining arguments are forwarded to the Diffusers loader.
202
+
203
+ Returns:
204
+ RBLNBaseModel: An instantiated RBLN model wrapping the Diffusers pipeline, ready for
205
+ inference on RBLN NPUs.
206
+
207
+ """
208
+ rbln_cls = cls.get_rbln_cls(model_id, export=export, **kwargs)
209
+ return rbln_cls.from_pretrained(model_id, export=export, rbln_config=rbln_config, **kwargs)
210
+
211
+ @staticmethod
212
+ def register(rbln_cls: Type[RBLNBaseModel], exist_ok=False):
213
+ """
214
+ Register a new RBLN model class.
215
+
216
+ Args:
217
+ rbln_cls (Type[RBLNBaseModel]): The RBLN model class to register.
218
+ exist_ok (bool): Whether to allow registering an already registered model.
219
+ """
220
+ if not issubclass(rbln_cls, RBLNBaseModel):
221
+ raise ValueError("`rbln_cls` must be a subclass of RBLNBaseModel.")
222
+
223
+ native_cls = getattr(importlib.import_module("optimum.rbln"), rbln_cls.__name__, None)
224
+ if rbln_cls.__name__ in MODEL_MAPPING or native_cls is not None:
225
+ if not exist_ok:
226
+ raise ValueError(f"Model for {rbln_cls.__name__} already registered.")
227
+
228
+ MODEL_MAPPING[rbln_cls.__name__] = rbln_cls
229
+
230
+
231
+ class RBLNAutoPipelineForText2Image(RBLNAutoPipelineBase, AutoPipelineForText2Image):
232
+ """Text2Image AutoPipeline for RBLN NPUs."""
233
+
234
+ _model_mapping = AUTO_TEXT2IMAGE_PIPELINES_MAPPING
235
+ _model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_TEXT2IMAGE_PIPELINES_MAPPING.items()}
236
+
237
+
238
+ class RBLNAutoPipelineForImage2Image(RBLNAutoPipelineBase, AutoPipelineForImage2Image):
239
+ """Image2Image AutoPipeline for RBLN NPUs."""
240
+
241
+ _model_mapping = AUTO_IMAGE2IMAGE_PIPELINES_MAPPING
242
+ _model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.items()}
243
+
244
+ @classmethod
245
+ def get_pipeline_key_name(cls, config, **kwargs):
246
+ orig_class_name = config["_class_name"]
247
+ # the `orig_class_name` can be:
248
+ # `- *Pipeline` (for regular text-to-image checkpoint)
249
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
250
+ # `- *Img2ImgPipeline` (for refiner checkpoint)
251
+ if "Img2Img" in orig_class_name:
252
+ to_replace = "Img2ImgPipeline"
253
+ elif "ControlPipeline" in orig_class_name:
254
+ to_replace = "ControlPipeline"
255
+ else:
256
+ to_replace = "Pipeline"
257
+
258
+ if "controlnet" in kwargs:
259
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
260
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
261
+ else:
262
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
263
+ if "enable_pag" in kwargs:
264
+ enable_pag = kwargs.pop("enable_pag")
265
+ if enable_pag:
266
+ orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
267
+
268
+ if to_replace == "ControlPipeline":
269
+ orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
270
+
271
+ return orig_class_name
272
+
273
+
274
+ class RBLNAutoPipelineForInpainting(RBLNAutoPipelineBase, AutoPipelineForInpainting):
275
+ """Inpainting AutoPipeline for RBLN NPUs."""
276
+
277
+ _model_mapping = AUTO_INPAINT_PIPELINES_MAPPING
278
+ _model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_INPAINT_PIPELINES_MAPPING.items()}
279
+
280
+ @classmethod
281
+ def get_pipeline_key_name(cls, config, **kwargs):
282
+ orig_class_name = config["_class_name"]
283
+
284
+ # The `orig_class_name`` can be:
285
+ # `- *InpaintPipeline` (for inpaint-specific checkpoint)
286
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
287
+ # - or *Pipeline (for regular text-to-image checkpoint)
288
+ if "Inpaint" in orig_class_name:
289
+ to_replace = "InpaintPipeline"
290
+ elif "ControlPipeline" in orig_class_name:
291
+ to_replace = "ControlPipeline"
292
+ else:
293
+ to_replace = "Pipeline"
294
+
295
+ if "controlnet" in kwargs:
296
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
297
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
298
+ else:
299
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
300
+ if "enable_pag" in kwargs:
301
+ enable_pag = kwargs.pop("enable_pag")
302
+ if enable_pag:
303
+ orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
304
+ if to_replace == "ControlPipeline":
305
+ orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
306
+
307
+ return orig_class_name
@@ -12,10 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNAutoConfig, RBLNModelConfig
18
- from ....transformers import RBLNLlamaForCausalLMConfig, RBLNSiglipVisionModelConfig
18
+ from ....transformers import RBLNSiglipVisionModelConfig
19
19
 
20
20
 
21
21
  class RBLNVideoSafetyModelConfig(RBLNModelConfig):
@@ -56,11 +56,11 @@ class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
56
56
  Configuration class for RBLN Cosmos Safety Checker.
57
57
  """
58
58
 
59
- submodules = ["aegis", "video_safety_model", "face_blur_filter", "siglip_encoder"]
59
+ submodules = ["llamaguard3", "video_safety_model", "face_blur_filter", "siglip_encoder"]
60
60
 
61
61
  def __init__(
62
62
  self,
63
- aegis: Optional[RBLNModelConfig] = None,
63
+ llamaguard3: Optional[RBLNModelConfig] = None,
64
64
  video_safety_model: Optional[RBLNModelConfig] = None,
65
65
  face_blur_filter: Optional[RBLNModelConfig] = None,
66
66
  siglip_encoder: Optional[RBLNSiglipVisionModelConfig] = None,
@@ -69,37 +69,40 @@ class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
69
69
  image_size: Optional[Tuple[int, int]] = None,
70
70
  height: Optional[int] = None,
71
71
  width: Optional[int] = None,
72
- **kwargs: Dict[str, Any],
72
+ max_seq_len: Optional[int] = None,
73
+ **kwargs: Any,
73
74
  ):
74
75
  super().__init__(**kwargs)
75
76
  if height is not None and width is not None:
76
77
  image_size = (height, width)
77
78
 
79
+ if max_seq_len is None:
80
+ max_seq_len = 512
81
+
78
82
  tensor_parallel_size = kwargs.get("tensor_parallel_size")
79
83
 
80
- self.aegis = self.init_submodule_config(
81
- RBLNLlamaForCausalLMConfig,
82
- aegis,
84
+ self.llamaguard3 = self.initialize_submodule_config(
85
+ llamaguard3,
86
+ cls_name="RBLNLlamaForCausalLMConfig",
83
87
  batch_size=batch_size,
84
88
  tensor_parallel_size=tensor_parallel_size,
89
+ max_seq_len=max_seq_len,
85
90
  )
86
-
87
- self.siglip_encoder = self.init_submodule_config(
88
- RBLNSiglipVisionModelConfig,
91
+ self.siglip_encoder = self.initialize_submodule_config(
89
92
  siglip_encoder,
93
+ cls_name="RBLNSiglipVisionModelConfig",
90
94
  batch_size=batch_size,
91
95
  image_size=(384, 384),
92
96
  )
93
-
94
- self.video_safety_model = self.init_submodule_config(
95
- RBLNVideoSafetyModelConfig,
97
+ self.video_safety_model = self.initialize_submodule_config(
96
98
  video_safety_model,
99
+ cls_name="RBLNVideoSafetyModelConfig",
97
100
  batch_size=batch_size,
98
101
  input_size=1152,
99
102
  )
100
- self.face_blur_filter = self.init_submodule_config(
101
- RBLNRetinaFaceFilterConfig,
103
+ self.face_blur_filter = self.initialize_submodule_config(
102
104
  face_blur_filter,
105
+ cls_name="RBLNRetinaFaceFilterConfig",
103
106
  batch_size=batch_size,
104
107
  image_size=image_size,
105
108
  )