optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +96 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +153 -42
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/modeling_diffusers.py +30 -14
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
- optimum/rbln/diffusers/pipelines/__init__.py +11 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +71 -19
- optimum/rbln/modeling_base.py +99 -21
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +92 -0
- optimum/rbln/transformers/configuration_generic.py +9 -7
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/modeling_generic.py +51 -9
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +91 -30
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
- optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/depreacate_utils.py +16 -0
- optimum/rbln/utils/runtime_utils.py +28 -18
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
- optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
|
@@ -21,8 +21,10 @@ from typing import Any, Dict, List, Optional, Protocol, Tuple, Type, Union, runt
|
|
|
21
21
|
|
|
22
22
|
import numpy as np
|
|
23
23
|
import torch
|
|
24
|
+
from packaging.version import Version
|
|
24
25
|
|
|
25
26
|
from .__version__ import __version__
|
|
27
|
+
from .utils.depreacate_utils import warn_deprecated_npu
|
|
26
28
|
from .utils.logging import get_logger
|
|
27
29
|
from .utils.runtime_utils import ContextRblnConfig
|
|
28
30
|
|
|
@@ -31,7 +33,6 @@ logger = get_logger(__name__)
|
|
|
31
33
|
|
|
32
34
|
|
|
33
35
|
DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
|
|
34
|
-
DEFAULT_MOD_NAME = "default"
|
|
35
36
|
TypeInputInfo = List[Tuple[str, Tuple[int], str]]
|
|
36
37
|
|
|
37
38
|
|
|
@@ -39,6 +40,9 @@ TypeInputInfo = List[Tuple[str, Tuple[int], str]]
|
|
|
39
40
|
class RBLNSerializableConfigProtocol(Protocol):
|
|
40
41
|
def _prepare_for_serialization(self) -> Dict[str, Any]: ...
|
|
41
42
|
|
|
43
|
+
def __repr__(self) -> str:
|
|
44
|
+
return f"{self.__class__.__name__}({self._prepare_for_serialization()})"
|
|
45
|
+
|
|
42
46
|
|
|
43
47
|
@dataclass
|
|
44
48
|
class RBLNCompileConfig:
|
|
@@ -47,17 +51,13 @@ class RBLNCompileConfig:
|
|
|
47
51
|
|
|
48
52
|
Attributes:
|
|
49
53
|
compiled_model_name (str): Name of the compiled model.
|
|
50
|
-
mod_name (str): Name of the RBLN module.
|
|
51
54
|
input_info (Union[List[TypeInputInfo], TypeInputInfo]): Information about input tensors.
|
|
52
|
-
fusion (Optional[bool]): Whether to use fusion optimization.
|
|
53
55
|
npu (Optional[str]): NPU configuration.
|
|
54
56
|
tensor_parallel_size (Optional[int]): Size for tensor parallelism.
|
|
55
57
|
"""
|
|
56
58
|
|
|
57
59
|
compiled_model_name: str = DEFAULT_COMPILED_MODEL_NAME
|
|
58
|
-
mod_name: str = DEFAULT_MOD_NAME
|
|
59
60
|
input_info: Union[List[TypeInputInfo], TypeInputInfo] = None
|
|
60
|
-
fusion: Optional[bool] = None
|
|
61
61
|
npu: Optional[str] = None
|
|
62
62
|
tensor_parallel_size: Optional[int] = None
|
|
63
63
|
|
|
@@ -111,9 +111,7 @@ class RBLNCompileConfig:
|
|
|
111
111
|
|
|
112
112
|
def update(self, kwargs: Dict[str, Any]):
|
|
113
113
|
self.compiled_model_name = kwargs.get("compiled_model_name", self.compiled_model_name)
|
|
114
|
-
self.mod_name = kwargs.get("mod_name", self.mod_name)
|
|
115
114
|
self.input_info = kwargs.get("input_info", self.input_info)
|
|
116
|
-
self.fusion = kwargs.get("fusion", self.fusion)
|
|
117
115
|
self.npu = kwargs.get("npu", self.npu)
|
|
118
116
|
self.tensor_parallel_size = kwargs.get("tensor_parallel_size", self.tensor_parallel_size)
|
|
119
117
|
return self
|
|
@@ -147,7 +145,7 @@ class RBLNCompileConfig:
|
|
|
147
145
|
return asdict(self)
|
|
148
146
|
|
|
149
147
|
|
|
150
|
-
RUNTIME_KEYWORDS = ["create_runtimes", "
|
|
148
|
+
RUNTIME_KEYWORDS = ["create_runtimes", "device", "device_map", "activate_profiler", "timeout"]
|
|
151
149
|
CONFIG_MAPPING: Dict[str, Type["RBLNModelConfig"]] = {}
|
|
152
150
|
|
|
153
151
|
|
|
@@ -183,6 +181,15 @@ def load_config(path: str) -> Tuple[Type["RBLNModelConfig"], Dict[str, Any]]:
|
|
|
183
181
|
|
|
184
182
|
|
|
185
183
|
class RBLNAutoConfig:
|
|
184
|
+
"""
|
|
185
|
+
Resolver and factory for RBLN model configurations.
|
|
186
|
+
|
|
187
|
+
This class selects the concrete `RBLNModelConfig` subclass, validates the
|
|
188
|
+
provided data, and returns a frozen configuration object that serves as the
|
|
189
|
+
single source of truth during export and load. It does not define the schema
|
|
190
|
+
or control model behavior.
|
|
191
|
+
"""
|
|
192
|
+
|
|
186
193
|
def __new__(cls, **kwargs):
|
|
187
194
|
cls_name = kwargs.get("cls_name")
|
|
188
195
|
if cls_name is None:
|
|
@@ -192,6 +199,33 @@ class RBLNAutoConfig:
|
|
|
192
199
|
|
|
193
200
|
@staticmethod
|
|
194
201
|
def load_from_dict(config_dict: Dict[str, Any]) -> "RBLNModelConfig":
|
|
202
|
+
"""
|
|
203
|
+
Build a `RBLNModelConfig` from a plain dictionary.
|
|
204
|
+
|
|
205
|
+
The dictionary must contain `cls_name`, which identifies the concrete
|
|
206
|
+
configuration class to instantiate. All other keys are forwarded to the
|
|
207
|
+
target class initializer. This method does not mutate `config_dict`.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
config_dict: Mapping typically created by `json.load` or `yaml.safe_load`.
|
|
211
|
+
For example, the parsed contents of `rbln_config.json`.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
RBLNModelConfig: A configuration instance. The specific subclass is
|
|
215
|
+
selected by `config_dict["cls_name"]`.
|
|
216
|
+
|
|
217
|
+
Raises:
|
|
218
|
+
ValueError: If `cls_name` is missing.
|
|
219
|
+
Exception: Any error raised by the target config class during init.
|
|
220
|
+
|
|
221
|
+
Examples:
|
|
222
|
+
>>> data = {
|
|
223
|
+
... "cls_name": "RBLNLlamaForCausalLMConfig",
|
|
224
|
+
... "create_runtimes": False,
|
|
225
|
+
... "tensor_parallel_size": 4
|
|
226
|
+
... }
|
|
227
|
+
>>> cfg = RBLNAutoConfig.load_from_dict(data)
|
|
228
|
+
"""
|
|
195
229
|
cls_name = config_dict.get("cls_name")
|
|
196
230
|
if cls_name is None:
|
|
197
231
|
raise ValueError("`cls_name` is required.")
|
|
@@ -204,7 +238,8 @@ class RBLNAutoConfig:
|
|
|
204
238
|
Register a new configuration for this class.
|
|
205
239
|
|
|
206
240
|
Args:
|
|
207
|
-
config (
|
|
241
|
+
config (RBLNModelConfig): The config to register.
|
|
242
|
+
exist_ok (bool): Whether to allow registering an already registered model.
|
|
208
243
|
"""
|
|
209
244
|
if not issubclass(config, RBLNModelConfig):
|
|
210
245
|
raise ValueError("`config` must be a subclass of RBLNModelConfig.")
|
|
@@ -246,9 +281,6 @@ class RBLNAutoConfig:
|
|
|
246
281
|
if key[5:] not in RUNTIME_KEYWORDS and key[5:] not in cls.submodules
|
|
247
282
|
}
|
|
248
283
|
|
|
249
|
-
if len(rbln_kwargs) > 0:
|
|
250
|
-
raise ValueError(f"Cannot set the following arguments: {list(rbln_kwargs.keys())}")
|
|
251
|
-
|
|
252
284
|
# Process submodule's rbln_config
|
|
253
285
|
for submodule in cls.submodules:
|
|
254
286
|
if submodule not in config_file:
|
|
@@ -263,6 +295,16 @@ class RBLNAutoConfig:
|
|
|
263
295
|
|
|
264
296
|
config_file.update(rbln_runtime_kwargs)
|
|
265
297
|
|
|
298
|
+
rbln_config = cls(**config_file)
|
|
299
|
+
|
|
300
|
+
if len(rbln_kwargs) > 0:
|
|
301
|
+
for key, value in rbln_kwargs.items():
|
|
302
|
+
if getattr(rbln_config, key) != value:
|
|
303
|
+
raise ValueError(
|
|
304
|
+
f"Cannot set the following arguments: {list(rbln_kwargs.keys())} "
|
|
305
|
+
f"Since the value is already set to {getattr(rbln_config, key)}"
|
|
306
|
+
)
|
|
307
|
+
|
|
266
308
|
if return_unused_kwargs:
|
|
267
309
|
return cls(**config_file), kwargs
|
|
268
310
|
else:
|
|
@@ -273,6 +315,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
273
315
|
"""Base configuration class for RBLN models that handles compilation settings, runtime options, and submodules.
|
|
274
316
|
|
|
275
317
|
This class provides functionality for:
|
|
318
|
+
|
|
276
319
|
1. Managing compilation configurations for RBLN devices
|
|
277
320
|
2. Configuring runtime behavior such as device placement
|
|
278
321
|
3. Handling nested configuration objects for complex model architectures
|
|
@@ -474,10 +517,10 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
474
517
|
non_save_attributes = [
|
|
475
518
|
"_frozen",
|
|
476
519
|
"_runtime_options",
|
|
520
|
+
"torch_dtype",
|
|
477
521
|
"npu",
|
|
478
522
|
"tensor_parallel_size",
|
|
479
523
|
"create_runtimes",
|
|
480
|
-
"optimize_host_memory",
|
|
481
524
|
"device",
|
|
482
525
|
"device_map",
|
|
483
526
|
"activate_profiler",
|
|
@@ -486,18 +529,18 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
486
529
|
submodules: List[str] = []
|
|
487
530
|
subclass_non_save_attributes = []
|
|
488
531
|
|
|
489
|
-
def
|
|
532
|
+
def initialize_submodule_config(
|
|
490
533
|
self,
|
|
491
|
-
submodule_config_cls: Type["RBLNModelConfig"],
|
|
492
534
|
submodule_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
|
|
493
|
-
|
|
535
|
+
force_kwargs: bool = False,
|
|
536
|
+
**kwargs: Any,
|
|
494
537
|
) -> "RBLNModelConfig":
|
|
495
|
-
# Initialize a submodule config from a dict or a RBLNModelConfig.
|
|
496
|
-
# kwargs is specified from the predecessor config.
|
|
497
|
-
|
|
498
538
|
if submodule_config is None:
|
|
499
539
|
submodule_config = {}
|
|
500
540
|
|
|
541
|
+
if isinstance(submodule_config, RBLNModelConfig):
|
|
542
|
+
return submodule_config
|
|
543
|
+
|
|
501
544
|
if isinstance(submodule_config, dict):
|
|
502
545
|
from_predecessor = self._runtime_options.copy()
|
|
503
546
|
from_predecessor.update(
|
|
@@ -511,13 +554,60 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
511
554
|
|
|
512
555
|
init_kwargs = from_predecessor
|
|
513
556
|
init_kwargs.update(submodule_config)
|
|
514
|
-
submodule_config = submodule_config_cls(**init_kwargs)
|
|
515
557
|
|
|
516
|
-
|
|
558
|
+
if force_kwargs:
|
|
559
|
+
for key, value in kwargs.items():
|
|
560
|
+
if key in init_kwargs:
|
|
561
|
+
if init_kwargs[key] != value:
|
|
562
|
+
raise ValueError(
|
|
563
|
+
f"Parameter conflict for '{key}': submodule_config has {init_kwargs[key]}, "
|
|
564
|
+
f"but kwargs has {value}. Using kwargs value: {value}"
|
|
565
|
+
)
|
|
566
|
+
init_kwargs[key] = value
|
|
567
|
+
|
|
568
|
+
if "cls_name" in init_kwargs:
|
|
569
|
+
config_cls = get_rbln_config_class(init_kwargs["cls_name"])
|
|
570
|
+
else:
|
|
571
|
+
return init_kwargs
|
|
572
|
+
|
|
573
|
+
submodule_config = config_cls(**init_kwargs)
|
|
574
|
+
|
|
575
|
+
if not isinstance(submodule_config, RBLNModelConfig):
|
|
517
576
|
raise TypeError(f"Invalid submodule config type: {type(submodule_config)}")
|
|
518
577
|
|
|
519
578
|
return submodule_config
|
|
520
579
|
|
|
580
|
+
def filter_parameters(self, config_cls: Type["RBLNModelConfig"], parameters: Dict[str, Any]) -> Dict[str, Any]:
|
|
581
|
+
import importlib
|
|
582
|
+
|
|
583
|
+
model_cls_name = config_cls.__name__.replace("Config", "")
|
|
584
|
+
modeling_module_name = config_cls.__module__.replace("configuration_", "modeling_")
|
|
585
|
+
|
|
586
|
+
model_cls = None
|
|
587
|
+
try:
|
|
588
|
+
modeling_module = importlib.import_module(modeling_module_name)
|
|
589
|
+
if hasattr(modeling_module, model_cls_name):
|
|
590
|
+
model_cls = getattr(modeling_module, model_cls_name)
|
|
591
|
+
except ImportError:
|
|
592
|
+
logger.debug(f"Could not import modeling module: {modeling_module_name}")
|
|
593
|
+
|
|
594
|
+
filtered_out_params = set()
|
|
595
|
+
|
|
596
|
+
if model_cls is not None:
|
|
597
|
+
if not getattr(model_cls, "_tp_support", False):
|
|
598
|
+
filtered_out_params.add("tensor_parallel_size")
|
|
599
|
+
|
|
600
|
+
filtered_params = {}
|
|
601
|
+
for key, value in parameters.items():
|
|
602
|
+
if key in filtered_out_params:
|
|
603
|
+
logger.debug(
|
|
604
|
+
f"Parameter '{key}' filtered out for {config_cls.__name__} (not supported by model flags)."
|
|
605
|
+
)
|
|
606
|
+
else:
|
|
607
|
+
filtered_params[key] = value
|
|
608
|
+
|
|
609
|
+
return filtered_params
|
|
610
|
+
|
|
521
611
|
def __setattr__(self, key, value):
|
|
522
612
|
if (
|
|
523
613
|
key != "_attributes_map"
|
|
@@ -556,7 +646,6 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
556
646
|
self,
|
|
557
647
|
cls_name: Optional[str] = None,
|
|
558
648
|
create_runtimes: Optional[bool] = None,
|
|
559
|
-
optimize_host_memory: Optional[bool] = None,
|
|
560
649
|
device: Optional[Union[int, List[int]]] = None,
|
|
561
650
|
device_map: Optional[Dict[str, Union[int, List[int]]]] = None,
|
|
562
651
|
activate_profiler: Optional[bool] = None,
|
|
@@ -564,8 +653,11 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
564
653
|
tensor_parallel_size: Optional[int] = None,
|
|
565
654
|
timeout: Optional[int] = None,
|
|
566
655
|
optimum_rbln_version: Optional[str] = None,
|
|
656
|
+
_torch_dtype: Optional[str] = None,
|
|
567
657
|
_compile_cfgs: List[RBLNCompileConfig] = [],
|
|
568
|
-
|
|
658
|
+
*,
|
|
659
|
+
optimize_host_memory: Optional[bool] = None,
|
|
660
|
+
**kwargs: Any,
|
|
569
661
|
):
|
|
570
662
|
"""
|
|
571
663
|
Initialize a RBLN model configuration with runtime options and compile configurations.
|
|
@@ -573,7 +665,6 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
573
665
|
Args:
|
|
574
666
|
cls_name (Optional[str]): The class name of the configuration. Defaults to the current class name.
|
|
575
667
|
create_runtimes (Optional[bool]): Whether to create RBLN runtimes. Defaults to True.
|
|
576
|
-
optimize_host_memory (Optional[bool]): Whether to optimize host memory usage. Defaults to True.
|
|
577
668
|
device (Optional[Union[int, List[int]]]): The device(s) to load the model onto. Can be a single device ID or a list.
|
|
578
669
|
device_map (Optional[Dict[str, Union[int, List[int]]]]): Mapping from compiled model names to device IDs.
|
|
579
670
|
activate_profiler (Optional[bool]): Whether to activate the profiler for performance analysis.
|
|
@@ -581,8 +672,9 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
581
672
|
tensor_parallel_size (Optional[int]): Size for tensor parallelism to distribute the model across devices.
|
|
582
673
|
timeout (Optional[int]): The timeout for the runtime in seconds. If it isn't provided, it will be set to 60 by default.
|
|
583
674
|
optimum_rbln_version (Optional[str]): The optimum-rbln version used for this configuration.
|
|
675
|
+
_torch_dtype (Optional[str]): The data type to use for the model.
|
|
584
676
|
_compile_cfgs (List[RBLNCompileConfig]): List of compilation configurations for the model.
|
|
585
|
-
|
|
677
|
+
kwargs: Additional keyword arguments.
|
|
586
678
|
|
|
587
679
|
Raises:
|
|
588
680
|
ValueError: If unexpected keyword arguments are provided.
|
|
@@ -598,16 +690,19 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
598
690
|
|
|
599
691
|
self._runtime_options = {}
|
|
600
692
|
self._runtime_options["create_runtimes"] = create_runtimes
|
|
601
|
-
self._runtime_options["optimize_host_memory"] = optimize_host_memory
|
|
602
693
|
self._runtime_options["device"] = device
|
|
603
694
|
self._runtime_options["device_map"] = device_map
|
|
604
695
|
self._runtime_options["activate_profiler"] = activate_profiler
|
|
605
696
|
self._runtime_options["timeout"] = timeout
|
|
606
697
|
|
|
698
|
+
if optimize_host_memory is not None:
|
|
699
|
+
logger.warning("`optimize_host_memory` is deprecated and will be removed in future versions.")
|
|
700
|
+
|
|
607
701
|
# Automatically pass npu, tensor_parallel_size to compile_cfgs
|
|
608
702
|
self.npu = npu
|
|
609
703
|
self.tensor_parallel_size = tensor_parallel_size
|
|
610
704
|
|
|
705
|
+
self._torch_dtype = _torch_dtype or "float32"
|
|
611
706
|
self.optimum_rbln_version = optimum_rbln_version
|
|
612
707
|
if self.optimum_rbln_version is None:
|
|
613
708
|
self.optimum_rbln_version = __version__
|
|
@@ -620,8 +715,34 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
620
715
|
self.set_compile_cfgs([RBLNCompileConfig(**cfg) for cfg in self._compile_cfgs])
|
|
621
716
|
|
|
622
717
|
if len(kwargs) > 0:
|
|
718
|
+
if optimum_rbln_version is not None: # loaded from file
|
|
719
|
+
if Version(__version__) < Version(optimum_rbln_version):
|
|
720
|
+
diff = "newer"
|
|
721
|
+
elif Version(__version__) > Version(optimum_rbln_version):
|
|
722
|
+
diff = "older"
|
|
723
|
+
else:
|
|
724
|
+
diff = None
|
|
725
|
+
if diff is not None:
|
|
726
|
+
raise ValueError(
|
|
727
|
+
f"Unexpected arguments: {kwargs.keys()}\n"
|
|
728
|
+
f"Maybe you are trying to load a model compiled with {diff} version of optimum-rbln. "
|
|
729
|
+
"It is recommended to use the same version to compile and load the model.\n"
|
|
730
|
+
f"Current version: {__version__}, Loaded version: {optimum_rbln_version}"
|
|
731
|
+
)
|
|
732
|
+
|
|
623
733
|
raise ValueError(f"Unexpected arguments: {kwargs.keys()}")
|
|
624
734
|
|
|
735
|
+
@property
|
|
736
|
+
def torch_dtype(self):
|
|
737
|
+
return getattr(torch, self._torch_dtype)
|
|
738
|
+
|
|
739
|
+
@torch_dtype.setter
|
|
740
|
+
def torch_dtype(self, torch_dtype: Union[str, torch.dtype]):
|
|
741
|
+
if isinstance(torch_dtype, torch.dtype):
|
|
742
|
+
torch_dtype = RBLNCompileConfig.normalize_dtype(torch_dtype)
|
|
743
|
+
|
|
744
|
+
self._torch_dtype = torch_dtype
|
|
745
|
+
|
|
625
746
|
@property
|
|
626
747
|
def rbln_model_cls_name(self) -> str:
|
|
627
748
|
return self.__class__.__name__[:-6]
|
|
@@ -675,6 +796,9 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
675
796
|
compile_cfg.npu = self.npu
|
|
676
797
|
compile_cfg.tensor_parallel_size = self.tensor_parallel_size
|
|
677
798
|
|
|
799
|
+
target_npu = self.npu or next((cfg.npu for cfg in self._compile_cfgs if cfg.npu is not None), None)
|
|
800
|
+
warn_deprecated_npu(target_npu)
|
|
801
|
+
|
|
678
802
|
def freeze(self):
|
|
679
803
|
if self._frozen:
|
|
680
804
|
raise RuntimeError(f"`{self.__class__.__name__}` is already frozen.")
|
|
@@ -713,13 +837,13 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
713
837
|
json.dump(serializable_data, jsonf, indent=2)
|
|
714
838
|
|
|
715
839
|
@classmethod
|
|
716
|
-
def load(cls, path: str, **kwargs:
|
|
840
|
+
def load(cls, path: str, **kwargs: Any) -> "RBLNModelConfig":
|
|
717
841
|
"""
|
|
718
842
|
Load a RBLNModelConfig from a path.
|
|
719
843
|
|
|
720
844
|
Args:
|
|
721
845
|
path (str): Path to the RBLNModelConfig file or directory containing the config file.
|
|
722
|
-
|
|
846
|
+
kwargs: Additional keyword arguments to override configuration values.
|
|
723
847
|
Keys starting with 'rbln_' will have the prefix removed and be used
|
|
724
848
|
to update the configuration.
|
|
725
849
|
|
|
@@ -746,7 +870,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
746
870
|
def initialize_from_kwargs(
|
|
747
871
|
cls: Type["RBLNModelConfig"],
|
|
748
872
|
rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
|
|
749
|
-
**kwargs:
|
|
873
|
+
**kwargs: Any,
|
|
750
874
|
) -> Tuple["RBLNModelConfig", Dict[str, Any]]:
|
|
751
875
|
# Initialize RBLNModelConfig from kwargs.
|
|
752
876
|
kwargs_keys = list(kwargs.keys())
|
|
@@ -791,19 +915,6 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
791
915
|
def create_runtimes(self, create_runtimes: bool):
|
|
792
916
|
self._runtime_options["create_runtimes"] = create_runtimes
|
|
793
917
|
|
|
794
|
-
@property
|
|
795
|
-
def optimize_host_memory(self):
|
|
796
|
-
context = ContextRblnConfig.get_current_context()["optimize_host_memory"]
|
|
797
|
-
if context is not None:
|
|
798
|
-
return context
|
|
799
|
-
elif self._runtime_options["optimize_host_memory"] is None:
|
|
800
|
-
return True
|
|
801
|
-
return self._runtime_options["optimize_host_memory"]
|
|
802
|
-
|
|
803
|
-
@optimize_host_memory.setter
|
|
804
|
-
def optimize_host_memory(self, optimize_host_memory: bool):
|
|
805
|
-
self._runtime_options["optimize_host_memory"] = optimize_host_memory
|
|
806
|
-
|
|
807
918
|
@property
|
|
808
919
|
def device(self):
|
|
809
920
|
context = ContextRblnConfig.get_current_context()["device"]
|
|
@@ -59,6 +59,9 @@ _import_structure = {
|
|
|
59
59
|
"RBLNVQModelConfig",
|
|
60
60
|
],
|
|
61
61
|
"pipelines": [
|
|
62
|
+
"RBLNAutoPipelineForImage2Image",
|
|
63
|
+
"RBLNAutoPipelineForInpainting",
|
|
64
|
+
"RBLNAutoPipelineForText2Image",
|
|
62
65
|
"RBLNCosmosTextToWorldPipeline",
|
|
63
66
|
"RBLNCosmosVideoToWorldPipeline",
|
|
64
67
|
"RBLNCosmosSafetyChecker",
|
|
@@ -135,6 +138,7 @@ if TYPE_CHECKING:
|
|
|
135
138
|
from .modeling_diffusers import RBLNDiffusionMixin
|
|
136
139
|
from .models import (
|
|
137
140
|
RBLNAutoencoderKL,
|
|
141
|
+
RBLNAutoencoderKLCosmos,
|
|
138
142
|
RBLNControlNetModel,
|
|
139
143
|
RBLNCosmosTransformer3DModel,
|
|
140
144
|
RBLNPriorTransformer,
|
|
@@ -143,6 +147,9 @@ if TYPE_CHECKING:
|
|
|
143
147
|
RBLNVQModel,
|
|
144
148
|
)
|
|
145
149
|
from .pipelines import (
|
|
150
|
+
RBLNAutoPipelineForImage2Image,
|
|
151
|
+
RBLNAutoPipelineForInpainting,
|
|
152
|
+
RBLNAutoPipelineForText2Image,
|
|
146
153
|
RBLNCosmosSafetyChecker,
|
|
147
154
|
RBLNCosmosTextToWorldPipeline,
|
|
148
155
|
RBLNCosmosVideoToWorldPipeline,
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
@@ -33,7 +33,7 @@ class RBLNAutoencoderKLConfig(RBLNModelConfig):
|
|
|
33
33
|
vae_scale_factor: Optional[float] = None, # TODO: rename to scaling_factor
|
|
34
34
|
in_channels: Optional[int] = None,
|
|
35
35
|
latent_channels: Optional[int] = None,
|
|
36
|
-
**kwargs:
|
|
36
|
+
**kwargs: Any,
|
|
37
37
|
):
|
|
38
38
|
"""
|
|
39
39
|
Args:
|
|
@@ -46,7 +46,7 @@ class RBLNAutoencoderKLConfig(RBLNModelConfig):
|
|
|
46
46
|
Determines how much smaller the latent representations are compared to the original images.
|
|
47
47
|
in_channels (Optional[int]): Number of input channels for the model.
|
|
48
48
|
latent_channels (Optional[int]): Number of channels in the latent space.
|
|
49
|
-
|
|
49
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
50
50
|
|
|
51
51
|
Raises:
|
|
52
52
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -52,7 +52,7 @@ class RBLNAutoencoderKLCosmosConfig(RBLNModelConfig):
|
|
|
52
52
|
Determines how much smaller the latent representations are compared to the original videos.
|
|
53
53
|
use_slicing (Optional[bool]): Enable sliced VAE encoding and decoding.
|
|
54
54
|
If True, the VAE will split the input tensor in slices to compute encoding or decoding in several steps.
|
|
55
|
-
|
|
55
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
56
56
|
|
|
57
57
|
Raises:
|
|
58
58
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
@@ -29,7 +29,7 @@ class RBLNControlNetModelConfig(RBLNModelConfig):
|
|
|
29
29
|
unet_sample_size: Optional[Tuple[int, int]] = None,
|
|
30
30
|
vae_sample_size: Optional[Tuple[int, int]] = None,
|
|
31
31
|
text_model_hidden_size: Optional[int] = None,
|
|
32
|
-
**kwargs:
|
|
32
|
+
**kwargs: Any,
|
|
33
33
|
):
|
|
34
34
|
"""
|
|
35
35
|
Args:
|
|
@@ -42,7 +42,7 @@ class RBLNControlNetModelConfig(RBLNModelConfig):
|
|
|
42
42
|
of the VAE input/output images.
|
|
43
43
|
text_model_hidden_size (Optional[int]): Hidden size of the text encoder model used
|
|
44
44
|
for conditioning.
|
|
45
|
-
|
|
45
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
46
46
|
|
|
47
47
|
Raises:
|
|
48
48
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
@@ -22,7 +22,7 @@ class RBLNPriorTransformerConfig(RBLNModelConfig):
|
|
|
22
22
|
Configuration class for RBLN Prior Transformer models.
|
|
23
23
|
|
|
24
24
|
This class inherits from RBLNModelConfig and provides specific configuration options
|
|
25
|
-
for
|
|
25
|
+
for Transformer models used in diffusion models like Kandinsky V2.2.
|
|
26
26
|
"""
|
|
27
27
|
|
|
28
28
|
subclass_non_save_attributes = ["_batch_size_is_specified"]
|
|
@@ -32,14 +32,14 @@ class RBLNPriorTransformerConfig(RBLNModelConfig):
|
|
|
32
32
|
batch_size: Optional[int] = None,
|
|
33
33
|
embedding_dim: Optional[int] = None,
|
|
34
34
|
num_embeddings: Optional[int] = None,
|
|
35
|
-
**kwargs:
|
|
35
|
+
**kwargs: Any,
|
|
36
36
|
):
|
|
37
37
|
"""
|
|
38
38
|
Args:
|
|
39
39
|
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
40
40
|
embedding_dim (Optional[int]): Dimension of the embedding vectors in the model.
|
|
41
41
|
num_embeddings (Optional[int]): Number of discrete embeddings in the codebook.
|
|
42
|
-
|
|
42
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
43
43
|
|
|
44
44
|
Raises:
|
|
45
45
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -12,13 +12,18 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class RBLNCosmosTransformer3DModelConfig(RBLNModelConfig):
|
|
21
|
-
"""
|
|
21
|
+
"""
|
|
22
|
+
Configuration class for RBLN Cosmos Transformer models.
|
|
23
|
+
|
|
24
|
+
This class inherits from RBLNModelConfig and provides specific configuration options
|
|
25
|
+
for Transformer models used in diffusion models like Cosmos.
|
|
26
|
+
"""
|
|
22
27
|
|
|
23
28
|
def __init__(
|
|
24
29
|
self,
|
|
@@ -33,7 +38,7 @@ class RBLNCosmosTransformer3DModelConfig(RBLNModelConfig):
|
|
|
33
38
|
num_latent_frames: Optional[int] = None,
|
|
34
39
|
latent_height: Optional[int] = None,
|
|
35
40
|
latent_width: Optional[int] = None,
|
|
36
|
-
**kwargs:
|
|
41
|
+
**kwargs: Any,
|
|
37
42
|
):
|
|
38
43
|
"""
|
|
39
44
|
Args:
|
|
@@ -47,7 +52,7 @@ class RBLNCosmosTransformer3DModelConfig(RBLNModelConfig):
|
|
|
47
52
|
num_channels_latents (Optional[int]): The number of channels in latent space.
|
|
48
53
|
latent_height (Optional[int]): The height in pixels in latent space.
|
|
49
54
|
latent_width (Optional[int]): The width in pixels in latent space.
|
|
50
|
-
|
|
55
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
51
56
|
|
|
52
57
|
Raises:
|
|
53
58
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -12,13 +12,18 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional, Tuple, Union
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class RBLNSD3Transformer2DModelConfig(RBLNModelConfig):
|
|
21
|
-
"""
|
|
21
|
+
"""
|
|
22
|
+
Configuration class for RBLN Stable Diffusion 3 Transformer models.
|
|
23
|
+
|
|
24
|
+
This class inherits from RBLNModelConfig and provides specific configuration options
|
|
25
|
+
for Transformer models used in diffusion models like Stable Diffusion 3.
|
|
26
|
+
"""
|
|
22
27
|
|
|
23
28
|
subclass_non_save_attributes = ["_batch_size_is_specified"]
|
|
24
29
|
|
|
@@ -27,7 +32,7 @@ class RBLNSD3Transformer2DModelConfig(RBLNModelConfig):
|
|
|
27
32
|
batch_size: Optional[int] = None,
|
|
28
33
|
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
29
34
|
prompt_embed_length: Optional[int] = None,
|
|
30
|
-
**kwargs:
|
|
35
|
+
**kwargs: Any,
|
|
31
36
|
):
|
|
32
37
|
"""
|
|
33
38
|
Args:
|
|
@@ -36,7 +41,7 @@ class RBLNSD3Transformer2DModelConfig(RBLNModelConfig):
|
|
|
36
41
|
of the generated samples. If an integer is provided, it's used for both height and width.
|
|
37
42
|
prompt_embed_length (Optional[int]): The length of the embedded prompt vectors that
|
|
38
43
|
will be used to condition the transformer model.
|
|
39
|
-
|
|
44
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
40
45
|
|
|
41
46
|
Raises:
|
|
42
47
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
@@ -38,7 +38,7 @@ class RBLNUNet2DConditionModelConfig(RBLNModelConfig):
|
|
|
38
38
|
in_features: Optional[int] = None,
|
|
39
39
|
text_model_hidden_size: Optional[int] = None,
|
|
40
40
|
image_model_hidden_size: Optional[int] = None,
|
|
41
|
-
**kwargs:
|
|
41
|
+
**kwargs: Any,
|
|
42
42
|
):
|
|
43
43
|
"""
|
|
44
44
|
Args:
|
|
@@ -52,7 +52,7 @@ class RBLNUNet2DConditionModelConfig(RBLNModelConfig):
|
|
|
52
52
|
in_features (Optional[int]): Number of input features for the model.
|
|
53
53
|
text_model_hidden_size (Optional[int]): Hidden size of the text encoder model.
|
|
54
54
|
image_model_hidden_size (Optional[int]): Hidden size of the image encoder model.
|
|
55
|
-
|
|
55
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
56
56
|
|
|
57
57
|
Raises:
|
|
58
58
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
@@ -33,7 +33,7 @@ class RBLNVQModelConfig(RBLNModelConfig):
|
|
|
33
33
|
vqmodel_scale_factor: Optional[float] = None, # TODO: rename to scaling_factor
|
|
34
34
|
in_channels: Optional[int] = None,
|
|
35
35
|
latent_channels: Optional[int] = None,
|
|
36
|
-
**kwargs:
|
|
36
|
+
**kwargs: Any,
|
|
37
37
|
):
|
|
38
38
|
"""
|
|
39
39
|
Args:
|
|
@@ -46,7 +46,7 @@ class RBLNVQModelConfig(RBLNModelConfig):
|
|
|
46
46
|
Determines the downsampling ratio between original images and latent representations.
|
|
47
47
|
in_channels (Optional[int]): Number of input channels for the model.
|
|
48
48
|
latent_channels (Optional[int]): Number of channels in the latent space.
|
|
49
|
-
|
|
49
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
50
50
|
|
|
51
51
|
Raises:
|
|
52
52
|
ValueError: If batch_size is not a positive integer.
|