optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. optimum/rbln/__init__.py +116 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +171 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +33 -18
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +50 -24
  52. optimum/rbln/modeling_base.py +116 -35
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +100 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +93 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  158. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  159. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  160. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  161. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  162. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  163. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  164. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
  165. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  166. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  167. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  168. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  169. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  170. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  171. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  172. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  173. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  174. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  175. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  176. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
  177. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  178. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  179. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  180. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  181. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  182. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  183. optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
  184. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  185. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  186. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  187. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  188. optimum/rbln/utils/deprecation.py +213 -0
  189. optimum/rbln/utils/hub.py +22 -50
  190. optimum/rbln/utils/runtime_utils.py +85 -17
  191. optimum/rbln/utils/submodule.py +31 -9
  192. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  193. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  194. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  195. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  196. optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
  197. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -21,8 +21,10 @@ from typing import Any, Dict, List, Optional, Protocol, Tuple, Type, Union, runt
21
21
 
22
22
  import numpy as np
23
23
  import torch
24
+ from packaging.version import Version
24
25
 
25
26
  from .__version__ import __version__
27
+ from .utils.deprecation import warn_deprecated_npu
26
28
  from .utils.logging import get_logger
27
29
  from .utils.runtime_utils import ContextRblnConfig
28
30
 
@@ -31,7 +33,6 @@ logger = get_logger(__name__)
31
33
 
32
34
 
33
35
  DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
34
- DEFAULT_MOD_NAME = "default"
35
36
  TypeInputInfo = List[Tuple[str, Tuple[int], str]]
36
37
 
37
38
 
@@ -39,6 +40,9 @@ TypeInputInfo = List[Tuple[str, Tuple[int], str]]
39
40
  class RBLNSerializableConfigProtocol(Protocol):
40
41
  def _prepare_for_serialization(self) -> Dict[str, Any]: ...
41
42
 
43
+ def __repr__(self) -> str:
44
+ return f"{self.__class__.__name__}({self._prepare_for_serialization()})"
45
+
42
46
 
43
47
  @dataclass
44
48
  class RBLNCompileConfig:
@@ -47,17 +51,13 @@ class RBLNCompileConfig:
47
51
 
48
52
  Attributes:
49
53
  compiled_model_name (str): Name of the compiled model.
50
- mod_name (str): Name of the RBLN module.
51
54
  input_info (Union[List[TypeInputInfo], TypeInputInfo]): Information about input tensors.
52
- fusion (Optional[bool]): Whether to use fusion optimization.
53
55
  npu (Optional[str]): NPU configuration.
54
56
  tensor_parallel_size (Optional[int]): Size for tensor parallelism.
