optimum-rbln 0.8.1a0__py3-none-any.whl → 0.8.1a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. optimum/rbln/__init__.py +2 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +53 -33
  4. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  5. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +33 -9
  11. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -12
  12. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +22 -6
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +16 -6
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +16 -6
  15. optimum/rbln/diffusers/modeling_diffusers.py +16 -26
  16. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +11 -0
  17. optimum/rbln/diffusers/models/autoencoders/vae.py +1 -8
  18. optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -0
  19. optimum/rbln/diffusers/models/controlnet.py +13 -7
  20. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  21. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +7 -0
  23. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  24. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  25. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  26. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  28. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  29. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  30. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  31. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  33. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  34. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  35. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  36. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  38. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  42. optimum/rbln/modeling.py +33 -35
  43. optimum/rbln/modeling_base.py +45 -107
  44. optimum/rbln/transformers/__init__.py +39 -47
  45. optimum/rbln/transformers/configuration_generic.py +16 -13
  46. optimum/rbln/transformers/modeling_generic.py +18 -19
  47. optimum/rbln/transformers/modeling_rope_utils.py +5 -2
  48. optimum/rbln/transformers/models/__init__.py +46 -4
  49. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  52. optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
  54. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +35 -4
  55. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  56. optimum/rbln/transformers/models/clip/modeling_clip.py +11 -12
  57. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  58. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  59. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +229 -175
  60. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  61. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +19 -0
  62. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +19 -0
  63. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  64. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  65. optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
  66. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  67. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  68. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  69. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  70. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  71. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +106 -236
  72. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  73. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  74. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  75. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  76. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  77. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  78. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  79. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
  80. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  81. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  82. optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
  83. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  84. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  85. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  86. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  87. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  88. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  89. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  90. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  91. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  92. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  93. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +15 -3
  94. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +58 -27
  95. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +47 -2
  96. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  97. optimum/rbln/transformers/models/resnet/configuration_resnet.py +20 -0
  98. optimum/rbln/transformers/models/resnet/modeling_resnet.py +22 -0
  99. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  100. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +4 -30
  101. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +2 -32
  102. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  103. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  104. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -1
  105. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
  106. optimum/rbln/transformers/models/siglip/configuration_siglip.py +3 -0
  107. optimum/rbln/transformers/models/siglip/modeling_siglip.py +62 -21
  108. optimum/rbln/transformers/models/t5/modeling_t5.py +46 -4
  109. optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
  110. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  111. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +2 -2
  112. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +14 -9
  113. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  114. optimum/rbln/transformers/models/vit/configuration_vit.py +19 -0
  115. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  116. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  117. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  118. optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -1
  119. optimum/rbln/transformers/models/whisper/modeling_whisper.py +35 -15
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  123. optimum/rbln/utils/model_utils.py +20 -0
  124. optimum/rbln/utils/submodule.py +6 -8
  125. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/METADATA +2 -2
  126. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/RECORD +130 -117
  127. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  128. /optimum/rbln/transformers/models/wav2vec2/{configuration_wav2vec.py → configuration_wav2vec2.py} +0 -0
  129. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/WHEEL +0 -0
  130. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -26,6 +26,7 @@ _import_structure = {
26
26
  "RBLNModel",
27
27
  ],
28
28
  "configuration_utils": [
29
+ "RBLNAutoConfig",
29
30
  "RBLNCompileConfig",
30
31
  "RBLNModelConfig",
31
32
  ],
