optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. optimum/rbln/__init__.py +96 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +153 -42
  5. optimum/rbln/diffusers/__init__.py +7 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  12. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  13. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  20. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  21. optimum/rbln/diffusers/models/__init__.py +3 -13
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  23. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
  24. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  25. optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
  26. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
  27. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
  28. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
  29. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  30. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  31. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  32. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  34. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  36. optimum/rbln/modeling.py +71 -19
  37. optimum/rbln/modeling_base.py +99 -21
  38. optimum/rbln/ops/attn.py +158 -0
  39. optimum/rbln/ops/flash_attn.py +166 -0
  40. optimum/rbln/ops/kv_cache_update.py +5 -0
  41. optimum/rbln/ops/linear.py +7 -0
  42. optimum/rbln/transformers/__init__.py +92 -0
  43. optimum/rbln/transformers/configuration_generic.py +9 -7
  44. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  45. optimum/rbln/transformers/modeling_generic.py +51 -9
  46. optimum/rbln/transformers/modeling_outputs.py +37 -0
  47. optimum/rbln/transformers/models/__init__.py +91 -30
  48. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  49. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  50. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  51. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  52. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  53. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  54. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  55. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
  57. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  58. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
  59. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  60. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  61. optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
  62. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  63. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  64. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  65. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  66. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  67. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
  68. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  69. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
  71. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
  72. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  73. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
  74. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  75. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  76. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  77. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  78. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  79. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  80. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  81. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  82. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  83. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  84. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  85. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
  86. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  87. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  88. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  89. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  90. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  91. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  92. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  93. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  94. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  95. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
  96. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  97. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  98. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  99. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  100. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  101. optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
  102. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  103. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
  104. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  105. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  106. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  107. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  108. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  109. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  110. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  111. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  112. optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
  113. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  114. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  115. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  116. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  117. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  118. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  119. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  120. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  121. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  122. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  123. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  124. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
  125. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  126. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  127. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  128. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  129. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  130. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
  131. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  132. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  133. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  134. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
  135. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  136. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  137. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  138. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  139. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
  140. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
  141. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  142. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  143. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  144. optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
  145. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  146. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  147. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  148. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  149. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  150. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  151. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  152. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  153. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
  154. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  155. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  156. optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
  157. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  158. optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
  159. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  160. optimum/rbln/utils/depreacate_utils.py +16 -0
  161. optimum/rbln/utils/runtime_utils.py +28 -18
  162. optimum/rbln/utils/submodule.py +31 -9
  163. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
  164. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
  165. optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
  166. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
  167. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
@@ -21,8 +21,10 @@ from typing import Any, Dict, List, Optional, Protocol, Tuple, Type, Union, runt
21
21
 
22
22
  import numpy as np
23
23
  import torch
24
+ from packaging.version import Version
24
25
 
25
26
  from .__version__ import __version__
27
+ from .utils.depreacate_utils import warn_deprecated_npu
26
28
  from .utils.logging import get_logger
27
29
  from .utils.runtime_utils import ContextRblnConfig
28
30
 
@@ -31,7 +33,6 @@ logger = get_logger(__name__)
31
33
 
32
34
 
33
35
  DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
34
- DEFAULT_MOD_NAME = "default"
35
36
  TypeInputInfo = List[Tuple[str, Tuple[int], str]]
36
37
 
37
38
 
@@ -39,6 +40,9 @@ TypeInputInfo = List[Tuple[str, Tuple[int], str]]
39
40
  class RBLNSerializableConfigProtocol(Protocol):
40
41
  def _prepare_for_serialization(self) -> Dict[str, Any]: ...
41
42
 
43
+ def __repr__(self) -> str:
44
+ return f"{self.__class__.__name__}({self._prepare_for_serialization()})"
45
+
42
46
 
43
47
  @dataclass
44
48
  class RBLNCompileConfig:
@@ -47,17 +51,13 @@ class RBLNCompileConfig:
47
51
 
48
52
  Attributes:
49
53
  compiled_model_name (str): Name of the compiled model.
50
- mod_name (str): Name of the RBLN module.
51
54
  input_info (Union[List[TypeInputInfo], TypeInputInfo]): Information about input tensors.
52
- fusion (Optional[bool]): Whether to use fusion optimization.
53
55
  npu (Optional[str]): NPU configuration.
54
56
  tensor_parallel_size (Optional[int]): Size for tensor parallelism.
55
57
  """
56
58
 
57
59
  compiled_model_name: str = DEFAULT_COMPILED_MODEL_NAME
58
- mod_name: str = DEFAULT_MOD_NAME
59
60
  input_info: Union[List[TypeInputInfo], TypeInputInfo] = None
60
- fusion: Optional[bool] = None
61
61
  npu: Optional[str] = None
62
62
  tensor_parallel_size: Optional[int] = None
63
63
 
@@ -111,9 +111,7 @@ class RBLNCompileConfig:
111
111
 
112
112
  def update(self, kwargs: Dict[str, Any]):
113
113
  self.compiled_model_name = kwargs.get("compiled_model_name", self.compiled_model_name)
114
- self.mod_name = kwargs.get("mod_name", self.mod_name)
115
114
  self.input_info = kwargs.get("input_info", self.input_info)
116
- self.fusion = kwargs.get("fusion", self.fusion)
117
115
  self.npu = kwargs.get("npu", self.npu)
118
116
  self.tensor_parallel_size = kwargs.get("tensor_parallel_size", self.tensor_parallel_size)
119
117
  return self
@@ -147,7 +145,7 @@ class RBLNCompileConfig:
147
145
  return asdict(self)
148
146
 
149
147
 
150
- RUNTIME_KEYWORDS = ["create_runtimes", "optimize_host_memory", "device", "device_map", "activate_profiler", "timeout"]
148
+ RUNTIME_KEYWORDS = ["create_runtimes", "device", "device_map", "activate_profiler", "timeout"]
151
149
  CONFIG_MAPPING: Dict[str, Type["RBLNModelConfig"]] = {}
152
150
 
153
151
 
@@ -183,6 +181,15 @@ def load_config(path: str) -> Tuple[Type["RBLNModelConfig"], Dict[str, Any]]:
183
181
 
184
182
 
185
183
  class RBLNAutoConfig:
184
+ """
185
+ Resolver and factory for RBLN model configurations.
186
+
187
+ This class selects the concrete `RBLNModelConfig` subclass, validates the
188
+ provided data, and returns a frozen configuration object that serves as the
189
+ single source of truth during export and load. It does not define the schema
190
+ or control model behavior.
191
+ """
192
+
186
193
  def __new__(cls, **kwargs):
187
194
  cls_name = kwargs.get("cls_name")
188
195
  if cls_name is None:
@@ -192,6 +199,33 @@ class RBLNAutoConfig:
192
199
 
193
200
  @staticmethod
194
201
  def load_from_dict(config_dict: Dict[str, Any]) -> "RBLNModelConfig":
202
+ """
203
+ Build a `RBLNModelConfig` from a plain dictionary.
204
+
205
+ The dictionary must contain `cls_name`, which identifies the concrete
206
+ configuration class to instantiate. All other keys are forwarded to the
207
+ target class initializer. This method does not mutate `config_dict`.
208
+
209
+ Args:
210
+ config_dict: Mapping typically created by `json.load` or `yaml.safe_load`.
211
+ For example, the parsed contents of `rbln_config.json`.
212
+
213
+ Returns:
214
+ RBLNModelConfig: A configuration instance. The specific subclass is
215
+ selected by `config_dict["cls_name"]`.
216
+
217
+ Raises:
218
+ ValueError: If `cls_name` is missing.
219
+ Exception: Any error raised by the target config class during init.
220
+
221
+ Examples:
222
+ >>> data = {
223
+ ... "cls_name": "RBLNLlamaForCausalLMConfig",
224
+ ... "create_runtimes": False,
225
+ ... "tensor_parallel_size": 4
226
+ ... }
227
+ >>> cfg = RBLNAutoConfig.load_from_dict(data)
228
+ """
195
229
  cls_name = config_dict.get("cls_name")
196
230
  if cls_name is None:
197
231
  raise ValueError("`cls_name` is required.")
@@ -204,7 +238,8 @@ class RBLNAutoConfig:
204
238
  Register a new configuration for this class.
205
239
 
206
240
  Args:
207
- config ([`RBLNModelConfig`]): The config to register.
241
+ config (RBLNModelConfig): The config to register.
242
+ exist_ok (bool): Whether to allow registering an already registered model.
208
243
  """
