optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. optimum/rbln/__init__.py +24 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +45 -33
  4. optimum/rbln/diffusers/__init__.py +21 -1
  5. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  13. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  14. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  15. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
  22. optimum/rbln/diffusers/modeling_diffusers.py +72 -65
  23. optimum/rbln/diffusers/models/__init__.py +4 -0
  24. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  25. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
  26. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
  27. optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
  28. optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
  29. optimum/rbln/diffusers/models/controlnet.py +14 -8
  30. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  31. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  32. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  33. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  34. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
  35. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  36. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  37. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  38. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  39. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  42. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  43. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  45. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  46. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  47. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  49. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  50. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  51. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  52. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  53. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  54. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  55. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  56. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  57. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  58. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  59. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  60. optimum/rbln/modeling.py +71 -37
  61. optimum/rbln/modeling_base.py +63 -109
  62. optimum/rbln/transformers/__init__.py +41 -47
  63. optimum/rbln/transformers/configuration_generic.py +16 -13
  64. optimum/rbln/transformers/modeling_generic.py +21 -22
  65. optimum/rbln/transformers/modeling_rope_utils.py +5 -2
  66. optimum/rbln/transformers/models/__init__.py +54 -4
  67. optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
  68. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  69. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  70. optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
  71. optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
  72. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  73. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  74. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  75. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  76. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
  77. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
  78. optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
  79. optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
  80. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
  82. optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
  83. optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
  84. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  85. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  86. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
  87. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  88. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  89. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
  90. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  91. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  92. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  93. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  94. optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
  95. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  96. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  97. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  98. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  99. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  100. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
  101. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  102. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  103. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  104. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  105. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  106. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  107. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
  108. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
  109. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  110. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  111. optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
  112. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  113. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  114. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  115. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  116. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  117. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  118. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  119. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  120. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  121. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  122. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
  123. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
  124. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
  125. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  126. optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
  127. optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
  128. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  129. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
  130. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
  131. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  132. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  133. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
  134. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
  135. optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
  136. optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
  137. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  138. optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
  139. optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
  140. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  141. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
  142. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
  143. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  144. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  145. optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
  146. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  147. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
  148. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  149. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
  150. optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
  151. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  152. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  153. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  154. optimum/rbln/utils/model_utils.py +20 -0
  155. optimum/rbln/utils/runtime_utils.py +49 -1
  156. optimum/rbln/utils/submodule.py +6 -8
  157. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
  158. optimum_rbln-0.8.1.dist-info/RECORD +211 -0
  159. optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
  160. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  161. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
  162. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -19,10 +19,11 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
19
19
 
20
20
  import torch
21
21
 
22
- from ..configuration_utils import ContextRblnConfig, RBLNModelConfig
22
+ from ..configuration_utils import ContextRblnConfig, RBLNModelConfig, get_rbln_config_class
23
23
  from ..modeling import RBLNModel
24
24
  from ..utils.decorator_utils import remove_compile_time_kwargs
25
25
  from ..utils.logging import get_logger
26
+ from ..utils.model_utils import get_rbln_model_cls
26
27
 
27
28
 
28
29
  logger = get_logger(__name__)
@@ -44,7 +45,7 @@ class RBLNDiffusionMixin:
44
45
  To use this mixin:
45
46
 
46
47
  1. Create a new pipeline class that inherits from both this mixin and the original StableDiffusionPipeline.
47
- 2. Define the required _submodules class variable listing the components to be compiled.
48
+ 2. Define the required _submodules and _optional_submodules class variable listing the components to be compiled.
48
49
 
49
50
  Example:
50
51
  ```python
@@ -54,6 +55,7 @@ class RBLNDiffusionMixin:
54
55
 
55
56
  Class Variables:
56
57
  _submodules: List of submodule names that should be compiled (typically ["text_encoder", "unet", "vae"])
58
+ _optional_submodules: List of submodule names compiled without inheriting RBLNModel (typically ["safety_checker"])
57
59
 
58
60
  Methods:
59
61
  from_pretrained: Creates and optionally compiles a model from a pretrained checkpoint
@@ -66,6 +68,7 @@ class RBLNDiffusionMixin:
66
68
 
67
69
  _connected_classes = {}
68
70
  _submodules = []
71
+ _optional_submodules = []
69
72
  _prefix = {}
70
73
  _rbln_config_class = None
71
74
  _hf_class = None
@@ -110,18 +113,10 @@ class RBLNDiffusionMixin:
110
113
 
111
114
  @classmethod
112
115
  def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
113
- """
114
- Lazily loads and caches the corresponding RBLN model config class.
115
- """
116
+ # Lazily loads and caches the corresponding RBLN model config class.
116
117
  if cls._rbln_config_class is None:
117
118
  rbln_config_class_name = cls.__name__ + "Config"
118
- library = importlib.import_module("optimum.rbln")
119
- cls._rbln_config_class = getattr(library, rbln_config_class_name, None)
120
- if cls._rbln_config_class is None:
121
- raise ValueError(
122
- f"RBLN config class {rbln_config_class_name} not found. This is an internal error. "
123
- "Please report it to the developers."
124
- )
119
+ cls._rbln_config_class = get_rbln_config_class(rbln_config_class_name)
125
120
  return cls._rbln_config_class
126
121
 
127
122
  @classmethod
@@ -143,7 +138,7 @@ class RBLNDiffusionMixin:
143
138
  lora_ids: Optional[Union[str, List[str]]] = None,
144
139
  lora_weights_names: Optional[Union[str, List[str]]] = None,
145
140
  lora_scales: Optional[Union[float, List[float]]] = None,
146
- **kwargs,
141
+ **kwargs: Dict[str, Any],
147
142
  ) -> "RBLNDiffusionMixin":
148
143
  """
149
144
  Load a pretrained diffusion pipeline from a model checkpoint, with optional compilation for RBLN NPUs.
@@ -157,24 +152,25 @@ class RBLNDiffusionMixin:
157
152
  Args:
158
153
  model_id (`str`):
159
154
  The model ID or path to the pretrained model to load. Can be either:
155
+
160
156
  - A model ID from the HuggingFace Hub
161
157
  - A local path to a saved model directory
162
- export (`bool`, *optional*, defaults to `False`):
158
+ export:
163
159
  If True, takes a PyTorch model from `model_id` and compiles it for RBLN NPU execution.
164
160
  If False, loads an already compiled RBLN model from `model_id` without recompilation.
165
- model_save_dir (`os.PathLike`, *optional*):
161
+ model_save_dir:
166
162
  Directory to save the compiled model artifacts. Only used when `export=True`.
167
163
  If not provided and `export=True`, a temporary directory is used.
168
- rbln_config (`Dict[str, Any]`, *optional*, defaults to `{}`):
164
+ rbln_config:
169
165
  Configuration options for RBLN compilation. Can include settings for specific submodules
170
166
  such as `text_encoder`, `unet`, and `vae`. Configuration can be tailored to the specific
171
167
  pipeline being compiled.
172
- lora_ids (`str` or `List[str]`, *optional*):
168
+ lora_ids:
173
169
  LoRA adapter ID(s) to load and apply before compilation. LoRA weights are fused
174
170
  into the model weights during compilation. Only used when `export=True`.
175
- lora_weights_names (`str` or `List[str]`, *optional*):
171
+ lora_weights_names:
176
172
  Names of specific LoRA weight files to load, corresponding to lora_ids. Only used when `export=True`.
177
- lora_scales (`float` or `List[float]`, *optional*):
173
+ lora_scales:
178
174
  Scaling factor(s) to apply to the LoRA adapter(s). Only used when `export=True`.
179
175
  **kwargs:
180
176
  Additional arguments to pass to the underlying diffusion pipeline constructor or the
@@ -182,39 +178,50 @@ class RBLNDiffusionMixin:
182
178
  or the particular diffusion pipeline being used.
183
179
 
184
180
  Returns:
185
- `RBLNDiffusionMixin`: A compiled or loaded diffusion pipeline that can be used for inference on RBLN NPU.
186
- The returned object is an instance of the class that called this method, inheriting from RBLNDiffusionMixin.
181
+ A compiled or loaded diffusion pipeline that can be used for inference on RBLN NPU.
182
+ The returned object is an instance of the class that called this method, inheriting from RBLNDiffusionMixin.
187
183
  """
188
184
  rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
189
185
 
190
186
  if export:
191
187
  # keep submodules if user passed any of them.
192
188
  passed_submodules = {
193
- name: kwargs.pop(name) for name in cls._submodules if isinstance(kwargs.get(name), RBLNModel)
189
+ name: kwargs.pop(name)
190
+ for name in cls._submodules + cls._optional_submodules
191
+ if isinstance(kwargs.get(name), RBLNModel)
194
192
  }
195
193
 
196
194
  else:
197
195
  # raise error if any of submodules are torch module.
198
196
  model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
199
- for submodule_name in cls._submodules:
200
- if isinstance(kwargs.get(submodule_name), torch.nn.Module):
201
- raise AssertionError(
202
- f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
197
+ for submodule_name in cls._submodules + cls._optional_submodules:
198
+ passed_submodule = kwargs.get(submodule_name, None)
199
+
200
+ if passed_submodule is None:
201
+ module_name, class_name = model_index_config[submodule_name]
202
+ if module_name != "optimum.rbln":
203
+ raise ValueError(
204
+ f"Invalid module_name '{module_name}' found in model_index.json for "
205
+ f"submodule '{submodule_name}'. "
206
+ "Expected 'optimum.rbln'. Please check the model_index.json configuration."
207
+ "If you want to compile, set `export=True`."
208
+ )
209
+
210
+ submodule_cls = get_rbln_model_cls(class_name)
211
+ submodule_config = getattr(rbln_config, submodule_name)
212
+ submodule = submodule_cls.from_pretrained(
213
+ model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
203
214
  )
204
215
 
205
- module_name, class_name = model_index_config[submodule_name]
206
- if module_name != "optimum.rbln":
207
- raise ValueError(
208
- f"Invalid module_name '{module_name}' found in model_index.json for "
209
- f"submodule '{submodule_name}'. "
210
- "Expected 'optimum.rbln'. Please check the model_index.json configuration."
211
- )
216
+ else:
217
+ if passed_submodule.__class__.__name__.startswith("RBLN"):
218
+ submodule = passed_submodule
219
+
220
+ elif isinstance(passed_submodule, torch.nn.Module):
221
+ raise AssertionError(
222
+ f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
223
+ )
212
224
 
213
- submodule_cls: Type[RBLNModel] = getattr(importlib.import_module("optimum.rbln"), class_name)
214
- submodule_config = getattr(rbln_config, submodule_name)
215
- submodule = submodule_cls.from_pretrained(
216
- model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
217
- )
218
225
  kwargs[submodule_name] = submodule
219
226
 
220
227
  with ContextRblnConfig(
@@ -293,7 +300,6 @@ class RBLNDiffusionMixin:
293
300
  elif isinstance(submodule, RBLNModel):
294
301
  pass
295
302
  elif submodule_name == "controlnet" and hasattr(submodule, "nets"):
296
- # In case of multicontrolnet
297
303
  submodule = cls._compile_multicontrolnet(
298
304
  controlnets=submodule,
299
305
  model_save_dir=model_save_dir,
@@ -301,11 +307,8 @@ class RBLNDiffusionMixin:
301
307
  prefix=prefix,
302
308
  )
303
309
  elif isinstance(submodule, torch.nn.Module):
304
- submodule_cls: RBLNModel = getattr(
305
- importlib.import_module("optimum.rbln"), f"RBLN{submodule.__class__.__name__}"
306
- )
307
310
  subfolder = prefix + submodule_name
308
- submodule = submodule_cls.from_model(
311
+ submodule = submodule_rbln_cls.from_model(
309
312
  model=submodule,
310
313
  subfolder=subfolder,
311
314
  model_save_dir=model_save_dir,
@@ -362,10 +365,16 @@ class RBLNDiffusionMixin:
362
365
  # Causing warning messeages.
363
366
 
364
367
  update_dict = {}
365
- for submodule_name in cls._submodules:
368
+ for submodule_name in cls._submodules + cls._optional_submodules:
366
369
  # replace submodule
367
- setattr(model, submodule_name, submodules[submodule_name])
368
- update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
370
+ if submodule_name in submodules:
371
+ setattr(model, submodule_name, submodules[submodule_name])
372
+ update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
373
+ else:
374
+ # It assumes that the modules in _optional_components is compiled
375
+ # and already registered as an attribute of the model.
376
+ update_dict[submodule_name] = ("optimum.rbln", getattr(model, submodule_name).__class__.__name__)
377
+
369
378
  if cls._load_connected_pipes:
370
379
  for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
371
380
  prefix = cls._prefix.get(connected_pipe_name, "")
@@ -396,31 +405,29 @@ class RBLNDiffusionMixin:
396
405
  return model
397
406
 
398
407
  def get_compiled_image_size(self):
399
- if hasattr(self, "vae"):
408
+ if hasattr(self, "vae") and hasattr(self.vae, "image_size"):
400
409
  compiled_image_size = self.vae.image_size
401
410
  else:
402
411
  compiled_image_size = None
403
412
  return compiled_image_size
404
413
 
405
414
  def handle_additional_kwargs(self, **kwargs):
406
- """
407
- Function to handle additional compile-time parameters during inference.
408
-
409
- If the additional variable is determined by another module, this method should be overrided.
410
-
411
- Example:
412
- ```python
413
- if hasattr(self, "movq"):
414
- compiled_image_size = self.movq.image_size
415
- kwargs["height"] = compiled_image_size[0]
416
- kwargs["width"] = compiled_image_size[1]
417
-
418
- compiled_num_frames = self.unet.rbln_config.num_frames
419
- if compiled_num_frames is not None:
420
- kwargs["num_frames"] = compiled_num_frames
421
- return kwargs
422
- ```
423
- """
415
+ # Function to handle additional compile-time parameters during inference.
416
+
417
+ # If the additional variable is determined by another module, this method should be overrided.
418
+
419
+ # Example:
420
+ # ```python
421
+ # if hasattr(self, "movq"):
422
+ # compiled_image_size = self.movq.image_size
423
+ # kwargs["height"] = compiled_image_size[0]
424
+ # kwargs["width"] = compiled_image_size[1]
425
+
426
+ # compiled_num_frames = self.unet.rbln_config.num_frames
427
+ # if compiled_num_frames is not None:
428
+ # kwargs["num_frames"] = compiled_num_frames
429
+ # return kwargs
430
+ # ```
424
431
  return kwargs
425
432
 
426
433
  @remove_compile_time_kwargs
@@ -20,6 +20,7 @@ from transformers.utils import _LazyModule
20
20
  _import_structure = {
21
21
  "autoencoders": [
22
22
  "RBLNAutoencoderKL",
23
+ "RBLNAutoencoderKLCosmos",
23
24
  "RBLNVQModel",
24
25
  ],
25
26
  "unets": [
@@ -28,6 +29,7 @@ _import_structure = {
28
29
  "controlnet": ["RBLNControlNetModel"],
29
30
  "transformers": [
30
31
  "RBLNPriorTransformer",
32
+ "RBLNCosmosTransformer3DModel",
31
33
  "RBLNSD3Transformer2DModel",
32
34
  ],
33
35
  }
@@ -35,10 +37,12 @@ _import_structure = {
35
37
  if TYPE_CHECKING:
36
38
  from .autoencoders import (
37
39
  RBLNAutoencoderKL,
40
+ RBLNAutoencoderKLCosmos,
38
41
  RBLNVQModel,
39
42
  )
40
43
  from .controlnet import RBLNControlNetModel
41
44
  from .transformers import (
45
+ RBLNCosmosTransformer3DModel,
42
46
  RBLNPriorTransformer,
43
47
  RBLNSD3Transformer2DModel,
44
48
  )
@@ -13,4 +13,5 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from .autoencoder_kl import RBLNAutoencoderKL
16
+ from .autoencoder_kl_cosmos import RBLNAutoencoderKLCosmos
16
17
  from .vq_model import RBLNVQModel
@@ -38,6 +38,17 @@ logger = get_logger(__name__)
38
38
 
39
39
 
40
40
  class RBLNAutoencoderKL(RBLNModel):
41
+ """
42
+ RBLN implementation of AutoencoderKL (VAE) for diffusion models.
43
+
44
+ This model is used to accelerate AutoencoderKL (VAE) models from diffusers library on RBLN NPUs.
45
+ It can be configured to include both encoder and decoder, or just the decoder part for latent-to-image
46
+ conversion.
47
+
48
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
49
+ the library implements for all its models.
50
+ """
51
+
41
52
  auto_model_class = AutoencoderKL
42
53
  hf_library_name = "diffusers"
43
54
  _rbln_config_class = RBLNAutoencoderKLConfig
@@ -69,7 +80,12 @@ class RBLNAutoencoderKL(RBLNModel):
69
80
 
70
81
  wrapped_model.eval()
71
82
 
72
- compiled_models[model_name] = cls.compile(wrapped_model, rbln_compile_config=rbln_config.compile_cfgs[i])
83
+ compiled_models[model_name] = cls.compile(
84
+ wrapped_model,
85
+ rbln_compile_config=rbln_config.compile_cfgs[i],
86
+ create_runtimes=rbln_config.create_runtimes,
87
+ device=rbln_config.device_map[model_name],
88
+ )
73
89
 
74
90
  return compiled_models
75
91
 
@@ -0,0 +1,219 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Dict, List, Union
16
+
17
+ import rebel
18
+ import torch
19
+ from diffusers.models.autoencoders.autoencoder_kl_cosmos import AutoencoderKLCosmos, CosmosCausalConv3d
20
+ from diffusers.models.autoencoders.vae import DecoderOutput
21
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
22
+ from torch.nn import functional as F
23
+ from transformers import PretrainedConfig
24
+
25
+ from ....configuration_utils import RBLNCompileConfig
26
+ from ....modeling import RBLNModel
27
+ from ....utils.logging import get_logger
28
+ from ...configurations import RBLNAutoencoderKLCosmosConfig
29
+ from .vae import RBLNRuntimeCosmosVAEDecoder, RBLNRuntimeCosmosVAEEncoder, _VAECosmosDecoder, _VAECosmosEncoder
30
+
31
+
32
+ if TYPE_CHECKING:
33
+ import torch
34
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
35
+
36
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
37
+
38
+ logger = get_logger(__name__)
39
+
40
+
41
+ class RBLNAutoencoderKLCosmos(RBLNModel):
42
+ """
43
+ RBLN implementation of AutoencoderKLCosmos for diffusion models.
44
+
45
+ This model is used to accelerate AutoencoderKLCosmos models from diffusers library on RBLN NPUs.
46
+ It can be configured to include both encoder and decoder, or just the decoder part for latent-to-video
47
+ conversion.
48
+
49
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
50
+ the library implements for all its models.
51
+ """
52
+
53
+ auto_model_class = AutoencoderKLCosmos
54
+ hf_library_name = "diffusers"
55
+ _rbln_config_class = RBLNAutoencoderKLCosmosConfig
56
+
57
+ def __post_init__(self, **kwargs):
58
+ super().__post_init__(**kwargs)
59
+
60
+ if self.rbln_config.uses_encoder:
61
+ self.encoder = RBLNRuntimeCosmosVAEEncoder(
62
+ runtime=self.model[0], main_input_name="x", use_slicing=self.rbln_config.use_slicing
63
+ )
64
+
65
+ self.decoder = RBLNRuntimeCosmosVAEDecoder(
66
+ runtime=self.model[-1], main_input_name="z", use_slicing=self.rbln_config.use_slicing
67
+ )
68
+ self.image_size = self.rbln_config.image_size
69
+
70
+ @classmethod
71
+ def wrap_model_if_needed(
72
+ cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLCosmosConfig
73
+ ) -> torch.nn.Module:
74
+ decoder_model = _VAECosmosDecoder(model)
75
+ decoder_model.eval()
76
+
77
+ if rbln_config.uses_encoder:
78
+ encoder_model = _VAECosmosEncoder(model)
79
+ encoder_model.eval()
80
+ return encoder_model, decoder_model
81
+ else:
82
+ return decoder_model
83
+
84
+ @classmethod
85
+ def get_compiled_model(
86
+ cls, model, rbln_config: RBLNAutoencoderKLCosmosConfig
87
+ ) -> Dict[str, rebel.RBLNCompiledModel]:
88
+ def replaced_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
89
+ if self.temporal_pad != 0:
90
+ hidden_states_prev = hidden_states[:, :, :1, ...].repeat(1, 1, self.temporal_pad, 1, 1)
91
+ hidden_states = torch.cat([hidden_states_prev, hidden_states], dim=2)
92
+ hidden_states = F.pad(hidden_states, (*self.spatial_pad, 0, 0), mode=self.pad_mode, value=0.0)
93
+ return super(CosmosCausalConv3d, self).forward(hidden_states)
94
+
95
+ try:
96
+ original_forward = CosmosCausalConv3d.forward
97
+ CosmosCausalConv3d.forward = replaced_forward
98
+
99
+ compiled_models = {}
100
+ if rbln_config.uses_encoder:
101
+ encoder_model, decoder_model = cls.wrap_model_if_needed(model, rbln_config)
102
+ enc_compiled_model = cls.compile(
103
+ encoder_model,
104
+ rbln_compile_config=rbln_config.compile_cfgs[0],
105
+ create_runtimes=rbln_config.create_runtimes,
106
+ device=rbln_config.device_map["encoder"],
107
+ )
108
+ compiled_models["encoder"] = enc_compiled_model
109
+ else:
110
+ decoder_model = cls.wrap_model_if_needed(model, rbln_config)
111
+ dec_compiled_model = cls.compile(
112
+ decoder_model,
113
+ rbln_compile_config=rbln_config.compile_cfgs[-1],
114
+ create_runtimes=rbln_config.create_runtimes,
115
+ device=rbln_config.device_map["decoder"],
116
+ )
117
+ compiled_models["decoder"] = dec_compiled_model
118
+
119
+ finally:
120
+ CosmosCausalConv3d.forward = original_forward
121
+
122
+ return compiled_models
123
+
124
+ @classmethod
125
+ def update_rbln_config_using_pipe(
126
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
127
+ ) -> "RBLNDiffusionMixinConfig":
128
+ rbln_config.vae.num_channels_latents = pipe.transformer.config.out_channels
129
+ rbln_config.vae.vae_scale_factor_temporal = pipe.vae_scale_factor_temporal
130
+ rbln_config.vae.vae_scale_factor_spatial = pipe.vae_scale_factor_spatial
131
+ return rbln_config
132
+
133
+ @classmethod
134
+ def _update_rbln_config(
135
+ cls,
136
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
137
+ model: "PreTrainedModel",
138
+ model_config: "PretrainedConfig",
139
+ rbln_config: RBLNAutoencoderKLCosmosConfig,
140
+ ) -> RBLNAutoencoderKLCosmosConfig:
141
+ batch_size = 1 if rbln_config.use_slicing else rbln_config.batch_size
142
+ compile_cfgs = []
143
+ if rbln_config.uses_encoder:
144
+ vae_enc_input_info = [
145
+ (
146
+ "x",
147
+ [
148
+ batch_size,
149
+ model_config.in_channels,
150
+ rbln_config.num_frames,
151
+ rbln_config.height,
152
+ rbln_config.width,
153
+ ],
154
+ "float32",
155
+ ),
156
+ ]
157
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
158
+
159
+ num_latent_frames = (rbln_config.num_frames - 1) // rbln_config.vae_scale_factor_temporal + 1
160
+ latent_height = rbln_config.height // rbln_config.vae_scale_factor_spatial
161
+ latent_width = rbln_config.width // rbln_config.vae_scale_factor_spatial
162
+
163
+ vae_dec_input_info = [
164
+ (
165
+ "z",
166
+ [
167
+ batch_size,
168
+ rbln_config.num_channels_latents,
169
+ num_latent_frames,
170
+ latent_height,
171
+ latent_width,
172
+ ],
173
+ "float32",
174
+ ),
175
+ ]
176
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
177
+
178
+ rbln_config.set_compile_cfgs(compile_cfgs)
179
+ return rbln_config
180
+
181
+ @classmethod
182
+ def _create_runtimes(
183
+ cls,
184
+ compiled_models: List[rebel.RBLNCompiledModel],
185
+ rbln_config: RBLNAutoencoderKLCosmosConfig,
186
+ ) -> List[rebel.Runtime]:
187
+ if len(compiled_models) == 1:
188
+ # decoder
189
+ expected_models = ["decoder"]
190
+ else:
191
+ expected_models = ["encoder", "decoder"]
192
+
193
+ if any(model_name not in rbln_config.device_map for model_name in expected_models):
194
+ cls._raise_missing_compiled_file_error(expected_models)
195
+
196
+ device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
197
+ return [
198
+ rebel.Runtime(
199
+ compiled_model,
200
+ tensor_type="pt",
201
+ device=device_val,
202
+ activate_profiler=rbln_config.activate_profiler,
203
+ )
204
+ for compiled_model, device_val in zip(compiled_models, device_vals)
205
+ ]
206
+
207
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
208
+ posterior = self.encoder.encode(x)
209
+ if not return_dict:
210
+ return (posterior,)
211
+ return AutoencoderKLOutput(latent_dist=posterior)
212
+
213
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
214
+ decoded = self.decoder.decode(z)
215
+
216
+ if not return_dict:
217
+ return (decoded,)
218
+
219
+ return DecoderOutput(sample=decoded)
@@ -15,17 +15,13 @@
15
15
  from typing import TYPE_CHECKING, List
16
16
 
17
17
  import torch
18
- from diffusers import AutoencoderKL, VQModel
19
- from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
18
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution, IdentityDistribution
20
19
 
21
- from ....utils.logging import get_logger
22
20
  from ....utils.runtime_utils import RBLNPytorchRuntime
23
21
 
24
22
 
25
23
  if TYPE_CHECKING:
26
- import torch
27
-
28
- logger = get_logger(__name__)
24
+ from diffusers import AutoencoderKL, AutoencoderKLCosmos, VQModel
29
25
 
30
26
 
31
27
  class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
@@ -40,6 +36,27 @@ class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
40
36
  return self.forward(z)
41
37
 
42
38
 
39
+ class RBLNRuntimeCosmosVAEEncoder(RBLNPytorchRuntime):
40
+ def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
41
+ if self.use_slicing and x.shape[0] > 1:
42
+ encoded_slices = [self.forward(x_slice) for x_slice in x.split(1)]
43
+ h = torch.cat(encoded_slices)
44
+ else:
45
+ h = self.forward(x)
46
+ posterior = IdentityDistribution(h)
47
+ return posterior
48
+
49
+
50
+ class RBLNRuntimeCosmosVAEDecoder(RBLNPytorchRuntime):
51
+ def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
52
+ if self.use_slicing and z.shape[0] > 1:
53
+ decoded_slices = [self.forward(z_slice) for z_slice in z.split(1)]
54
+ decoded = torch.cat(decoded_slices)
55
+ else:
56
+ decoded = self.forward(z)
57
+ return decoded
58
+
59
+
43
60
  class _VAEDecoder(torch.nn.Module):
44
61
  def __init__(self, vae: "AutoencoderKL"):
45
62
  super().__init__()
@@ -73,6 +90,26 @@ class _VAEEncoder(torch.nn.Module):
73
90
  return vae_out
74
91
 
75
92
 
93
+ class _VAECosmosEncoder(torch.nn.Module):
94
+ def __init__(self, vae: "AutoencoderKLCosmos"):
95
+ super().__init__()
96
+ self.vae = vae
97
+
98
+ def forward(self, x):
99
+ vae_out = self.vae._encode(x)
100
+ return vae_out
101
+
102
+
103
+ class _VAECosmosDecoder(torch.nn.Module):
104
+ def __init__(self, vae: "AutoencoderKLCosmos"):
105
+ super().__init__()
106
+ self.vae = vae
107
+
108
+ def forward(self, z):
109
+ vae_out = self.vae._decode(z, return_dict=False)
110
+ return vae_out
111
+
112
+
76
113
  class RBLNRuntimeVQEncoder(RBLNPytorchRuntime):
77
114
  def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
78
115
  h = self.forward(x.contiguous())
@@ -91,7 +128,7 @@ class RBLNRuntimeVQDecoder(RBLNPytorchRuntime):
91
128
 
92
129
 
93
130
  class _VQEncoder(torch.nn.Module):
94
- def __init__(self, vq_model: VQModel):
131
+ def __init__(self, vq_model: "VQModel"):
95
132
  super().__init__()
96
133
  self.vq_model = vq_model
97
134
 
@@ -106,7 +143,7 @@ class _VQEncoder(torch.nn.Module):
106
143
 
107
144
 
108
145
  class _VQDecoder(torch.nn.Module):
109
- def __init__(self, vq_model: VQModel):
146
+ def __init__(self, vq_model: "VQModel"):
110
147
  super().__init__()
111
148
  self.vq_model = vq_model
112
149
 
@@ -35,6 +35,17 @@ logger = get_logger(__name__)
35
35
 
36
36
 
37
37
  class RBLNVQModel(RBLNModel):
38
+ """
39
+ RBLN implementation of VQModel for diffusion models.
40
+
41
+ This model is used to accelerate VQModel models from diffusers library on RBLN NPUs.
42
+ It can be configured to include both encoder and decoder, or just the decoder part for latent-to-image
43
+ conversion.
44
+
45
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
46
+ the library implements for all its models.
47
+ """
48
+
38
49
  auto_model_class = VQModel
39
50
  config_name = "config.json"
40
51
  hf_library_name = "diffusers"
@@ -67,7 +78,12 @@ class RBLNVQModel(RBLNModel):
67
78
 
68
79
  wrapped_model.eval()
69
80
 
70
- compiled_models[model_name] = cls.compile(wrapped_model, rbln_compile_config=rbln_config.compile_cfgs[i])
81
+ compiled_models[model_name] = cls.compile(
82
+ wrapped_model,
83
+ rbln_compile_config=rbln_config.compile_cfgs[i],
84
+ create_runtimes=rbln_config.create_runtimes,
85
+ device=rbln_config.device_map[model_name],
86
+ )
71
87
 
72
88
  return compiled_models
73
89