@@ -192,6 +193,7 @@ _import_structure = {
192
193
 
193
194
  if TYPE_CHECKING:
194
195
  from .configuration_utils import (
196
+ RBLNAutoConfig,
195
197
  RBLNCompileConfig,
196
198
  RBLNModelConfig,
197
199
  )
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.8.1a0'
21
- __version_tuple__ = version_tuple = (0, 8, 1, 'a0')
20
+ __version__ = version = '0.8.1a2'
21
+ __version_tuple__ = version_tuple = (0, 8, 1, 'a2')
@@ -19,6 +19,7 @@ from dataclasses import asdict, dataclass
19
19
  from pathlib import Path
20
20
  from typing import Any, Dict, List, Optional, Protocol, Tuple, Type, Union, runtime_checkable
21
21
 
22
+ import numpy as np
22
23
  import torch
23
24
 
24
25
  from .__version__ import __version__
@@ -61,7 +62,7 @@ class RBLNCompileConfig:
61
62
  tensor_parallel_size: Optional[int] = None
62
63
 
63
64
  @staticmethod
64
- def normalize_dtype(dtype):
65
+ def normalize_dtype(dtype: Union[str, torch.dtype, np.dtype]) -> str:
65
66
  """
66
67
  Convert framework-specific dtype to string representation.
67
68
  i.e. torch.float32 -> "float32"
@@ -70,7 +71,7 @@ class RBLNCompileConfig:
70
71
  dtype: The input dtype (can be string, torch dtype, or numpy dtype).
71
72
 
72
73
  Returns:
73
- str: The normalized string representation of the dtype.
74
+ The normalized string representation of the dtype.
74
75
  """
75
76
  if isinstance(dtype, str):
76
77
  return dtype
@@ -147,6 +148,17 @@ class RBLNCompileConfig:
147
148
 
148
149
 
149
150
  RUNTIME_KEYWORDS = ["create_runtimes", "optimize_host_memory", "device", "device_map", "activate_profiler"]
151
+ CONFIG_MAPPING: Dict[str, Type["RBLNModelConfig"]] = {}
152
+
153
+
154
+ def get_rbln_config_class(rbln_config_class_name: str) -> Type["RBLNModelConfig"]:
155
+ cls = getattr(importlib.import_module("optimum.rbln"), rbln_config_class_name, None)
156
+ if cls is None:
157
+ if rbln_config_class_name in CONFIG_MAPPING:
158
+ cls = CONFIG_MAPPING[rbln_config_class_name]
159
+ else:
160
+ raise ValueError(f"Configuration for {rbln_config_class_name} not found.")
161
+ return cls
150
162
 
151
163
 
152
164
  def load_config(path: str) -> Tuple[Type["RBLNModelConfig"], Dict[str, Any]]:
@@ -166,7 +178,7 @@ def load_config(path: str) -> Tuple[Type["RBLNModelConfig"], Dict[str, Any]]:
166
178
  )
167
179
 
168
180
  cls_name = config_file["cls_name"]
169
- cls = getattr(importlib.import_module("optimum.rbln"), cls_name)
181
+ cls = get_rbln_config_class(cls_name)
170
182
  return cls, config_file
171
183
 
172
184
 
@@ -175,7 +187,7 @@ class RBLNAutoConfig:
175
187
  cls_name = kwargs.get("cls_name")
176
188
  if cls_name is None:
177
189
  raise ValueError("`cls_name` is required.")
178
- cls = getattr(importlib.import_module("optimum.rbln"), cls_name)
190
+ cls = get_rbln_config_class(cls_name)
179
191
  return cls(**kwargs)
180
192
 
181
193
  @staticmethod
@@ -183,9 +195,27 @@ class RBLNAutoConfig:
183
195
  cls_name = config_dict.get("cls_name")
184
196
  if cls_name is None:
185
197
  raise ValueError("`cls_name` is required.")
186
- cls = getattr(importlib.import_module("optimum.rbln"), cls_name)
198
+ cls = get_rbln_config_class(cls_name)
187
199
  return cls(**config_dict)
188
200
 
201
+ @staticmethod
202
+ def register(config: Type["RBLNModelConfig"], exist_ok=False):
203
+ """
204
+ Register a new configuration for this class.
205
+
206
+ Args:
207
+ config ([`RBLNModelConfig`]): The config to register.
208
+ """
209
+ if not issubclass(config, RBLNModelConfig):
210
+ raise ValueError("`config` must be a subclass of RBLNModelConfig.")
211
+
212
+ native_cls = getattr(importlib.import_module("optimum.rbln"), config.__name__, None)
213
+ if config.__name__ in CONFIG_MAPPING or native_cls is not None:
214
+ if not exist_ok:
215
+ raise ValueError(f"Configuration for {config.__name__} already registered.")
216
+
217
+ CONFIG_MAPPING[config.__name__] = config
218
+
189
219
  @staticmethod
190
220
  def load(
191
221
  path: str,
@@ -307,9 +337,6 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
307
337
  # Save to disk
308
338
  config.save("/path/to/model")
309
339
 
310
- # Load configuration from disk
311
- loaded_config = RBLNModelConfig.load("/path/to/model")
312
-
313
340
  # Using AutoConfig
314
341
  loaded_config = RBLNAutoConfig.load("/path/to/model")
315
342
  ```
@@ -462,19 +489,25 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
462
489
  self,
463
490
  submodule_config_cls: Type["RBLNModelConfig"],
464
491
  submodule_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
465
- **kwargs,
492
+ **kwargs: Dict[str, Any],
466
493
  ) -> "RBLNModelConfig":
467
- """
468
- Initialize a submodule config from a dict or a RBLNModelConfig.
494
+ # Initialize a submodule config from a dict or a RBLNModelConfig.
495
+ # kwargs is specified from the predecessor config.
469
496
 
470
- kwargs is specified from the predecessor config.
471
- """
472
497
  if submodule_config is None:
473
498
  submodule_config = {}
474
499
 
475
500
  if isinstance(submodule_config, dict):
476
501
  from_predecessor = self._runtime_options.copy()
502
+ from_predecessor.update(
503
+ {
504
+ "npu": self.npu,
505
+ "tensor_parallel_size": self.tensor_parallel_size,
506
+ "optimum_rbln_version": self.optimum_rbln_version,
507
+ }
508
+ )
477
509
  from_predecessor.update(kwargs)
510
+
478
511
  init_kwargs = from_predecessor
479
512
  init_kwargs.update(submodule_config)
480
513
  submodule_config = submodule_config_cls(**init_kwargs)
@@ -530,7 +563,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
530
563
  tensor_parallel_size: Optional[int] = None,
531
564
  optimum_rbln_version: Optional[str] = None,
532
565
  _compile_cfgs: List[RBLNCompileConfig] = [],
533
- **kwargs,
566
+ **kwargs: Dict[str, Any],
534
567
  ):
535
568
  """
536
569
  Initialize a RBLN model configuration with runtime options and compile configurations.
@@ -600,10 +633,8 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
600
633
  return rbln_model_cls
601
634
 
602
635
  def _prepare_for_serialization(self) -> Dict[str, Any]:
603
- """
604
- Prepare the attributes map for serialization by converting nested RBLNModelConfig
605
- objects to their serializable form.
606
- """
636
+ # Prepare the attributes map for serialization by converting nested RBLNModelConfig
637
+ # objects to their serializable form.
607
638
  serializable_map = {}
608
639
  for key, value in self._attributes_map.items():
609
640
  if isinstance(value, RBLNSerializableConfigProtocol):
@@ -678,7 +709,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
678
709
  json.dump(serializable_data, jsonf, indent=2)
679
710
 
680
711
  @classmethod
681
- def load(cls, path: str, **kwargs) -> "RBLNModelConfig":
712
+ def load(cls, path: str, **kwargs: Dict[str, Any]) -> "RBLNModelConfig":
682
713
  """
683
714
  Load a RBLNModelConfig from a path.
684
715
 
@@ -711,11 +742,9 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
711
742
  def initialize_from_kwargs(
712
743
  cls: Type["RBLNModelConfig"],
713
744
  rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
714
- **kwargs,
745
+ **kwargs: Dict[str, Any],
715
746
  ) -> Tuple["RBLNModelConfig", Dict[str, Any]]:
716
- """
717
- Initialize RBLNModelConfig from kwargs.
718
- """
747
+ # Initialize RBLNModelConfig from kwargs.
719
748
  kwargs_keys = list(kwargs.keys())
720
749
  rbln_kwargs = {key[5:]: kwargs.pop(key) for key in kwargs_keys if key.startswith("rbln_")}
721
750
 
@@ -733,16 +762,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
733
762
  return rbln_config, kwargs
734
763
 
735
764
  def get_default_values_for_original_cls(self, func_name: str, keys: List[str]) -> Dict[str, Any]:
736
- """
737
- Get default values for original class attributes from RBLNModelConfig.
738
-
739
- Args:
740
- func_name (str): The name of the function to get the default values for.
741
- keys (List[str]): The keys of the attributes to get.
742
-
743
- Returns:
744
- Dict[str, Any]: The default values for the attributes.
745
- """
765
+ # Get default values for original class attributes from RBLNModelConfig.
746
766
  model_cls = self.rbln_model_cls.get_hf_class()
747
767
  func = getattr(model_cls, func_name)
748
768
  func_signature = inspect.signature(func)
@@ -12,12 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional, Tuple
15
+ from typing import Any, Dict, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNAutoencoderKLConfig(RBLNModelConfig):
21
+ """
22
+ Configuration class for RBLN Variational Autoencoder (VAE) models.
23
+
24
+ This class inherits from RBLNModelConfig and provides specific configuration options
25
+ for VAE models used in diffusion-based image generation.
26
+ """
27
+
21
28
  def __init__(
22
29
  self,
23
30
  batch_size: Optional[int] = None,
@@ -26,7 +33,7 @@ class RBLNAutoencoderKLConfig(RBLNModelConfig):
26
33
  vae_scale_factor: Optional[float] = None, # TODO: rename to scaling_factor
27
34
  in_channels: Optional[int] = None,
28
35
  latent_channels: Optional[int] = None,
29
- **kwargs,
36
+ **kwargs: Dict[str, Any],
30
37
  ):
31
38
  """
32
39
  Args:
@@ -12,12 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional, Tuple
15
+ from typing import Any, Dict, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNControlNetModelConfig(RBLNModelConfig):
21
+ """Configuration class for RBLN ControlNet models."""
22
+
21
23
  subclass_non_save_attributes = ["_batch_size_is_specified"]
22
24
 
23
25
  def __init__(
@@ -27,7 +29,7 @@ class RBLNControlNetModelConfig(RBLNModelConfig):
27
29
  unet_sample_size: Optional[Tuple[int, int]] = None,
28
30
  vae_sample_size: Optional[Tuple[int, int]] = None,
29
31
  text_model_hidden_size: Optional[int] = None,
30
- **kwargs,
32
+ **kwargs: Dict[str, Any],
31
33
  ):
32
34
  """
33
35
  Args:
@@ -12,12 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional
15
+ from typing import Any, Dict, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNPriorTransformerConfig(RBLNModelConfig):
21
+ """
22
+ Configuration class for RBLN Prior Transformer models.
23
+
24
+ This class inherits from RBLNModelConfig and provides specific configuration options
25
+ for Prior Transformer models used in diffusion models like Kandinsky V2.2.
26
+ """
27
+
21
28
  subclass_non_save_attributes = ["_batch_size_is_specified"]
22
29
 
23
30
  def __init__(
@@ -25,7 +32,7 @@ class RBLNPriorTransformerConfig(RBLNModelConfig):
25
32
  batch_size: Optional[int] = None,
26
33
  embedding_dim: Optional[int] = None,
27
34
  num_embeddings: Optional[int] = None,
28
- **kwargs,
35
+ **kwargs: Dict[str, Any],
29
36
  ):
30
37
  """
31
38
  Args:
@@ -12,12 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional, Tuple, Union
15
+ from typing import Any, Dict, Optional, Tuple, Union
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNSD3Transformer2DModelConfig(RBLNModelConfig):
21
+ """Configuration class for RBLN Stable Diffusion 3 Transformer models."""
22
+
21
23
  subclass_non_save_attributes = ["_batch_size_is_specified"]
22
24
 
23
25
  def __init__(
@@ -25,7 +27,7 @@ class RBLNSD3Transformer2DModelConfig(RBLNModelConfig):
25
27
  batch_size: Optional[int] = None,
26
28
  sample_size: Optional[Union[int, Tuple[int, int]]] = None,
27
29
  prompt_embed_length: Optional[int] = None,
28
- **kwargs,
30
+ **kwargs: Dict[str, Any],
29
31
  ):
30
32
  """
31
33
  Args:
@@ -12,12 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional, Tuple
15
+ from typing import Any, Dict, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNUNet2DConditionModelConfig(RBLNModelConfig):
21
+ """
22
+ Configuration class for RBLN UNet2DCondition models.
23
+
24
+ This class inherits from RBLNModelConfig and provides specific configuration options
25
+ for UNet2DCondition models used in diffusion-based image generation.
26
+ """
27
+
21
28
  subclass_non_save_attributes = ["_batch_size_is_specified"]
22
29
 
23
30
  def __init__(
@@ -31,7 +38,7 @@ class RBLNUNet2DConditionModelConfig(RBLNModelConfig):
31
38
  in_features: Optional[int] = None,
32
39
  text_model_hidden_size: Optional[int] = None,
33
40
  image_model_hidden_size: Optional[int] = None,
34
- **kwargs,
41
+ **kwargs: Dict[str, Any],
35
42
  ):
36
43
  """
37
44
  Args:
@@ -12,12 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional, Tuple
15
+ from typing import Any, Dict, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNVQModelConfig(RBLNModelConfig):
21
+ """
22
+ Configuration class for RBLN VQModel models, used in Kandinsky.
23
+
24
+ This class inherits from RBLNModelConfig and provides specific configuration options
25
+ for VQModel, which acts similarly to a VAE but uses vector quantization.
26
+ """
27
+
21
28
  def __init__(
22
29
  self,
23
30
  batch_size: Optional[int] = None,
@@ -26,7 +33,7 @@ class RBLNVQModelConfig(RBLNModelConfig):
26
33
  vqmodel_scale_factor: Optional[float] = None, # TODO: rename to scaling_factor
27
34
  in_channels: Optional[int] = None,
28
35
  latent_channels: Optional[int] = None,
29
- **kwargs,
36
+ **kwargs: Dict[str, Any],
30
37
  ):
31
38
  """
32
39
  Args:
@@ -12,14 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional, Tuple
15
+ from typing import Any, Dict, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ....transformers import RBLNCLIPTextModelConfig, RBLNCLIPTextModelWithProjectionConfig
19
19
  from ..models import RBLNAutoencoderKLConfig, RBLNControlNetModelConfig, RBLNUNet2DConditionModelConfig
20
20
 
21
21
 
22
- class _RBLNStableDiffusionControlNetPipelineBaseConfig(RBLNModelConfig):
22
+ class RBLNStableDiffusionControlNetPipelineBaseConfig(RBLNModelConfig):
23
+ """
24
+ Base configuration for Stable Diffusion ControlNet pipelines.
25
+ """
26
+
23
27
  submodules = ["text_encoder", "unet", "vae", "controlnet"]
24
28
  _vae_uses_encoder = False
25
29
 
@@ -38,7 +42,7 @@ class _RBLNStableDiffusionControlNetPipelineBaseConfig(RBLNModelConfig):
38
42
  sample_size: Optional[Tuple[int, int]] = None,
39
43
  image_size: Optional[Tuple[int, int]] = None,
40
44
  guidance_scale: Optional[float] = None,
41
- **kwargs,
45
+ **kwargs: Dict[str, Any],
42
46
  ):
43
47
  """
44
48
  Args:
@@ -138,15 +142,27 @@ class _RBLNStableDiffusionControlNetPipelineBaseConfig(RBLNModelConfig):
138
142
  return self.vae.sample_size
139
143
 
140
144
 
141
- class RBLNStableDiffusionControlNetPipelineConfig(_RBLNStableDiffusionControlNetPipelineBaseConfig):
145
+ class RBLNStableDiffusionControlNetPipelineConfig(RBLNStableDiffusionControlNetPipelineBaseConfig):
146
+ """
147
+ Configuration for Stable Diffusion ControlNet pipeline.
148
+ """
149
+
142
150
  _vae_uses_encoder = False
143
151
 
144
152
 
145
- class RBLNStableDiffusionControlNetImg2ImgPipelineConfig(_RBLNStableDiffusionControlNetPipelineBaseConfig):
153
+ class RBLNStableDiffusionControlNetImg2ImgPipelineConfig(RBLNStableDiffusionControlNetPipelineBaseConfig):
154
+ """
155
+ Configuration for Stable Diffusion ControlNet image-to-image pipeline.
156
+ """
157
+
146
158
  _vae_uses_encoder = True
147
159
 
148
160
 
149
- class _RBLNStableDiffusionXLControlNetPipelineBaseConfig(RBLNModelConfig):
161
+ class RBLNStableDiffusionXLControlNetPipelineBaseConfig(RBLNModelConfig):
162
+ """
163
+ Base configuration for Stable Diffusion XL ControlNet pipelines.
164
+ """
165
+
150
166
  submodules = ["text_encoder", "text_encoder_2", "unet", "vae", "controlnet"]
151
167
  _vae_uses_encoder = False
152
168
 
@@ -166,7 +182,7 @@ class _RBLNStableDiffusionXLControlNetPipelineBaseConfig(RBLNModelConfig):
166
182
  sample_size: Optional[Tuple[int, int]] = None,
167
183
  image_size: Optional[Tuple[int, int]] = None,
168
184
  guidance_scale: Optional[float] = None,
169
- **kwargs,
185
+ **kwargs: Dict[str, Any],
170
186
  ):
171
187
  """
172
188
  Args:
@@ -272,9 +288,17 @@ class _RBLNStableDiffusionXLControlNetPipelineBaseConfig(RBLNModelConfig):
272
288
  return self.vae.sample_size
273
289
 
274
290
 
275
- class RBLNStableDiffusionXLControlNetPipelineConfig(_RBLNStableDiffusionXLControlNetPipelineBaseConfig):
291
+ class RBLNStableDiffusionXLControlNetPipelineConfig(RBLNStableDiffusionXLControlNetPipelineBaseConfig):
292
+ """
293
+ Configuration for Stable Diffusion XL ControlNet pipeline.
294
+ """
295
+
276
296
  _vae_uses_encoder = False
277
297
 
278
298
 
279
- class RBLNStableDiffusionXLControlNetImg2ImgPipelineConfig(_RBLNStableDiffusionXLControlNetPipelineBaseConfig):
299
+ class RBLNStableDiffusionXLControlNetImg2ImgPipelineConfig(RBLNStableDiffusionXLControlNetPipelineBaseConfig):
300
+ """
301
+ Configuration for Stable Diffusion XL ControlNet image-to-image pipeline.
302
+ """
303
+
280
304
  _vae_uses_encoder = True
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional, Tuple
15
+ from typing import Any, Dict, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ....transformers import RBLNCLIPTextModelWithProjectionConfig, RBLNCLIPVisionModelWithProjectionConfig
@@ -20,7 +20,9 @@ from ..models import RBLNUNet2DConditionModelConfig, RBLNVQModelConfig
20
20
  from ..models.configuration_prior_transformer import RBLNPriorTransformerConfig
21
21
 
22
22
 
23
- class _RBLNKandinskyV22PipelineBaseConfig(RBLNModelConfig):
23
+ class RBLNKandinskyV22PipelineBaseConfig(RBLNModelConfig):
24
+ """Base configuration class for Kandinsky V2.2 decoder pipelines."""
25
+
24
26
  submodules = ["unet", "movq"]
25
27
  _movq_uses_encoder = False
26
28
 
@@ -37,7 +39,7 @@ class _RBLNKandinskyV22PipelineBaseConfig(RBLNModelConfig):
37
39
  img_width: Optional[int] = None,
38
40
  height: Optional[int] = None,
39
41
  width: Optional[int] = None,
40
- **kwargs,
42
+ **kwargs: Dict[str, Any],
41
43
  ):
42
44
  """
43
45
  Args:
@@ -117,19 +119,27 @@ class _RBLNKandinskyV22PipelineBaseConfig(RBLNModelConfig):
117
119
  return self.movq.sample_size
118
120
 
119
121
 
120
- class RBLNKandinskyV22PipelineConfig(_RBLNKandinskyV22PipelineBaseConfig):
122
+ class RBLNKandinskyV22PipelineConfig(RBLNKandinskyV22PipelineBaseConfig):
123
+ """Configuration class for the Kandinsky V2.2 text-to-image decoder pipeline."""
124
+
121
125
  _movq_uses_encoder = False
122
126
 
123
127
 
124
- class RBLNKandinskyV22Img2ImgPipelineConfig(_RBLNKandinskyV22PipelineBaseConfig):
128
+ class RBLNKandinskyV22Img2ImgPipelineConfig(RBLNKandinskyV22PipelineBaseConfig):
129
+ """Configuration class for the Kandinsky V2.2 image-to-image decoder pipeline."""
130
+
125
131
  _movq_uses_encoder = True
126
132
 
127
133
 
128
- class RBLNKandinskyV22InpaintPipelineConfig(_RBLNKandinskyV22PipelineBaseConfig):
134
+ class RBLNKandinskyV22InpaintPipelineConfig(RBLNKandinskyV22PipelineBaseConfig):
135
+ """Configuration class for the Kandinsky V2.2 inpainting decoder pipeline."""
136
+
129
137
  _movq_uses_encoder = True
130
138
 
131
139
 
132
140
  class RBLNKandinskyV22PriorPipelineConfig(RBLNModelConfig):
141
+ """Configuration class for the Kandinsky V2.2 Prior pipeline."""
142
+
133
143
  submodules = ["text_encoder", "image_encoder", "prior"]
134
144
 
135
145
  def __init__(
@@ -140,7 +150,7 @@ class RBLNKandinskyV22PriorPipelineConfig(RBLNModelConfig):
140
150
  *,
141
151
  batch_size: Optional[int] = None,
142
152
  guidance_scale: Optional[float] = None,
143
- **kwargs,
153
+ **kwargs: Dict[str, Any],
144
154
  ):
145
155
  """
146
156
  Initialize a configuration for Kandinsky 2.2 prior pipeline optimized for RBLN NPU.
@@ -194,7 +204,9 @@ class RBLNKandinskyV22PriorPipelineConfig(RBLNModelConfig):
194
204
  return self.image_encoder.image_size
195
205
 
196
206
 
197
- class _RBLNKandinskyV22CombinedPipelineBaseConfig(RBLNModelConfig):
207
+ class RBLNKandinskyV22CombinedPipelineBaseConfig(RBLNModelConfig):
208
+ """Base configuration class for Kandinsky V2.2 combined pipelines."""
209
+
198
210
  submodules = ["prior_pipe", "decoder_pipe"]
199
211
  _decoder_pipe_cls = RBLNKandinskyV22PipelineConfig
200
212
 
@@ -216,7 +228,7 @@ class _RBLNKandinskyV22CombinedPipelineBaseConfig(RBLNModelConfig):
216
228
  prior_text_encoder: Optional[RBLNCLIPTextModelWithProjectionConfig] = None,
217
229
  unet: Optional[RBLNUNet2DConditionModelConfig] = None,
218
230
  movq: Optional[RBLNVQModelConfig] = None,
219
- **kwargs,
231
+ **kwargs: Dict[str, Any],
220
232
  ):
221
233
  """
222
234
  Initialize a configuration for combined Kandinsky 2.2 pipelines optimized for RBLN NPU.
@@ -325,13 +337,19 @@ class _RBLNKandinskyV22CombinedPipelineBaseConfig(RBLNModelConfig):
325
337
  return self.decoder_pipe.movq
326
338
 
327
339
 
328
- class RBLNKandinskyV22CombinedPipelineConfig(_RBLNKandinskyV22CombinedPipelineBaseConfig):
340
+ class RBLNKandinskyV22CombinedPipelineConfig(RBLNKandinskyV22CombinedPipelineBaseConfig):
341
+ """Configuration class for the Kandinsky V2.2 combined text-to-image pipeline."""
342
+
329
343
  _decoder_pipe_cls = RBLNKandinskyV22PipelineConfig
330
344
 
331
345
 
332
- class RBLNKandinskyV22InpaintCombinedPipelineConfig(_RBLNKandinskyV22CombinedPipelineBaseConfig):
346
+ class RBLNKandinskyV22InpaintCombinedPipelineConfig(RBLNKandinskyV22CombinedPipelineBaseConfig):
347
+ """Configuration class for the Kandinsky V2.2 combined inpainting pipeline."""
348
+
333
349
  _decoder_pipe_cls = RBLNKandinskyV22InpaintPipelineConfig
334
350
 
335
351
 
336
- class RBLNKandinskyV22Img2ImgCombinedPipelineConfig(_RBLNKandinskyV22CombinedPipelineBaseConfig):
352
+ class RBLNKandinskyV22Img2ImgCombinedPipelineConfig(RBLNKandinskyV22CombinedPipelineBaseConfig):
353
+ """Configuration class for the Kandinsky V2.2 combined image-to-image pipeline."""
354
+
337
355
  _decoder_pipe_cls = RBLNKandinskyV22Img2ImgPipelineConfig
@@ -12,14 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional, Tuple
15
+ from typing import Any, Dict, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ....transformers import RBLNCLIPTextModelConfig
19
19
  from ..models import RBLNAutoencoderKLConfig, RBLNUNet2DConditionModelConfig
20
20
 
21
21
 
22
- class _RBLNStableDiffusionPipelineBaseConfig(RBLNModelConfig):
22
+ class RBLNStableDiffusionPipelineBaseConfig(RBLNModelConfig):
23
+ """
24
+ Base configuration for Stable Diffusion pipelines.
25
+ """
26
+
23
27
  submodules = ["text_encoder", "unet", "vae"]
24
28
  _vae_uses_encoder = False
25
29
 
@@ -37,7 +41,7 @@ class _RBLNStableDiffusionPipelineBaseConfig(RBLNModelConfig):
37
41
  sample_size: Optional[Tuple[int, int]] = None,
38
42
  image_size: Optional[Tuple[int, int]] = None,
39
43
  guidance_scale: Optional[float] = None,
40
- **kwargs,
44
+ **kwargs: Dict[str, Any],
41
45
  ):
42
46
  """
43
47
  Args:
@@ -128,13 +132,25 @@ class _RBLNStableDiffusionPipelineBaseConfig(RBLNModelConfig):
128
132
  return self.vae.sample_size
129
133
 
130
134
 
131
- class RBLNStableDiffusionPipelineConfig(_RBLNStableDiffusionPipelineBaseConfig):
135
+ class RBLNStableDiffusionPipelineConfig(RBLNStableDiffusionPipelineBaseConfig):
136
+ """
137
+ Configuration for Stable Diffusion pipeline.
138
+ """
139
+
132
140
  _vae_uses_encoder = False
133
141
 
134
142
 
135
- class RBLNStableDiffusionImg2ImgPipelineConfig(_RBLNStableDiffusionPipelineBaseConfig):
143
+ class RBLNStableDiffusionImg2ImgPipelineConfig(RBLNStableDiffusionPipelineBaseConfig):
144
+ """
145
+ Configuration for Stable Diffusion image-to-image pipeline.
146
+ """
147
+
136
148
  _vae_uses_encoder = True
137
149
 
138
150
 
139
- class RBLNStableDiffusionInpaintPipelineConfig(_RBLNStableDiffusionPipelineBaseConfig):
151
+ class RBLNStableDiffusionInpaintPipelineConfig(RBLNStableDiffusionPipelineBaseConfig):
152
+ """
153
+ Configuration for Stable Diffusion inpainting pipeline.
154
+ """
155
+
140
156
  _vae_uses_encoder = True