209
244
  if not issubclass(config, RBLNModelConfig):
210
245
  raise ValueError("`config` must be a subclass of RBLNModelConfig.")
@@ -246,9 +281,6 @@ class RBLNAutoConfig:
246
281
  if key[5:] not in RUNTIME_KEYWORDS and key[5:] not in cls.submodules
247
282
  }
248
283
 
249
- if len(rbln_kwargs) > 0:
250
- raise ValueError(f"Cannot set the following arguments: {list(rbln_kwargs.keys())}")
251
-
252
284
  # Process submodule's rbln_config
253
285
  for submodule in cls.submodules:
254
286
  if submodule not in config_file:
@@ -263,6 +295,16 @@ class RBLNAutoConfig:
263
295
 
264
296
  config_file.update(rbln_runtime_kwargs)
265
297
 
298
+ rbln_config = cls(**config_file)
299
+
300
+ if len(rbln_kwargs) > 0:
301
+ for key, value in rbln_kwargs.items():
302
+ if getattr(rbln_config, key) != value:
303
+ raise ValueError(
304
+ f"Cannot set the following arguments: {list(rbln_kwargs.keys())} "
305
+ f"Since the value is already set to {getattr(rbln_config, key)}"
306
+ )
307
+
266
308
  if return_unused_kwargs:
267
309
  return cls(**config_file), kwargs
268
310
  else:
@@ -273,6 +315,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
273
315
  """Base configuration class for RBLN models that handles compilation settings, runtime options, and submodules.
274
316
 
275
317
  This class provides functionality for:
318
+
276
319
  1. Managing compilation configurations for RBLN devices
277
320
  2. Configuring runtime behavior such as device placement
278
321
  3. Handling nested configuration objects for complex model architectures
@@ -474,10 +517,10 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
474
517
  non_save_attributes = [
475
518
  "_frozen",
476
519
  "_runtime_options",
520
+ "torch_dtype",
477
521
  "npu",
478
522
  "tensor_parallel_size",
479
523
  "create_runtimes",
480
- "optimize_host_memory",
481
524
  "device",
482
525
  "device_map",
483
526
  "activate_profiler",
@@ -486,18 +529,18 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
486
529
  submodules: List[str] = []
487
530
  subclass_non_save_attributes = []
488
531
 
489
- def init_submodule_config(
532
+ def initialize_submodule_config(
490
533
  self,
491
- submodule_config_cls: Type["RBLNModelConfig"],
492
534
  submodule_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
493
- **kwargs: Dict[str, Any],
535
+ force_kwargs: bool = False,
536
+ **kwargs: Any,
494
537
  ) -> "RBLNModelConfig":
495
- # Initialize a submodule config from a dict or a RBLNModelConfig.
496
- # kwargs is specified from the predecessor config.
497
-
498
538
  if submodule_config is None:
499
539
  submodule_config = {}
500
540
 
541
+ if isinstance(submodule_config, RBLNModelConfig):
542
+ return submodule_config
543
+
501
544
  if isinstance(submodule_config, dict):
502
545
  from_predecessor = self._runtime_options.copy()
503
546
  from_predecessor.update(
@@ -511,13 +554,60 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
511
554
 
512
555
  init_kwargs = from_predecessor
513
556
  init_kwargs.update(submodule_config)
514
- submodule_config = submodule_config_cls(**init_kwargs)
515
557
 
516
- if not isinstance(submodule_config, submodule_config_cls):
558
+ if force_kwargs:
559
+ for key, value in kwargs.items():
560
+ if key in init_kwargs:
561
+ if init_kwargs[key] != value:
562
+ raise ValueError(
563
+ f"Parameter conflict for '{key}': submodule_config has {init_kwargs[key]}, "
564
+ f"but kwargs has {value}. Using kwargs value: {value}"
565
+ )
566
+ init_kwargs[key] = value
567
+
568
+ if "cls_name" in init_kwargs:
569
+ config_cls = get_rbln_config_class(init_kwargs["cls_name"])
570
+ else:
571
+ return init_kwargs
572
+
573
+ submodule_config = config_cls(**init_kwargs)
574
+
575
+ if not isinstance(submodule_config, RBLNModelConfig):
517
576
  raise TypeError(f"Invalid submodule config type: {type(submodule_config)}")
518
577
 
519
578
  return submodule_config
520
579
 
580
+ def filter_parameters(self, config_cls: Type["RBLNModelConfig"], parameters: Dict[str, Any]) -> Dict[str, Any]:
581
+ import importlib
582
+
583
+ model_cls_name = config_cls.__name__.replace("Config", "")
584
+ modeling_module_name = config_cls.__module__.replace("configuration_", "modeling_")
585
+
586
+ model_cls = None
587
+ try:
588
+ modeling_module = importlib.import_module(modeling_module_name)
589
+ if hasattr(modeling_module, model_cls_name):
590
+ model_cls = getattr(modeling_module, model_cls_name)
591
+ except ImportError:
592
+ logger.debug(f"Could not import modeling module: {modeling_module_name}")
593
+
594
+ filtered_out_params = set()
595
+
596
+ if model_cls is not None:
597
+ if not getattr(model_cls, "_tp_support", False):
598
+ filtered_out_params.add("tensor_parallel_size")
599
+
600
+ filtered_params = {}
601
+ for key, value in parameters.items():
602
+ if key in filtered_out_params:
603
+ logger.debug(
604
+ f"Parameter '{key}' filtered out for {config_cls.__name__} (not supported by model flags)."
605
+ )
606
+ else:
607
+ filtered_params[key] = value
608
+
609
+ return filtered_params
610
+
521
611
  def __setattr__(self, key, value):
522
612
  if (
523
613
  key != "_attributes_map"
@@ -556,7 +646,6 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
556
646
  self,
557
647
  cls_name: Optional[str] = None,
558
648
  create_runtimes: Optional[bool] = None,
559
- optimize_host_memory: Optional[bool] = None,
560
649
  device: Optional[Union[int, List[int]]] = None,
561
650
  device_map: Optional[Dict[str, Union[int, List[int]]]] = None,
562
651
  activate_profiler: Optional[bool] = None,
@@ -564,8 +653,11 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
564
653
  tensor_parallel_size: Optional[int] = None,
565
654
  timeout: Optional[int] = None,
566
655
  optimum_rbln_version: Optional[str] = None,
656
+ _torch_dtype: Optional[str] = None,
567
657
  _compile_cfgs: List[RBLNCompileConfig] = [],
568
- **kwargs: Dict[str, Any],
658
+ *,
659
+ optimize_host_memory: Optional[bool] = None,
660
+ **kwargs: Any,
569
661
  ):
570
662
  """
571
663
  Initialize a RBLN model configuration with runtime options and compile configurations.
@@ -573,7 +665,6 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
573
665
  Args:
574
666
  cls_name (Optional[str]): The class name of the configuration. Defaults to the current class name.
575
667
  create_runtimes (Optional[bool]): Whether to create RBLN runtimes. Defaults to True.
576
- optimize_host_memory (Optional[bool]): Whether to optimize host memory usage. Defaults to True.
577
668
  device (Optional[Union[int, List[int]]]): The device(s) to load the model onto. Can be a single device ID or a list.
578
669
  device_map (Optional[Dict[str, Union[int, List[int]]]]): Mapping from compiled model names to device IDs.
579
670
  activate_profiler (Optional[bool]): Whether to activate the profiler for performance analysis.
@@ -581,8 +672,9 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
581
672
  tensor_parallel_size (Optional[int]): Size for tensor parallelism to distribute the model across devices.
582
673
  timeout (Optional[int]): The timeout for the runtime in seconds. If it isn't provided, it will be set to 60 by default.
583
674
  optimum_rbln_version (Optional[str]): The optimum-rbln version used for this configuration.
675
+ _torch_dtype (Optional[str]): The data type to use for the model.
584
676
  _compile_cfgs (List[RBLNCompileConfig]): List of compilation configurations for the model.
585
- **kwargs: Additional keyword arguments.
677
+ kwargs: Additional keyword arguments.
586
678
 
587
679
  Raises:
588
680
  ValueError: If unexpected keyword arguments are provided.
@@ -598,16 +690,19 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
598
690
 
599
691
  self._runtime_options = {}
600
692
  self._runtime_options["create_runtimes"] = create_runtimes
601
- self._runtime_options["optimize_host_memory"] = optimize_host_memory
602
693
  self._runtime_options["device"] = device
603
694
  self._runtime_options["device_map"] = device_map
604
695
  self._runtime_options["activate_profiler"] = activate_profiler
605
696
  self._runtime_options["timeout"] = timeout
606
697
 
698
+ if optimize_host_memory is not None:
699
+ logger.warning("`optimize_host_memory` is deprecated and will be removed in future versions.")
700
+
607
701
  # Automatically pass npu, tensor_parallel_size to compile_cfgs
608
702
  self.npu = npu
609
703
  self.tensor_parallel_size = tensor_parallel_size
610
704
 
705
+ self._torch_dtype = _torch_dtype or "float32"
611
706
  self.optimum_rbln_version = optimum_rbln_version
612
707
  if self.optimum_rbln_version is None:
613
708
  self.optimum_rbln_version = __version__
@@ -620,8 +715,34 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
620
715
  self.set_compile_cfgs([RBLNCompileConfig(**cfg) for cfg in self._compile_cfgs])
621
716
 
622
717
  if len(kwargs) > 0:
718
+ if optimum_rbln_version is not None: # loaded from file
719
+ if Version(__version__) < Version(optimum_rbln_version):
720
+ diff = "newer"
721
+ elif Version(__version__) > Version(optimum_rbln_version):
722
+ diff = "older"
723
+ else:
724
+ diff = None
725
+ if diff is not None:
726
+ raise ValueError(
727
+ f"Unexpected arguments: {kwargs.keys()}\n"
728
+ f"Maybe you are trying to load a model compiled with {diff} version of optimum-rbln. "
729
+ "It is recommended to use the same version to compile and load the model.\n"
730
+ f"Current version: {__version__}, Loaded version: {optimum_rbln_version}"
731
+ )
732
+
623
733
  raise ValueError(f"Unexpected arguments: {kwargs.keys()}")
624
734
 
735
+ @property
736
+ def torch_dtype(self):
737
+ return getattr(torch, self._torch_dtype)
738
+
739
+ @torch_dtype.setter
740
+ def torch_dtype(self, torch_dtype: Union[str, torch.dtype]):
741
+ if isinstance(torch_dtype, torch.dtype):
742
+ torch_dtype = RBLNCompileConfig.normalize_dtype(torch_dtype)
743
+
744
+ self._torch_dtype = torch_dtype
745
+
625
746
  @property
626
747
  def rbln_model_cls_name(self) -> str:
627
748
  return self.__class__.__name__[:-6]
@@ -675,6 +796,9 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
675
796
  compile_cfg.npu = self.npu
676
797
  compile_cfg.tensor_parallel_size = self.tensor_parallel_size
677
798
 
799
+ target_npu = self.npu or next((cfg.npu for cfg in self._compile_cfgs if cfg.npu is not None), None)
800
+ warn_deprecated_npu(target_npu)
801
+
678
802
  def freeze(self):
679
803
  if self._frozen:
680
804
  raise RuntimeError(f"`{self.__class__.__name__}` is already frozen.")
@@ -713,13 +837,13 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
713
837
  json.dump(serializable_data, jsonf, indent=2)
714
838
 
715
839
  @classmethod
716
- def load(cls, path: str, **kwargs: Dict[str, Any]) -> "RBLNModelConfig":
840
+ def load(cls, path: str, **kwargs: Any) -> "RBLNModelConfig":
717
841
  """
718
842
  Load a RBLNModelConfig from a path.
719
843
 
720
844
  Args:
721
845
  path (str): Path to the RBLNModelConfig file or directory containing the config file.
722
- **kwargs: Additional keyword arguments to override configuration values.
846
+ kwargs: Additional keyword arguments to override configuration values.
723
847
  Keys starting with 'rbln_' will have the prefix removed and be used
724
848
  to update the configuration.
725
849
 
@@ -746,7 +870,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
746
870
  def initialize_from_kwargs(
747
871
  cls: Type["RBLNModelConfig"],
748
872
  rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
749
- **kwargs: Dict[str, Any],
873
+ **kwargs: Any,
750
874
  ) -> Tuple["RBLNModelConfig", Dict[str, Any]]:
751
875
  # Initialize RBLNModelConfig from kwargs.
752
876
  kwargs_keys = list(kwargs.keys())
@@ -791,19 +915,6 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
791
915
  def create_runtimes(self, create_runtimes: bool):
792
916
  self._runtime_options["create_runtimes"] = create_runtimes
793
917
 
794
- @property
795
- def optimize_host_memory(self):
796
- context = ContextRblnConfig.get_current_context()["optimize_host_memory"]
797
- if context is not None:
798
- return context
799
- elif self._runtime_options["optimize_host_memory"] is None:
800
- return True
801
- return self._runtime_options["optimize_host_memory"]
802
-
803
- @optimize_host_memory.setter
804
- def optimize_host_memory(self, optimize_host_memory: bool):
805
- self._runtime_options["optimize_host_memory"] = optimize_host_memory
806
-
807
918
  @property
808
919
  def device(self):
809
920
  context = ContextRblnConfig.get_current_context()["device"]
@@ -59,6 +59,9 @@ _import_structure = {
59
59
  "RBLNVQModelConfig",
60
60
  ],
61
61
  "pipelines": [
62
+ "RBLNAutoPipelineForImage2Image",
63
+ "RBLNAutoPipelineForInpainting",
64
+ "RBLNAutoPipelineForText2Image",
62
65
  "RBLNCosmosTextToWorldPipeline",
63
66
  "RBLNCosmosVideoToWorldPipeline",
64
67
  "RBLNCosmosSafetyChecker",
@@ -135,6 +138,7 @@ if TYPE_CHECKING:
135
138
  from .modeling_diffusers import RBLNDiffusionMixin
136
139
  from .models import (
137
140
  RBLNAutoencoderKL,
141
+ RBLNAutoencoderKLCosmos,
138
142
  RBLNControlNetModel,
139
143
  RBLNCosmosTransformer3DModel,
140
144
  RBLNPriorTransformer,
@@ -143,6 +147,9 @@ if TYPE_CHECKING:
143
147
  RBLNVQModel,
144
148
  )
145
149
  from .pipelines import (
150
+ RBLNAutoPipelineForImage2Image,
151
+ RBLNAutoPipelineForInpainting,
152
+ RBLNAutoPipelineForText2Image,
146
153
  RBLNCosmosSafetyChecker,
147
154
  RBLNCosmosTextToWorldPipeline,
148
155
  RBLNCosmosVideoToWorldPipeline,
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -33,7 +33,7 @@ class RBLNAutoencoderKLConfig(RBLNModelConfig):
33
33
  vae_scale_factor: Optional[float] = None, # TODO: rename to scaling_factor
34
34
  in_channels: Optional[int] = None,
35
35
  latent_channels: Optional[int] = None,
36
- **kwargs: Dict[str, Any],
36
+ **kwargs: Any,
37
37
  ):
38
38
  """
39
39
  Args:
@@ -46,7 +46,7 @@ class RBLNAutoencoderKLConfig(RBLNModelConfig):
46
46
  Determines how much smaller the latent representations are compared to the original images.
47
47
  in_channels (Optional[int]): Number of input channels for the model.
48
48
  latent_channels (Optional[int]): Number of channels in the latent space.
49
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
49
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
50
50
 
51
51
  Raises:
52
52
  ValueError: If batch_size is not a positive integer.
@@ -52,7 +52,7 @@ class RBLNAutoencoderKLCosmosConfig(RBLNModelConfig):
52
52
  Determines how much smaller the latent representations are compared to the original videos.
53
53
  use_slicing (Optional[bool]): Enable sliced VAE encoding and decoding.
54
54
  If True, the VAE will split the input tensor in slices to compute encoding or decoding in several steps.
55
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
55
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
56
56
 
57
57
  Raises:
58
58
  ValueError: If batch_size is not a positive integer.
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -29,7 +29,7 @@ class RBLNControlNetModelConfig(RBLNModelConfig):
29
29
  unet_sample_size: Optional[Tuple[int, int]] = None,
30
30
  vae_sample_size: Optional[Tuple[int, int]] = None,
31
31
  text_model_hidden_size: Optional[int] = None,
32
- **kwargs: Dict[str, Any],
32
+ **kwargs: Any,
33
33
  ):
34
34
  """
35
35
  Args:
@@ -42,7 +42,7 @@ class RBLNControlNetModelConfig(RBLNModelConfig):
42
42
  of the VAE input/output images.
43
43
  text_model_hidden_size (Optional[int]): Hidden size of the text encoder model used
44
44
  for conditioning.
45
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
45
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
46
46
 
47
47
  Raises:
48
48
  ValueError: If batch_size is not a positive integer.
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -22,7 +22,7 @@ class RBLNPriorTransformerConfig(RBLNModelConfig):
22
22
  Configuration class for RBLN Prior Transformer models.
23
23
 
24
24
  This class inherits from RBLNModelConfig and provides specific configuration options
25
- for Prior Transformer models used in diffusion models like Kandinsky V2.2.
25
+ for Transformer models used in diffusion models like Kandinsky V2.2.
26
26
  """
27
27
 
28
28
  subclass_non_save_attributes = ["_batch_size_is_specified"]
@@ -32,14 +32,14 @@ class RBLNPriorTransformerConfig(RBLNModelConfig):
32
32
  batch_size: Optional[int] = None,
33
33
  embedding_dim: Optional[int] = None,
34
34
  num_embeddings: Optional[int] = None,
35
- **kwargs: Dict[str, Any],
35
+ **kwargs: Any,
36
36
  ):
37
37
  """
38
38
  Args:
39
39
  batch_size (Optional[int]): The batch size for inference. Defaults to 1.
40
40
  embedding_dim (Optional[int]): Dimension of the embedding vectors in the model.
41
41
  num_embeddings (Optional[int]): Number of discrete embeddings in the codebook.
42
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
42
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
43
43
 
44
44
  Raises:
45
45
  ValueError: If batch_size is not a positive integer.
@@ -12,13 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNCosmosTransformer3DModelConfig(RBLNModelConfig):
21
- """Configuration class for RBLN Cosmos Transformer models."""
21
+ """
22
+ Configuration class for RBLN Cosmos Transformer models.
23
+
24
+ This class inherits from RBLNModelConfig and provides specific configuration options
25
+ for Transformer models used in diffusion models like Cosmos.
26
+ """
22
27
 
23
28
  def __init__(
24
29
  self,
@@ -33,7 +38,7 @@ class RBLNCosmosTransformer3DModelConfig(RBLNModelConfig):
33
38
  num_latent_frames: Optional[int] = None,
34
39
  latent_height: Optional[int] = None,
35
40
  latent_width: Optional[int] = None,
36
- **kwargs: Dict[str, Any],
41
+ **kwargs: Any,
37
42
  ):
38
43
  """
39
44
  Args:
@@ -47,7 +52,7 @@ class RBLNCosmosTransformer3DModelConfig(RBLNModelConfig):
47
52
  num_channels_latents (Optional[int]): The number of channels in latent space.
48
53
  latent_height (Optional[int]): The height in pixels in latent space.
49
54
  latent_width (Optional[int]): The width in pixels in latent space.
50
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
55
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
51
56
 
52
57
  Raises:
53
58
  ValueError: If batch_size is not a positive integer.
@@ -12,13 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple, Union
15
+ from typing import Any, Optional, Tuple, Union
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNSD3Transformer2DModelConfig(RBLNModelConfig):
21
- """Configuration class for RBLN Stable Diffusion 3 Transformer models."""
21
+ """
22
+ Configuration class for RBLN Stable Diffusion 3 Transformer models.
23
+
24
+ This class inherits from RBLNModelConfig and provides specific configuration options
25
+ for Transformer models used in diffusion models like Stable Diffusion 3.
26
+ """
22
27
 
23
28
  subclass_non_save_attributes = ["_batch_size_is_specified"]
24
29
 
@@ -27,7 +32,7 @@ class RBLNSD3Transformer2DModelConfig(RBLNModelConfig):
27
32
  batch_size: Optional[int] = None,
28
33
  sample_size: Optional[Union[int, Tuple[int, int]]] = None,
29
34
  prompt_embed_length: Optional[int] = None,
30
- **kwargs: Dict[str, Any],
35
+ **kwargs: Any,
31
36
  ):
32
37
  """
33
38
  Args:
@@ -36,7 +41,7 @@ class RBLNSD3Transformer2DModelConfig(RBLNModelConfig):
36
41
  of the generated samples. If an integer is provided, it's used for both height and width.
37
42
  prompt_embed_length (Optional[int]): The length of the embedded prompt vectors that
38
43
  will be used to condition the transformer model.
39
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
44
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
40
45
 
41
46
  Raises:
42
47
  ValueError: If batch_size is not a positive integer.
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -38,7 +38,7 @@ class RBLNUNet2DConditionModelConfig(RBLNModelConfig):
38
38
  in_features: Optional[int] = None,
39
39
  text_model_hidden_size: Optional[int] = None,
40
40
  image_model_hidden_size: Optional[int] = None,
41
- **kwargs: Dict[str, Any],
41
+ **kwargs: Any,
42
42
  ):
43
43
  """
44
44
  Args:
@@ -52,7 +52,7 @@ class RBLNUNet2DConditionModelConfig(RBLNModelConfig):
52
52
  in_features (Optional[int]): Number of input features for the model.
53
53
  text_model_hidden_size (Optional[int]): Hidden size of the text encoder model.
54
54
  image_model_hidden_size (Optional[int]): Hidden size of the image encoder model.
55
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
55
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
56
56
 
57
57
  Raises:
58
58
  ValueError: If batch_size is not a positive integer.
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -33,7 +33,7 @@ class RBLNVQModelConfig(RBLNModelConfig):
33
33
  vqmodel_scale_factor: Optional[float] = None, # TODO: rename to scaling_factor
34
34
  in_channels: Optional[int] = None,
35
35
  latent_channels: Optional[int] = None,
36
- **kwargs: Dict[str, Any],
36
+ **kwargs: Any,
37
37
  ):
38
38
  """
39
39
  Args:
@@ -46,7 +46,7 @@ class RBLNVQModelConfig(RBLNModelConfig):
46
46
  Determines the downsampling ratio between original images and latent representations.
47
47
  in_channels (Optional[int]): Number of input channels for the model.
48
48
  latent_channels (Optional[int]): Number of channels in the latent space.
49
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
49
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
50
50
 
51
51
  Raises:
52
52
  ValueError: If batch_size is not a positive integer.