optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. optimum/rbln/__init__.py +24 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +45 -33
  4. optimum/rbln/diffusers/__init__.py +21 -1
  5. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  13. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  14. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  15. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
  22. optimum/rbln/diffusers/modeling_diffusers.py +72 -65
  23. optimum/rbln/diffusers/models/__init__.py +4 -0
  24. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  25. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
  26. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
  27. optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
  28. optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
  29. optimum/rbln/diffusers/models/controlnet.py +14 -8
  30. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  31. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  32. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  33. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  34. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
  35. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  36. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  37. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  38. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  39. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  42. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  43. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  45. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  46. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  47. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  49. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  50. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  51. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  52. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  53. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  54. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  55. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  56. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  57. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  58. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  59. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  60. optimum/rbln/modeling.py +71 -37
  61. optimum/rbln/modeling_base.py +63 -109
  62. optimum/rbln/transformers/__init__.py +41 -47
  63. optimum/rbln/transformers/configuration_generic.py +16 -13
  64. optimum/rbln/transformers/modeling_generic.py +21 -22
  65. optimum/rbln/transformers/modeling_rope_utils.py +5 -2
  66. optimum/rbln/transformers/models/__init__.py +54 -4
  67. optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
  68. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  69. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  70. optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
  71. optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
  72. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  73. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  74. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  75. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  76. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
  77. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
  78. optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
  79. optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
  80. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
  82. optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
  83. optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
  84. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  85. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  86. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
  87. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  88. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  89. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
  90. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  91. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  92. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  93. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  94. optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
  95. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  96. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  97. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  98. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  99. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  100. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
  101. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  102. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  103. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  104. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  105. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  106. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  107. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
  108. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
  109. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  110. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  111. optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
  112. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  113. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  114. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  115. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  116. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  117. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  118. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  119. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  120. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  121. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  122. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
  123. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
  124. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
  125. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  126. optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
  127. optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
  128. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  129. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
  130. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
  131. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  132. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  133. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
  134. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
  135. optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
  136. optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
  137. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  138. optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
  139. optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
  140. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  141. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
  142. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
  143. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  144. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  145. optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
  146. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  147. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
  148. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  149. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
  150. optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
  151. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  152. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  153. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  154. optimum/rbln/utils/model_utils.py +20 -0
  155. optimum/rbln/utils/runtime_utils.py +49 -1
  156. optimum/rbln/utils/submodule.py +6 -8
  157. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
  158. optimum_rbln-0.8.1.dist-info/RECORD +211 -0
  159. optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
  160. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  161. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
  162. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -26,6 +26,7 @@ _import_structure = {
26
26
  "RBLNModel",
27
27
  ],
28
28
  "configuration_utils": [
29
+ "RBLNAutoConfig",
29
30
  "RBLNCompileConfig",
30
31
  "RBLNModelConfig",
31
32
  ],