55
57
  """
56
58
 
57
59
  compiled_model_name: str = DEFAULT_COMPILED_MODEL_NAME
58
- mod_name: str = DEFAULT_MOD_NAME
59
60
  input_info: Union[List[TypeInputInfo], TypeInputInfo] = None
60
- fusion: Optional[bool] = None
61
61
  npu: Optional[str] = None
62
62
  tensor_parallel_size: Optional[int] = None
63
63
 
@@ -111,9 +111,7 @@ class RBLNCompileConfig:
111
111
 
112
112
  def update(self, kwargs: Dict[str, Any]):
113
113
  self.compiled_model_name = kwargs.get("compiled_model_name", self.compiled_model_name)
114
- self.mod_name = kwargs.get("mod_name", self.mod_name)
115
114
  self.input_info = kwargs.get("input_info", self.input_info)
116
- self.fusion = kwargs.get("fusion", self.fusion)
117
115
  self.npu = kwargs.get("npu", self.npu)
118
116
  self.tensor_parallel_size = kwargs.get("tensor_parallel_size", self.tensor_parallel_size)
119
117
  return self
@@ -147,7 +145,7 @@ class RBLNCompileConfig:
147
145
  return asdict(self)
148
146
 
149
147
 
150
- RUNTIME_KEYWORDS = ["create_runtimes", "optimize_host_memory", "device", "device_map", "activate_profiler"]
148
+ RUNTIME_KEYWORDS = ["create_runtimes", "device", "device_map", "activate_profiler", "timeout"]
151
149
  CONFIG_MAPPING: Dict[str, Type["RBLNModelConfig"]] = {}
152
150
 
153
151
 
@@ -183,6 +181,15 @@ def load_config(path: str) -> Tuple[Type["RBLNModelConfig"], Dict[str, Any]]:
183
181
 
184
182
 
185
183
  class RBLNAutoConfig:
184
+ """
185
+ Resolver and factory for RBLN model configurations.
186
+
187
+ This class selects the concrete `RBLNModelConfig` subclass, validates the
188
+ provided data, and returns a frozen configuration object that serves as the
189
+ single source of truth during export and load. It does not define the schema
190
+ or control model behavior.
191
+ """
192
+
186
193
  def __new__(cls, **kwargs):
187
194
  cls_name = kwargs.get("cls_name")
188
195
  if cls_name is None:
@@ -192,6 +199,33 @@ class RBLNAutoConfig:
192
199
 
193
200
  @staticmethod
194
201
  def load_from_dict(config_dict: Dict[str, Any]) -> "RBLNModelConfig":
202
+ """
203
+ Build a `RBLNModelConfig` from a plain dictionary.
204
+
205
+ The dictionary must contain `cls_name`, which identifies the concrete
206
+ configuration class to instantiate. All other keys are forwarded to the
207
+ target class initializer. This method does not mutate `config_dict`.
208
+
209
+ Args:
210
+ config_dict: Mapping typically created by `json.load` or `yaml.safe_load`.
211
+ For example, the parsed contents of `rbln_config.json`.
212
+
213
+ Returns:
214
+ RBLNModelConfig: A configuration instance. The specific subclass is
215
+ selected by `config_dict["cls_name"]`.
216
+
217
+ Raises:
218
+ ValueError: If `cls_name` is missing.
219
+ Exception: Any error raised by the target config class during init.
220
+
221
+ Examples:
222
+ >>> data = {
223
+ ... "cls_name": "RBLNLlamaForCausalLMConfig",
224
+ ... "create_runtimes": False,
225
+ ... "tensor_parallel_size": 4
226
+ ... }
227
+ >>> cfg = RBLNAutoConfig.load_from_dict(data)
228
+ """
195
229
  cls_name = config_dict.get("cls_name")
196
230
  if cls_name is None:
197
231
  raise ValueError("`cls_name` is required.")
@@ -204,7 +238,8 @@ class RBLNAutoConfig:
204
238
  Register a new configuration for this class.
205
239
 
206
240
  Args:
207
- config ([`RBLNModelConfig`]): The config to register.
241
+ config (RBLNModelConfig): The config to register.
242
+ exist_ok (bool): Whether to allow registering an already registered model.
208
243
  """
209
244
  if not issubclass(config, RBLNModelConfig):
210
245
  raise ValueError("`config` must be a subclass of RBLNModelConfig.")
@@ -246,9 +281,6 @@ class RBLNAutoConfig:
246
281
  if key[5:] not in RUNTIME_KEYWORDS and key[5:] not in cls.submodules
247
282
  }
248
283
 
249
- if len(rbln_kwargs) > 0:
250
- raise ValueError(f"Cannot set the following arguments: {list(rbln_kwargs.keys())}")
251
-
252
284
  # Process submodule's rbln_config
253
285
  for submodule in cls.submodules:
254
286
  if submodule not in config_file:
@@ -263,6 +295,16 @@ class RBLNAutoConfig:
263
295
 
264
296
  config_file.update(rbln_runtime_kwargs)
265
297
 
298
+ rbln_config = cls(**config_file)
299
+
300
+ if len(rbln_kwargs) > 0:
301
+ for key, value in rbln_kwargs.items():
302
+ if getattr(rbln_config, key) != value:
303
+ raise ValueError(
304
+ f"Cannot set the following arguments: {list(rbln_kwargs.keys())} "
305
+ f"Since the value is already set to {getattr(rbln_config, key)}"
306
+ )
307
+
266
308
  if return_unused_kwargs:
267
309
  return cls(**config_file), kwargs
268
310
  else:
@@ -273,6 +315,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
273
315
  """Base configuration class for RBLN models that handles compilation settings, runtime options, and submodules.
274
316
 
275
317
  This class provides functionality for:
318
+
276
319
  1. Managing compilation configurations for RBLN devices
277
320
  2. Configuring runtime behavior such as device placement
278
321
  3. Handling nested configuration objects for complex model architectures
@@ -474,29 +517,31 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
474
517
  non_save_attributes = [
475
518
  "_frozen",
476
519
  "_runtime_options",
520
+ "torch_dtype",
477
521
  "npu",
478
522
  "tensor_parallel_size",
479
523
  "create_runtimes",
480
- "optimize_host_memory",
481
524
  "device",
482
525
  "device_map",
483
526
  "activate_profiler",
527
+ "timeout",
484
528
  ]
485
529
  submodules: List[str] = []
486
530
  subclass_non_save_attributes = []
531
+ _allow_no_compile_cfgs = False
487
532
 
488
- def init_submodule_config(
533
+ def initialize_submodule_config(
489
534
  self,
490
- submodule_config_cls: Type["RBLNModelConfig"],
491
535
  submodule_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
492
- **kwargs: Dict[str, Any],
536
+ force_kwargs: bool = False,
537
+ **kwargs: Any,
493
538
  ) -> "RBLNModelConfig":
494
- # Initialize a submodule config from a dict or a RBLNModelConfig.
495
- # kwargs is specified from the predecessor config.
496
-
497
539
  if submodule_config is None:
498
540
  submodule_config = {}
499
541
 
542
+ if isinstance(submodule_config, RBLNModelConfig):
543
+ return submodule_config
544
+
500
545
  if isinstance(submodule_config, dict):
501
546
  from_predecessor = self._runtime_options.copy()
502
547
  from_predecessor.update(
@@ -510,13 +555,60 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
510
555
 
511
556
  init_kwargs = from_predecessor
512
557
  init_kwargs.update(submodule_config)
513
- submodule_config = submodule_config_cls(**init_kwargs)
514
558
 
515
- if not isinstance(submodule_config, submodule_config_cls):
559
+ if force_kwargs:
560
+ for key, value in kwargs.items():
561
+ if key in init_kwargs:
562
+ if init_kwargs[key] != value:
563
+ raise ValueError(
564
+ f"Parameter conflict for '{key}': submodule_config has {init_kwargs[key]}, "
565
+ f"but kwargs has {value}. Using kwargs value: {value}"
566
+ )
567
+ init_kwargs[key] = value
568
+
569
+ if "cls_name" in init_kwargs:
570
+ config_cls = get_rbln_config_class(init_kwargs["cls_name"])
571
+ else:
572
+ return init_kwargs
573
+
574
+ submodule_config = config_cls(**init_kwargs)
575
+
576
+ if not isinstance(submodule_config, RBLNModelConfig):
516
577
  raise TypeError(f"Invalid submodule config type: {type(submodule_config)}")
517
578
 
518
579
  return submodule_config
519
580
 
581
+ def filter_parameters(self, config_cls: Type["RBLNModelConfig"], parameters: Dict[str, Any]) -> Dict[str, Any]:
582
+ import importlib
583
+
584
+ model_cls_name = config_cls.__name__.replace("Config", "")
585
+ modeling_module_name = config_cls.__module__.replace("configuration_", "modeling_")
586
+
587
+ model_cls = None
588
+ try:
589
+ modeling_module = importlib.import_module(modeling_module_name)
590
+ if hasattr(modeling_module, model_cls_name):
591
+ model_cls = getattr(modeling_module, model_cls_name)
592
+ except ImportError:
593
+ logger.debug(f"Could not import modeling module: {modeling_module_name}")
594
+
595
+ filtered_out_params = set()
596
+
597
+ if model_cls is not None:
598
+ if not getattr(model_cls, "_tp_support", False):
599
+ filtered_out_params.add("tensor_parallel_size")
600
+
601
+ filtered_params = {}
602
+ for key, value in parameters.items():
603
+ if key in filtered_out_params:
604
+ logger.debug(
605
+ f"Parameter '{key}' filtered out for {config_cls.__name__} (not supported by model flags)."
606
+ )
607
+ else:
608
+ filtered_params[key] = value
609
+
610
+ return filtered_params
611
+
520
612
  def __setattr__(self, key, value):
521
613
  if (
522
614
  key != "_attributes_map"
@@ -555,15 +647,18 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
555
647
  self,
556
648
  cls_name: Optional[str] = None,
557
649
  create_runtimes: Optional[bool] = None,
558
- optimize_host_memory: Optional[bool] = None,
559
650
  device: Optional[Union[int, List[int]]] = None,
560
651
  device_map: Optional[Dict[str, Union[int, List[int]]]] = None,
561
652
  activate_profiler: Optional[bool] = None,
562
653
  npu: Optional[str] = None,
563
654
  tensor_parallel_size: Optional[int] = None,
655
+ timeout: Optional[int] = None,
564
656
  optimum_rbln_version: Optional[str] = None,
657
+ _torch_dtype: Optional[str] = None,
565
658
  _compile_cfgs: List[RBLNCompileConfig] = [],
566
- **kwargs: Dict[str, Any],
659
+ *,
660
+ optimize_host_memory: Optional[bool] = None,
661
+ **kwargs: Any,
567
662
  ):
568
663
  """
569
664
  Initialize a RBLN model configuration with runtime options and compile configurations.
@@ -571,15 +666,16 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
571
666
  Args:
572
667
  cls_name (Optional[str]): The class name of the configuration. Defaults to the current class name.
573
668
  create_runtimes (Optional[bool]): Whether to create RBLN runtimes. Defaults to True.
574
- optimize_host_memory (Optional[bool]): Whether to optimize host memory usage. Defaults to True.
575
669
  device (Optional[Union[int, List[int]]]): The device(s) to load the model onto. Can be a single device ID or a list.
576
670
  device_map (Optional[Dict[str, Union[int, List[int]]]]): Mapping from compiled model names to device IDs.
577
671
  activate_profiler (Optional[bool]): Whether to activate the profiler for performance analysis.
578
672
  npu (Optional[str]): The NPU device name to use for compilation.
579
673
  tensor_parallel_size (Optional[int]): Size for tensor parallelism to distribute the model across devices.
674
+ timeout (Optional[int]): The timeout for the runtime in seconds. If it isn't provided, it will be set to 60 by default.
580
675
  optimum_rbln_version (Optional[str]): The optimum-rbln version used for this configuration.
676
+ _torch_dtype (Optional[str]): The data type to use for the model.
581
677
  _compile_cfgs (List[RBLNCompileConfig]): List of compilation configurations for the model.
582
- **kwargs: Additional keyword arguments.
678
+ kwargs: Additional keyword arguments.
583
679
 
584
680
  Raises:
585
681
  ValueError: If unexpected keyword arguments are provided.
@@ -595,15 +691,19 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
595
691
 
596
692
  self._runtime_options = {}
597
693
  self._runtime_options["create_runtimes"] = create_runtimes
598
- self._runtime_options["optimize_host_memory"] = optimize_host_memory
599
694
  self._runtime_options["device"] = device
600
695
  self._runtime_options["device_map"] = device_map
601
696
  self._runtime_options["activate_profiler"] = activate_profiler
697
+ self._runtime_options["timeout"] = timeout
698
+
699
+ if optimize_host_memory is not None:
700
+ logger.warning("`optimize_host_memory` is deprecated and will be removed in future versions.")
602
701
 
603
702
  # Automatically pass npu, tensor_parallel_size to compile_cfgs
604
703
  self.npu = npu
605
704
  self.tensor_parallel_size = tensor_parallel_size
606
705
 
706
+ self._torch_dtype = _torch_dtype or "float32"
607
707
  self.optimum_rbln_version = optimum_rbln_version
608
708
  if self.optimum_rbln_version is None:
609
709
  self.optimum_rbln_version = __version__
@@ -616,8 +716,34 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
616
716
  self.set_compile_cfgs([RBLNCompileConfig(**cfg) for cfg in self._compile_cfgs])
617
717
 
618
718
  if len(kwargs) > 0:
719
+ if optimum_rbln_version is not None: # loaded from file
720
+ if Version(__version__) < Version(optimum_rbln_version):
721
+ diff = "newer"
722
+ elif Version(__version__) > Version(optimum_rbln_version):
723
+ diff = "older"
724
+ else:
725
+ diff = None
726
+ if diff is not None:
727
+ raise ValueError(
728
+ f"Unexpected arguments: {kwargs.keys()}\n"
729
+ f"Maybe you are trying to load a model compiled with {diff} version of optimum-rbln. "
730
+ "It is recommended to use the same version to compile and load the model.\n"
731
+ f"Current version: {__version__}, Loaded version: {optimum_rbln_version}"
732
+ )
733
+
619
734
  raise ValueError(f"Unexpected arguments: {kwargs.keys()}")
620
735
 
736
+ @property
737
+ def torch_dtype(self):
738
+ return getattr(torch, self._torch_dtype)
739
+
740
+ @torch_dtype.setter
741
+ def torch_dtype(self, torch_dtype: Union[str, torch.dtype]):
742
+ if isinstance(torch_dtype, torch.dtype):
743
+ torch_dtype = RBLNCompileConfig.normalize_dtype(torch_dtype)
744
+
745
+ self._torch_dtype = torch_dtype
746
+
621
747
  @property
622
748
  def rbln_model_cls_name(self) -> str:
623
749
  return self.__class__.__name__[:-6]
@@ -671,6 +797,9 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
671
797
  compile_cfg.npu = self.npu
672
798
  compile_cfg.tensor_parallel_size = self.tensor_parallel_size
673
799
 
800
+ target_npu = self.npu or next((cfg.npu for cfg in self._compile_cfgs if cfg.npu is not None), None)
801
+ warn_deprecated_npu(target_npu)
802
+
674
803
  def freeze(self):
675
804
  if self._frozen:
676
805
  raise RuntimeError(f"`{self.__class__.__name__}` is already frozen.")
@@ -680,7 +809,8 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
680
809
  or len(self._compile_cfgs) == 0
681
810
  or not all(isinstance(cfg, RBLNCompileConfig) for cfg in self._compile_cfgs)
682
811
  ):
683
- raise RuntimeError("`compile_cfgs` must be set before freezing.")
812
+ if not self._allow_no_compile_cfgs:
813
+ raise RuntimeError("`compile_cfgs` must contain at least one `RBLNCompileConfig` before freezing.")
684
814
 
685
815
  for submodule_name in self.submodules:
686
816
  submodule_config = getattr(self, submodule_name, None)
@@ -709,13 +839,13 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
709
839
  json.dump(serializable_data, jsonf, indent=2)
710
840
 
711
841
  @classmethod
712
- def load(cls, path: str, **kwargs: Dict[str, Any]) -> "RBLNModelConfig":
842
+ def load(cls, path: str, **kwargs: Any) -> "RBLNModelConfig":
713
843
  """
714
844
  Load a RBLNModelConfig from a path.
715
845
 
716
846
  Args:
717
847
  path (str): Path to the RBLNModelConfig file or directory containing the config file.
718
- **kwargs: Additional keyword arguments to override configuration values.
848
+ kwargs: Additional keyword arguments to override configuration values.
719
849
  Keys starting with 'rbln_' will have the prefix removed and be used
720
850
  to update the configuration.
721
851
 
@@ -742,7 +872,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
742
872
  def initialize_from_kwargs(
743
873
  cls: Type["RBLNModelConfig"],
744
874
  rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
745
- **kwargs: Dict[str, Any],
875
+ **kwargs: Any,
746
876
  ) -> Tuple["RBLNModelConfig", Dict[str, Any]]:
747
877
  # Initialize RBLNModelConfig from kwargs.
748
878
  kwargs_keys = list(kwargs.keys())
@@ -787,19 +917,6 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
787
917
  def create_runtimes(self, create_runtimes: bool):
788
918
  self._runtime_options["create_runtimes"] = create_runtimes
789
919
 
790
- @property
791
- def optimize_host_memory(self):
792
- context = ContextRblnConfig.get_current_context()["optimize_host_memory"]
793
- if context is not None:
794
- return context
795
- elif self._runtime_options["optimize_host_memory"] is None:
796
- return True
797
- return self._runtime_options["optimize_host_memory"]
798
-
799
- @optimize_host_memory.setter
800
- def optimize_host_memory(self, optimize_host_memory: bool):
801
- self._runtime_options["optimize_host_memory"] = optimize_host_memory
802
-
803
920
  @property
804
921
  def device(self):
805
922
  context = ContextRblnConfig.get_current_context()["device"]
@@ -838,3 +955,14 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
838
955
  @activate_profiler.setter
839
956
  def activate_profiler(self, activate_profiler: bool):
840
957
  self._runtime_options["activate_profiler"] = activate_profiler
958
+
959
+ @property
960
+ def timeout(self):
961
+ context = ContextRblnConfig.get_current_context()["timeout"]
962
+ if context is not None:
963
+ return context
964
+ return self._runtime_options["timeout"]
965
+
966
+ @timeout.setter
967
+ def timeout(self, timeout: int):
968
+ self._runtime_options["timeout"] = timeout
@@ -57,8 +57,14 @@ _import_structure = {
57
57
  "RBLNSD3Transformer2DModelConfig",
58
58
  "RBLNUNet2DConditionModelConfig",
59
59
  "RBLNVQModelConfig",
60
+ "RBLNUNetSpatioTemporalConditionModelConfig",
61
+ "RBLNStableVideoDiffusionPipelineConfig",
62
+ "RBLNAutoencoderKLTemporalDecoderConfig",
60
63
  ],
61
64
  "pipelines": [
65
+ "RBLNAutoPipelineForImage2Image",
66
+ "RBLNAutoPipelineForInpainting",
67
+ "RBLNAutoPipelineForText2Image",
62
68
  "RBLNCosmosTextToWorldPipeline",
63
69
  "RBLNCosmosVideoToWorldPipeline",
64
70
  "RBLNCosmosSafetyChecker",
@@ -83,14 +89,17 @@ _import_structure = {
83
89
  "RBLNStableDiffusion3Pipeline",
84
90
  "RBLNStableDiffusion3Img2ImgPipeline",
85
91
  "RBLNStableDiffusion3InpaintPipeline",
92
+ "RBLNStableVideoDiffusionPipeline",
86
93
  ],
87
94
  "models": [
88
95
  "RBLNAutoencoderKL",
89
96
  "RBLNAutoencoderKLCosmos",
90
97
  "RBLNUNet2DConditionModel",
98
+ "RBLNUNetSpatioTemporalConditionModel",
91
99
  "RBLNControlNetModel",
92
100
  "RBLNCosmosTransformer3DModel",
93
101
  "RBLNSD3Transformer2DModel",
102
+ "RBLNAutoencoderKLTemporalDecoder",
94
103
  "RBLNPriorTransformer",
95
104
  "RBLNVQModel",
96
105
  ],
@@ -103,6 +112,7 @@ if TYPE_CHECKING:
103
112
  from .configurations import (
104
113
  RBLNAutoencoderKLConfig,
105
114
  RBLNAutoencoderKLCosmosConfig,
115
+ RBLNAutoencoderKLTemporalDecoderConfig,
106
116
  RBLNControlNetModelConfig,
107
117
  RBLNCosmosTextToWorldPipelineConfig,
108
118
  RBLNCosmosTransformer3DModelConfig,
@@ -129,20 +139,28 @@ if TYPE_CHECKING:
129
139
  RBLNStableDiffusionXLImg2ImgPipelineConfig,
130
140
  RBLNStableDiffusionXLInpaintPipelineConfig,
131
141
  RBLNStableDiffusionXLPipelineConfig,
142
+ RBLNStableVideoDiffusionPipelineConfig,
132
143
  RBLNUNet2DConditionModelConfig,
144
+ RBLNUNetSpatioTemporalConditionModelConfig,
133
145
  RBLNVQModelConfig,
134
146
  )
135
147
  from .modeling_diffusers import RBLNDiffusionMixin
136
148
  from .models import (
137
149
  RBLNAutoencoderKL,
150
+ RBLNAutoencoderKLCosmos,
151
+ RBLNAutoencoderKLTemporalDecoder,
138
152
  RBLNControlNetModel,
139
153
  RBLNCosmosTransformer3DModel,
140
154
  RBLNPriorTransformer,
141
155
  RBLNSD3Transformer2DModel,
142
156
  RBLNUNet2DConditionModel,
157
+ RBLNUNetSpatioTemporalConditionModel,
143
158
  RBLNVQModel,
144
159
  )
145
160
  from .pipelines import (
161
+ RBLNAutoPipelineForImage2Image,
162
+ RBLNAutoPipelineForInpainting,
163
+ RBLNAutoPipelineForText2Image,
146
164
  RBLNCosmosSafetyChecker,
147
165
  RBLNCosmosTextToWorldPipeline,
148
166
  RBLNCosmosVideoToWorldPipeline,
@@ -167,6 +185,7 @@ if TYPE_CHECKING:
167
185
  RBLNStableDiffusionXLImg2ImgPipeline,
168
186
  RBLNStableDiffusionXLInpaintPipeline,
169
187
  RBLNStableDiffusionXLPipeline,
188
+ RBLNStableVideoDiffusionPipeline,
170
189
  )
171
190
  else:
172
191
  import sys
@@ -1,11 +1,13 @@
1
1
  from .models import (
2
2
  RBLNAutoencoderKLConfig,
3
3
  RBLNAutoencoderKLCosmosConfig,
4
+ RBLNAutoencoderKLTemporalDecoderConfig,
4
5
  RBLNControlNetModelConfig,
5
6
  RBLNCosmosTransformer3DModelConfig,
6
7
  RBLNPriorTransformerConfig,
7
8
  RBLNSD3Transformer2DModelConfig,
8
9
  RBLNUNet2DConditionModelConfig,
10
+ RBLNUNetSpatioTemporalConditionModelConfig,
9
11
  RBLNVQModelConfig,
10
12
  )
11
13
  from .pipelines import (
@@ -31,4 +33,5 @@ from .pipelines import (
31
33
  RBLNStableDiffusionXLImg2ImgPipelineConfig,
32
34
  RBLNStableDiffusionXLInpaintPipelineConfig,
33
35
  RBLNStableDiffusionXLPipelineConfig,
36
+ RBLNStableVideoDiffusionPipelineConfig,
34
37
  )
@@ -1,8 +1,10 @@
1
1
  from .configuration_autoencoder_kl import RBLNAutoencoderKLConfig
2
2
  from .configuration_autoencoder_kl_cosmos import RBLNAutoencoderKLCosmosConfig
3
+ from .configuration_autoencoder_kl_temporal_decoder import RBLNAutoencoderKLTemporalDecoderConfig
3
4
  from .configuration_controlnet import RBLNControlNetModelConfig
4
5
  from .configuration_prior_transformer import RBLNPriorTransformerConfig
5
6
  from .configuration_transformer_cosmos import RBLNCosmosTransformer3DModelConfig
6
7
  from .configuration_transformer_sd3 import RBLNSD3Transformer2DModelConfig
7
8
  from .configuration_unet_2d_condition import RBLNUNet2DConditionModelConfig
9
+ from .configuration_unet_spatio_temporal_condition import RBLNUNetSpatioTemporalConditionModelConfig
8
10
  from .configuration_vq_model import RBLNVQModelConfig
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -33,7 +33,7 @@ class RBLNAutoencoderKLConfig(RBLNModelConfig):
33
33
  vae_scale_factor: Optional[float] = None, # TODO: rename to scaling_factor
34
34
  in_channels: Optional[int] = None,
35
35
  latent_channels: Optional[int] = None,
36
- **kwargs: Dict[str, Any],
36
+ **kwargs: Any,
37
37
  ):
38
38
  """
39
39
  Args:
@@ -46,7 +46,7 @@ class RBLNAutoencoderKLConfig(RBLNModelConfig):
46
46
  Determines how much smaller the latent representations are compared to the original images.
47
47
  in_channels (Optional[int]): Number of input channels for the model.
48
48
  latent_channels (Optional[int]): Number of channels in the latent space.
49
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
49
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
50
50
 
51
51
  Raises:
52
52
  ValueError: If batch_size is not a positive integer.
@@ -52,7 +52,7 @@ class RBLNAutoencoderKLCosmosConfig(RBLNModelConfig):
52
52
  Determines how much smaller the latent representations are compared to the original videos.
53
53
  use_slicing (Optional[bool]): Enable sliced VAE encoding and decoding.
54
54
  If True, the VAE will split the input tensor in slices to compute encoding or decoding in several steps.
55
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
55
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
56
56
 
57
57
  Raises:
58
58
  ValueError: If batch_size is not a positive integer.
@@ -0,0 +1,67 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional, Tuple
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNAutoencoderKLTemporalDecoderConfig(RBLNModelConfig):
21
+ def __init__(
22
+ self,
23
+ batch_size: Optional[int] = None,
24
+ sample_size: Optional[Tuple[int, int]] = None,
25
+ uses_encoder: Optional[bool] = None,
26
+ num_frames: Optional[int] = None,
27
+ decode_chunk_size: Optional[int] = None,
28
+ vae_scale_factor: Optional[float] = None,
29
+ **kwargs: Any,
30
+ ):
31
+ """
32
+ Args:
33
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
34
+ sample_size (Optional[Tuple[int, int]]): The spatial dimensions (height, width) of the input/output images.
35
+ If an integer is provided, it's used for both height and width.
36
+ uses_encoder (Optional[bool]): Whether to include the encoder part of the VAE in the model.
37
+ When False, only the decoder is used (for latent-to-image conversion).
38
+ num_frames (Optional[int]): The number of frames in the generated video.
39
+ decode_chunk_size (Optional[int]): The number of frames to decode at once during VAE decoding.
40
+ Useful for managing memory usage during video generation.
41
+ vae_scale_factor (Optional[float]): The scaling factor between pixel space and latent space.
42
+ Determines how much smaller the latent representations are compared to the original images.
43
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
44
+
45
+ Raises:
46
+ ValueError: If batch_size is not a positive integer.
47
+ """
48
+ super().__init__(**kwargs)
49
+ self.batch_size = batch_size or 1
50
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
51
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
52
+
53
+ self.uses_encoder = uses_encoder
54
+ self.num_frames = num_frames
55
+ self.decode_chunk_size = decode_chunk_size
56
+ self.vae_scale_factor = vae_scale_factor
57
+ self.sample_size = sample_size
58
+ if isinstance(sample_size, int):
59
+ self.sample_size = (sample_size, sample_size)
60
+
61
+ @property
62
+ def image_size(self):
63
+ return self.sample_size
64
+
65
+ @property
66
+ def latent_sample_size(self):
67
+ return (self.image_size[0] // self.vae_scale_factor, self.image_size[1] // self.vae_scale_factor)
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -29,7 +29,7 @@ class RBLNControlNetModelConfig(RBLNModelConfig):
29
29
  unet_sample_size: Optional[Tuple[int, int]] = None,
30
30
  vae_sample_size: Optional[Tuple[int, int]] = None,
31
31
  text_model_hidden_size: Optional[int] = None,
32
- **kwargs: Dict[str, Any],
32
+ **kwargs: Any,
33
33
  ):
34
34
  """
35
35
  Args:
@@ -42,7 +42,7 @@ class RBLNControlNetModelConfig(RBLNModelConfig):
42
42
  of the VAE input/output images.
43
43
  text_model_hidden_size (Optional[int]): Hidden size of the text encoder model used
44
44
  for conditioning.
45
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
45
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
46
46
 
47
47
  Raises:
48
48
  ValueError: If batch_size is not a positive integer.