@@ -69,6 +70,8 @@ _import_structure = {
69
70
  "RBLNCLIPVisionModelConfig",
70
71
  "RBLNCLIPVisionModelWithProjection",
71
72
  "RBLNCLIPVisionModelWithProjectionConfig",
73
+ "RBLNColPaliForRetrieval",
74
+ "RBLNColPaliForRetrievalConfig",
72
75
  "RBLNDecoderOnlyModelForCausalLM",
73
76
  "RBLNDecoderOnlyModelForCausalLMConfig",
74
77
  "RBLNDistilBertForQuestionAnswering",
@@ -135,8 +138,17 @@ _import_structure = {
135
138
  "diffusers": [
136
139
  "RBLNAutoencoderKL",
137
140
  "RBLNAutoencoderKLConfig",
141
+ "RBLNAutoencoderKLCosmos",
142
+ "RBLNAutoencoderKLCosmosConfig",
138
143
  "RBLNControlNetModel",
139
144
  "RBLNControlNetModelConfig",
145
+ "RBLNCosmosTextToWorldPipeline",
146
+ "RBLNCosmosVideoToWorldPipeline",
147
+ "RBLNCosmosTextToWorldPipelineConfig",
148
+ "RBLNCosmosVideoToWorldPipelineConfig",
149
+ "RBLNCosmosSafetyChecker",
150
+ "RBLNCosmosTransformer3DModel",
151
+ "RBLNCosmosTransformer3DModelConfig",
140
152
  "RBLNDiffusionMixin",
141
153
  "RBLNKandinskyV22CombinedPipeline",
142
154
  "RBLNKandinskyV22CombinedPipelineConfig",
@@ -192,14 +204,24 @@ _import_structure = {
192
204
 
193
205
  if TYPE_CHECKING:
194
206
  from .configuration_utils import (
207
+ RBLNAutoConfig,
195
208
  RBLNCompileConfig,
196
209
  RBLNModelConfig,
197
210
  )
198
211
  from .diffusers import (
199
212
  RBLNAutoencoderKL,
200
213
  RBLNAutoencoderKLConfig,
214
+ RBLNAutoencoderKLCosmos,
215
+ RBLNAutoencoderKLCosmosConfig,
201
216
  RBLNControlNetModel,
202
217
  RBLNControlNetModelConfig,
218
+ RBLNCosmosSafetyChecker,
219
+ RBLNCosmosTextToWorldPipeline,
220
+ RBLNCosmosTextToWorldPipelineConfig,
221
+ RBLNCosmosTransformer3DModel,
222
+ RBLNCosmosTransformer3DModelConfig,
223
+ RBLNCosmosVideoToWorldPipeline,
224
+ RBLNCosmosVideoToWorldPipelineConfig,
203
225
  RBLNDiffusionMixin,
204
226
  RBLNKandinskyV22CombinedPipeline,
205
227
  RBLNKandinskyV22CombinedPipelineConfig,
@@ -295,6 +317,8 @@ if TYPE_CHECKING:
295
317
  RBLNCLIPVisionModelConfig,
296
318
  RBLNCLIPVisionModelWithProjection,
297
319
  RBLNCLIPVisionModelWithProjectionConfig,
320
+ RBLNColPaliForRetrieval,
321
+ RBLNColPaliForRetrievalConfig,
298
322
  RBLNDecoderOnlyModelForCausalLM,
299
323
  RBLNDecoderOnlyModelForCausalLMConfig,
300
324
  RBLNDistilBertForQuestionAnswering,
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.8.0.post2'
21
- __version_tuple__ = version_tuple = (0, 8, 0, 'post2')
20
+ __version__ = version = '0.8.1'
21
+ __version_tuple__ = version_tuple = (0, 8, 1)
@@ -19,6 +19,7 @@ from dataclasses import asdict, dataclass
19
19
  from pathlib import Path
20
20
  from typing import Any, Dict, List, Optional, Protocol, Tuple, Type, Union, runtime_checkable
21
21
 
22
+ import numpy as np
22
23
  import torch
23
24
 
24
25
  from .__version__ import __version__
@@ -61,7 +62,7 @@ class RBLNCompileConfig:
61
62
  tensor_parallel_size: Optional[int] = None
62
63
 
63
64
  @staticmethod
64
- def normalize_dtype(dtype):
65
+ def normalize_dtype(dtype: Union[str, torch.dtype, np.dtype]) -> str:
65
66
  """
66
67
  Convert framework-specific dtype to string representation.
67
68
  i.e. torch.float32 -> "float32"
@@ -70,7 +71,7 @@ class RBLNCompileConfig:
70
71
  dtype: The input dtype (can be string, torch dtype, or numpy dtype).
71
72
 
72
73
  Returns:
73
- str: The normalized string representation of the dtype.
74
+ The normalized string representation of the dtype.
74
75
  """
75
76
  if isinstance(dtype, str):
76
77
  return dtype
@@ -147,6 +148,17 @@ class RBLNCompileConfig:
147
148
 
148
149
 
149
150
  RUNTIME_KEYWORDS = ["create_runtimes", "optimize_host_memory", "device", "device_map", "activate_profiler"]
151
+ CONFIG_MAPPING: Dict[str, Type["RBLNModelConfig"]] = {}
152
+
153
+
154
+ def get_rbln_config_class(rbln_config_class_name: str) -> Type["RBLNModelConfig"]:
155
+ cls = getattr(importlib.import_module("optimum.rbln"), rbln_config_class_name, None)
156
+ if cls is None:
157
+ if rbln_config_class_name in CONFIG_MAPPING:
158
+ cls = CONFIG_MAPPING[rbln_config_class_name]
159
+ else:
160
+ raise ValueError(f"Configuration for {rbln_config_class_name} not found.")
161
+ return cls
150
162
 
151
163
 
152
164
  def load_config(path: str) -> Tuple[Type["RBLNModelConfig"], Dict[str, Any]]:
@@ -166,7 +178,7 @@ def load_config(path: str) -> Tuple[Type["RBLNModelConfig"], Dict[str, Any]]:
166
178
  )
167
179
 
168
180
  cls_name = config_file["cls_name"]
169
- cls = getattr(importlib.import_module("optimum.rbln"), cls_name)
181
+ cls = get_rbln_config_class(cls_name)
170
182
  return cls, config_file
171
183
 
172
184
 
@@ -175,7 +187,7 @@ class RBLNAutoConfig:
175
187
  cls_name = kwargs.get("cls_name")
176
188
  if cls_name is None:
177
189
  raise ValueError("`cls_name` is required.")
178
- cls = getattr(importlib.import_module("optimum.rbln"), cls_name)
190
+ cls = get_rbln_config_class(cls_name)
179
191
  return cls(**kwargs)
180
192
 
181
193
  @staticmethod
@@ -183,9 +195,27 @@ class RBLNAutoConfig:
183
195
  cls_name = config_dict.get("cls_name")
184
196
  if cls_name is None:
185
197
  raise ValueError("`cls_name` is required.")
186
- cls = getattr(importlib.import_module("optimum.rbln"), cls_name)
198
+ cls = get_rbln_config_class(cls_name)
187
199
  return cls(**config_dict)
188
200
 
201
+ @staticmethod
202
+ def register(config: Type["RBLNModelConfig"], exist_ok=False):
203
+ """
204
+ Register a new configuration for this class.
205
+
206
+ Args:
207
+ config ([`RBLNModelConfig`]): The config to register.
208
+ """
209
+ if not issubclass(config, RBLNModelConfig):
210
+ raise ValueError("`config` must be a subclass of RBLNModelConfig.")
211
+
212
+ native_cls = getattr(importlib.import_module("optimum.rbln"), config.__name__, None)
213
+ if config.__name__ in CONFIG_MAPPING or native_cls is not None:
214
+ if not exist_ok:
215
+ raise ValueError(f"Configuration for {config.__name__} already registered.")
216
+
217
+ CONFIG_MAPPING[config.__name__] = config
218
+
189
219
  @staticmethod
190
220
  def load(
191
221
  path: str,
@@ -307,9 +337,6 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
307
337
  # Save to disk
308
338
  config.save("/path/to/model")
309
339
 
310
- # Load configuration from disk
311
- loaded_config = RBLNModelConfig.load("/path/to/model")
312
-
313
340
  # Using AutoConfig
314
341
  loaded_config = RBLNAutoConfig.load("/path/to/model")
315
342
  ```
@@ -462,13 +489,11 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
462
489
  self,
463
490
  submodule_config_cls: Type["RBLNModelConfig"],
464
491
  submodule_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
465
- **kwargs,
492
+ **kwargs: Dict[str, Any],
466
493
  ) -> "RBLNModelConfig":
467
- """
468
- Initialize a submodule config from a dict or a RBLNModelConfig.
494
+ # Initialize a submodule config from a dict or a RBLNModelConfig.
495
+ # kwargs is specified from the predecessor config.
469
496
 
470
- kwargs is specified from the predecessor config.
471
- """
472
497
  if submodule_config is None:
473
498
  submodule_config = {}
474
499
 
@@ -538,7 +563,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
538
563
  tensor_parallel_size: Optional[int] = None,
539
564
  optimum_rbln_version: Optional[str] = None,
540
565
  _compile_cfgs: List[RBLNCompileConfig] = [],
541
- **kwargs,
566
+ **kwargs: Dict[str, Any],
542
567
  ):
543
568
  """
544
569
  Initialize a RBLN model configuration with runtime options and compile configurations.
@@ -608,10 +633,8 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
608
633
  return rbln_model_cls
609
634
 
610
635
  def _prepare_for_serialization(self) -> Dict[str, Any]:
611
- """
612
- Prepare the attributes map for serialization by converting nested RBLNModelConfig
613
- objects to their serializable form.
614
- """
636
+ # Prepare the attributes map for serialization by converting nested RBLNModelConfig
637
+ # objects to their serializable form.
615
638
  serializable_map = {}
616
639
  for key, value in self._attributes_map.items():
617
640
  if isinstance(value, RBLNSerializableConfigProtocol):
@@ -686,7 +709,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
686
709
  json.dump(serializable_data, jsonf, indent=2)
687
710
 
688
711
  @classmethod
689
- def load(cls, path: str, **kwargs) -> "RBLNModelConfig":
712
+ def load(cls, path: str, **kwargs: Dict[str, Any]) -> "RBLNModelConfig":
690
713
  """
691
714
  Load a RBLNModelConfig from a path.
692
715
 
@@ -719,11 +742,9 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
719
742
  def initialize_from_kwargs(
720
743
  cls: Type["RBLNModelConfig"],
721
744
  rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
722
- **kwargs,
745
+ **kwargs: Dict[str, Any],
723
746
  ) -> Tuple["RBLNModelConfig", Dict[str, Any]]:
724
- """
725
- Initialize RBLNModelConfig from kwargs.
726
- """
747
+ # Initialize RBLNModelConfig from kwargs.
727
748
  kwargs_keys = list(kwargs.keys())
728
749
  rbln_kwargs = {key[5:]: kwargs.pop(key) for key in kwargs_keys if key.startswith("rbln_")}
729
750
 
@@ -741,16 +762,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
741
762
  return rbln_config, kwargs
742
763
 
743
764
  def get_default_values_for_original_cls(self, func_name: str, keys: List[str]) -> Dict[str, Any]:
744
- """
745
- Get default values for original class attributes from RBLNModelConfig.
746
-
747
- Args:
748
- func_name (str): The name of the function to get the default values for.
749
- keys (List[str]): The keys of the attributes to get.
750
-
751
- Returns:
752
- Dict[str, Any]: The default values for the attributes.
753
- """
765
+ # Get default values for original class attributes from RBLNModelConfig.
754
766
  model_cls = self.rbln_model_cls.get_hf_class()
755
767
  func = getattr(model_cls, func_name)
756
768
  func_signature = inspect.signature(func)
@@ -18,14 +18,21 @@ from diffusers.pipelines.pipeline_utils import ALL_IMPORTABLE_CLASSES, LOADABLE_
18
18
  from transformers.utils import _LazyModule
19
19
 
20
20
 
21
- LOADABLE_CLASSES["optimum.rbln"] = {"RBLNBaseModel": ["save_pretrained", "from_pretrained"]}
21
+ LOADABLE_CLASSES["optimum.rbln"] = {
22
+ "RBLNBaseModel": ["save_pretrained", "from_pretrained"],
23
+ "RBLNCosmosSafetyChecker": ["save_pretrained", "from_pretrained"],
24
+ }
22
25
  ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES["optimum.rbln"])
23
26
 
24
27
 
25
28
  _import_structure = {
26
29
  "configurations": [
27
30
  "RBLNAutoencoderKLConfig",
31
+ "RBLNAutoencoderKLCosmosConfig",
28
32
  "RBLNControlNetModelConfig",
33
+ "RBLNCosmosTextToWorldPipelineConfig",
34
+ "RBLNCosmosVideoToWorldPipelineConfig",
35
+ "RBLNCosmosTransformer3DModelConfig",
29
36
  "RBLNKandinskyV22CombinedPipelineConfig",
30
37
  "RBLNKandinskyV22Img2ImgCombinedPipelineConfig",
31
38
  "RBLNKandinskyV22Img2ImgPipelineConfig",
@@ -52,6 +59,9 @@ _import_structure = {
52
59
  "RBLNVQModelConfig",
53
60
  ],
54
61
  "pipelines": [
62
+ "RBLNCosmosTextToWorldPipeline",
63
+ "RBLNCosmosVideoToWorldPipeline",
64
+ "RBLNCosmosSafetyChecker",
55
65
  "RBLNKandinskyV22CombinedPipeline",
56
66
  "RBLNKandinskyV22Img2ImgCombinedPipeline",
57
67
  "RBLNKandinskyV22InpaintCombinedPipeline",
@@ -76,8 +86,10 @@ _import_structure = {
76
86
  ],
77
87
  "models": [
78
88
  "RBLNAutoencoderKL",
89
+ "RBLNAutoencoderKLCosmos",
79
90
  "RBLNUNet2DConditionModel",
80
91
  "RBLNControlNetModel",
92
+ "RBLNCosmosTransformer3DModel",
81
93
  "RBLNSD3Transformer2DModel",
82
94
  "RBLNPriorTransformer",
83
95
  "RBLNVQModel",
@@ -90,7 +102,11 @@ _import_structure = {
90
102
  if TYPE_CHECKING:
91
103
  from .configurations import (
92
104
  RBLNAutoencoderKLConfig,
105
+ RBLNAutoencoderKLCosmosConfig,
93
106
  RBLNControlNetModelConfig,
107
+ RBLNCosmosTextToWorldPipelineConfig,
108
+ RBLNCosmosTransformer3DModelConfig,
109
+ RBLNCosmosVideoToWorldPipelineConfig,
94
110
  RBLNKandinskyV22CombinedPipelineConfig,
95
111
  RBLNKandinskyV22Img2ImgCombinedPipelineConfig,
96
112
  RBLNKandinskyV22Img2ImgPipelineConfig,
@@ -120,12 +136,16 @@ if TYPE_CHECKING:
120
136
  from .models import (
121
137
  RBLNAutoencoderKL,
122
138
  RBLNControlNetModel,
139
+ RBLNCosmosTransformer3DModel,
123
140
  RBLNPriorTransformer,
124
141
  RBLNSD3Transformer2DModel,
125
142
  RBLNUNet2DConditionModel,
126
143
  RBLNVQModel,
127
144
  )
128
145
  from .pipelines import (
146
+ RBLNCosmosSafetyChecker,
147
+ RBLNCosmosTextToWorldPipeline,
148
+ RBLNCosmosVideoToWorldPipeline,
129
149
  RBLNKandinskyV22CombinedPipeline,
130
150
  RBLNKandinskyV22Img2ImgCombinedPipeline,
131
151
  RBLNKandinskyV22Img2ImgPipeline,
@@ -1,12 +1,16 @@
1
1
  from .models import (
2
2
  RBLNAutoencoderKLConfig,
3
+ RBLNAutoencoderKLCosmosConfig,
3
4
  RBLNControlNetModelConfig,
5
+ RBLNCosmosTransformer3DModelConfig,
4
6
  RBLNPriorTransformerConfig,
5
7
  RBLNSD3Transformer2DModelConfig,
6
8
  RBLNUNet2DConditionModelConfig,
7
9
  RBLNVQModelConfig,
8
10
  )
9
11
  from .pipelines import (
12
+ RBLNCosmosTextToWorldPipelineConfig,
13
+ RBLNCosmosVideoToWorldPipelineConfig,
10
14
  RBLNKandinskyV22CombinedPipelineConfig,
11
15
  RBLNKandinskyV22Img2ImgCombinedPipelineConfig,
12
16
  RBLNKandinskyV22Img2ImgPipelineConfig,
@@ -1,6 +1,8 @@
1
1
  from .configuration_autoencoder_kl import RBLNAutoencoderKLConfig
2
+ from .configuration_autoencoder_kl_cosmos import RBLNAutoencoderKLCosmosConfig
2
3
  from .configuration_controlnet import RBLNControlNetModelConfig
3
4
  from .configuration_prior_transformer import RBLNPriorTransformerConfig
5
+ from .configuration_transformer_cosmos import RBLNCosmosTransformer3DModelConfig
4
6
  from .configuration_transformer_sd3 import RBLNSD3Transformer2DModelConfig
5
7
  from .configuration_unet_2d_condition import RBLNUNet2DConditionModelConfig
6
8
  from .configuration_vq_model import RBLNVQModelConfig
@@ -12,12 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional, Tuple
15
+ from typing import Any, Dict, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNAutoencoderKLConfig(RBLNModelConfig):
21
+ """
22
+ Configuration class for RBLN Variational Autoencoder (VAE) models.
23
+
24
+ This class inherits from RBLNModelConfig and provides specific configuration options
25
+ for VAE models used in diffusion-based image generation.
26
+ """
27
+
21
28
  def __init__(
22
29
  self,
23
30
  batch_size: Optional[int] = None,
@@ -26,7 +33,7 @@ class RBLNAutoencoderKLConfig(RBLNModelConfig):
26
33
  vae_scale_factor: Optional[float] = None, # TODO: rename to scaling_factor
27
34
  in_channels: Optional[int] = None,
28
35
  latent_channels: Optional[int] = None,
29
- **kwargs,
36
+ **kwargs: Dict[str, Any],
30
37
  ):
31
38
  """
32
39
  Args:
@@ -0,0 +1,84 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.logging import get_logger
19
+
20
+
21
+ logger = get_logger(__name__)
22
+
23
+
24
+ class RBLNAutoencoderKLCosmosConfig(RBLNModelConfig):
25
+ """Configuration class for RBLN Cosmos Variational Autoencoder (VAE) models."""
26
+
27
+ def __init__(
28
+ self,
29
+ batch_size: Optional[int] = None,
30
+ uses_encoder: Optional[bool] = None,
31
+ num_frames: Optional[int] = None,
32
+ height: Optional[int] = None,
33
+ width: Optional[int] = None,
34
+ num_channels_latents: Optional[int] = None,
35
+ vae_scale_factor_temporal: Optional[int] = None,
36
+ vae_scale_factor_spatial: Optional[int] = None,
37
+ use_slicing: Optional[bool] = None,
38
+ **kwargs: Dict[str, Any],
39
+ ):
40
+ """
41
+ Args:
42
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
43
+ uses_encoder (Optional[bool]): Whether to include the encoder part of the VAE in the model.
44
+ When False, only the decoder is used (for latent-to-video conversion).
45
+ num_frames (Optional[int]): The number of frames in the generated video. Defaults to 121.
46
+ height (Optional[int]): The height in pixels of the generated video. Defaults to 704.
47
+ width (Optional[int]): The width in pixels of the generated video. Defaults to 1280.
48
+ num_channels_latents (Optional[int]): The number of channels in latent space.
49
+ vae_scale_factor_temporal (Optional[int]): The scaling factor between time space and latent space.
50
+ Determines how much shorter the latent representations are compared to the original videos.
51
+ vae_scale_factor_spatial (Optional[int]): The scaling factor between pixel space and latent space.
52
+ Determines how much smaller the latent representations are compared to the original videos.
53
+ use_slicing (Optional[bool]): Enable sliced VAE encoding and decoding.
54
+ If True, the VAE will split the input tensor in slices to compute encoding or decoding in several steps.
55
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
56
+
57
+ Raises:
58
+ ValueError: If batch_size is not a positive integer.
59
+ """
60
+ super().__init__(**kwargs)
61
+ # Since the Cosmos VAE Decoder already requires approximately 7.9 GiB of memory,
62
+ # Optimum-rbln cannot execute this model on RBLN-CA12 when the batch size > 1.
63
+ # However, the Cosmos VAE Decoder propose batch slicing when the batch size is greater than 1,
64
+ # Optimum-rbln utilize this method by compiling with batch_size=1 to enable batch slicing.
65
+ self.batch_size = batch_size or 1
66
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
67
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
68
+ elif self.batch_size > 1:
69
+ logger.warning("The batch size of Cosmos VAE Decoder will be explicitly 1 for memory efficiency.")
70
+ self.batch_size = 1
71
+
72
+ self.uses_encoder = uses_encoder
73
+ self.num_frames = num_frames or 121
74
+ self.height = height or 704
75
+ self.width = width or 1280
76
+
77
+ self.num_channels_latents = num_channels_latents
78
+ self.vae_scale_factor_temporal = vae_scale_factor_temporal
79
+ self.vae_scale_factor_spatial = vae_scale_factor_spatial
80
+ self.use_slicing = use_slicing or False
81
+
82
+ @property
83
+ def image_size(self):
84
+ return (self.height, self.width)
@@ -12,12 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional, Tuple
15
+ from typing import Any, Dict, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNControlNetModelConfig(RBLNModelConfig):
21
+ """Configuration class for RBLN ControlNet models."""
22
+
21
23
  subclass_non_save_attributes = ["_batch_size_is_specified"]
22
24
 
23
25
  def __init__(
@@ -27,7 +29,7 @@ class RBLNControlNetModelConfig(RBLNModelConfig):
27
29
  unet_sample_size: Optional[Tuple[int, int]] = None,
28
30
  vae_sample_size: Optional[Tuple[int, int]] = None,
29
31
  text_model_hidden_size: Optional[int] = None,
30
- **kwargs,
32
+ **kwargs: Dict[str, Any],
31
33
  ):
32
34
  """
33
35
  Args:
@@ -12,12 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional
15
+ from typing import Any, Dict, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNPriorTransformerConfig(RBLNModelConfig):
21
+ """
22
+ Configuration class for RBLN Prior Transformer models.
23
+
24
+ This class inherits from RBLNModelConfig and provides specific configuration options
25
+ for Prior Transformer models used in diffusion models like Kandinsky V2.2.
26
+ """
27
+
21
28
  subclass_non_save_attributes = ["_batch_size_is_specified"]
22
29
 
23
30
  def __init__(
@@ -25,7 +32,7 @@ class RBLNPriorTransformerConfig(RBLNModelConfig):
25
32
  batch_size: Optional[int] = None,
26
33
  embedding_dim: Optional[int] = None,
27
34
  num_embeddings: Optional[int] = None,
28
- **kwargs,
35
+ **kwargs: Dict[str, Any],
29
36
  ):
30
37
  """
31
38
  Args:
@@ -0,0 +1,70 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNCosmosTransformer3DModelConfig(RBLNModelConfig):
21
+ """Configuration class for RBLN Cosmos Transformer models."""
22
+
23
+ def __init__(
24
+ self,
25
+ batch_size: Optional[int] = None,
26
+ num_frames: Optional[int] = None,
27
+ height: Optional[int] = None,
28
+ width: Optional[int] = None,
29
+ fps: Optional[int] = None,
30
+ max_seq_len: Optional[int] = None,
31
+ embedding_dim: Optional[int] = None,
32
+ num_channels_latents: Optional[int] = None,
33
+ num_latent_frames: Optional[int] = None,
34
+ latent_height: Optional[int] = None,
35
+ latent_width: Optional[int] = None,
36
+ **kwargs: Dict[str, Any],
37
+ ):
38
+ """
39
+ Args:
40
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
41
+ num_frames (Optional[int]): The number of frames in the generated video. Defaults to 121.
42
+ height (Optional[int]): The height in pixels of the generated video. Defaults to 704.
43
+ width (Optional[int]): The width in pixels of the generated video. Defaults to 1280.
44
+ fps (Optional[int]): The frames per second of the generated video. Defaults to 30.
45
+ max_seq_len (Optional[int]): Maximum sequence length of prompt embeds.
46
+ embedding_dim (Optional[int]): Embedding vector dimension of prompt embeds.
47
+ num_channels_latents (Optional[int]): The number of channels in latent space.
48
+ latent_height (Optional[int]): The height in pixels in latent space.
49
+ latent_width (Optional[int]): The width in pixels in latent space.
50
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
51
+
52
+ Raises:
53
+ ValueError: If batch_size is not a positive integer.
54
+ """
55
+ super().__init__(**kwargs)
56
+ self.batch_size = batch_size or 1
57
+ self.num_frames = num_frames or 121
58
+ self.height = height or 704
59
+ self.width = width or 1280
60
+ self.fps = fps or 30
61
+
62
+ self.max_seq_len = max_seq_len
63
+ self.num_channels_latents = num_channels_latents
64
+ self.num_latent_frames = num_latent_frames
65
+ self.latent_height = latent_height
66
+ self.latent_width = latent_width
67
+ self.embedding_dim = embedding_dim
68
+
69
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
70
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
@@ -12,12 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional, Tuple, Union
15
+ from typing import Any, Dict, Optional, Tuple, Union
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNSD3Transformer2DModelConfig(RBLNModelConfig):
21
+ """Configuration class for RBLN Stable Diffusion 3 Transformer models."""
22
+
21
23
  subclass_non_save_attributes = ["_batch_size_is_specified"]
22
24
 
23
25
  def __init__(
@@ -25,7 +27,7 @@ class RBLNSD3Transformer2DModelConfig(RBLNModelConfig):
25
27
  batch_size: Optional[int] = None,
26
28
  sample_size: Optional[Union[int, Tuple[int, int]]] = None,
27
29
  prompt_embed_length: Optional[int] = None,
28
- **kwargs,
30
+ **kwargs: Dict[str, Any],
29
31
  ):
30
32
  """
31
33
  Args:
@@ -12,12 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional, Tuple
15
+ from typing import Any, Dict, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNUNet2DConditionModelConfig(RBLNModelConfig):
21
+ """
22
+ Configuration class for RBLN UNet2DCondition models.
23
+
24
+ This class inherits from RBLNModelConfig and provides specific configuration options
25
+ for UNet2DCondition models used in diffusion-based image generation.
26
+ """
27
+
21
28
  subclass_non_save_attributes = ["_batch_size_is_specified"]
22
29
 
23
30
  def __init__(
@@ -31,7 +38,7 @@ class RBLNUNet2DConditionModelConfig(RBLNModelConfig):
31
38
  in_features: Optional[int] = None,
32
39
  text_model_hidden_size: Optional[int] = None,
33
40
  image_model_hidden_size: Optional[int] = None,
34
- **kwargs,
41
+ **kwargs: Dict[str, Any],
35
42
  ):
36
43
  """
37
44
  Args: