optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +116 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +171 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +33 -18
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +50 -24
- optimum/rbln/modeling_base.py +116 -35
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +100 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +93 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +22 -50
- optimum/rbln/utils/runtime_utils.py +85 -17
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -21,8 +21,10 @@ from typing import Any, Dict, List, Optional, Protocol, Tuple, Type, Union, runt
|
|
|
21
21
|
|
|
22
22
|
import numpy as np
|
|
23
23
|
import torch
|
|
24
|
+
from packaging.version import Version
|
|
24
25
|
|
|
25
26
|
from .__version__ import __version__
|
|
27
|
+
from .utils.deprecation import warn_deprecated_npu
|
|
26
28
|
from .utils.logging import get_logger
|
|
27
29
|
from .utils.runtime_utils import ContextRblnConfig
|
|
28
30
|
|
|
@@ -31,7 +33,6 @@ logger = get_logger(__name__)
|
|
|
31
33
|
|
|
32
34
|
|
|
33
35
|
DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
|
|
34
|
-
DEFAULT_MOD_NAME = "default"
|
|
35
36
|
TypeInputInfo = List[Tuple[str, Tuple[int], str]]
|
|
36
37
|
|
|
37
38
|
|
|
@@ -39,6 +40,9 @@ TypeInputInfo = List[Tuple[str, Tuple[int], str]]
|
|
|
39
40
|
class RBLNSerializableConfigProtocol(Protocol):
|
|
40
41
|
def _prepare_for_serialization(self) -> Dict[str, Any]: ...
|
|
41
42
|
|
|
43
|
+
def __repr__(self) -> str:
|
|
44
|
+
return f"{self.__class__.__name__}({self._prepare_for_serialization()})"
|
|
45
|
+
|
|
42
46
|
|
|
43
47
|
@dataclass
|
|
44
48
|
class RBLNCompileConfig:
|
|
@@ -47,17 +51,13 @@ class RBLNCompileConfig:
|
|
|
47
51
|
|
|
48
52
|
Attributes:
|
|
49
53
|
compiled_model_name (str): Name of the compiled model.
|
|
50
|
-
mod_name (str): Name of the RBLN module.
|
|
51
54
|
input_info (Union[List[TypeInputInfo], TypeInputInfo]): Information about input tensors.
|
|
52
|
-
fusion (Optional[bool]): Whether to use fusion optimization.
|
|
53
55
|
npu (Optional[str]): NPU configuration.
|
|
54
56
|
tensor_parallel_size (Optional[int]): Size for tensor parallelism.
|
|
55
57
|
"""
|
|
56
58
|
|
|
57
59
|
compiled_model_name: str = DEFAULT_COMPILED_MODEL_NAME
|
|
58
|
-
mod_name: str = DEFAULT_MOD_NAME
|
|
59
60
|
input_info: Union[List[TypeInputInfo], TypeInputInfo] = None
|
|
60
|
-
fusion: Optional[bool] = None
|
|
61
61
|
npu: Optional[str] = None
|
|
62
62
|
tensor_parallel_size: Optional[int] = None
|
|
63
63
|
|
|
@@ -111,9 +111,7 @@ class RBLNCompileConfig:
|
|
|
111
111
|
|
|
112
112
|
def update(self, kwargs: Dict[str, Any]):
|
|
113
113
|
self.compiled_model_name = kwargs.get("compiled_model_name", self.compiled_model_name)
|
|
114
|
-
self.mod_name = kwargs.get("mod_name", self.mod_name)
|
|
115
114
|
self.input_info = kwargs.get("input_info", self.input_info)
|
|
116
|
-
self.fusion = kwargs.get("fusion", self.fusion)
|
|
117
115
|
self.npu = kwargs.get("npu", self.npu)
|
|
118
116
|
self.tensor_parallel_size = kwargs.get("tensor_parallel_size", self.tensor_parallel_size)
|
|
119
117
|
return self
|
|
@@ -147,7 +145,7 @@ class RBLNCompileConfig:
|
|
|
147
145
|
return asdict(self)
|
|
148
146
|
|
|
149
147
|
|
|
150
|
-
RUNTIME_KEYWORDS = ["create_runtimes", "
|
|
148
|
+
RUNTIME_KEYWORDS = ["create_runtimes", "device", "device_map", "activate_profiler", "timeout"]
|
|
151
149
|
CONFIG_MAPPING: Dict[str, Type["RBLNModelConfig"]] = {}
|
|
152
150
|
|
|
153
151
|
|
|
@@ -183,6 +181,15 @@ def load_config(path: str) -> Tuple[Type["RBLNModelConfig"], Dict[str, Any]]:
|
|
|
183
181
|
|
|
184
182
|
|
|
185
183
|
class RBLNAutoConfig:
|
|
184
|
+
"""
|
|
185
|
+
Resolver and factory for RBLN model configurations.
|
|
186
|
+
|
|
187
|
+
This class selects the concrete `RBLNModelConfig` subclass, validates the
|
|
188
|
+
provided data, and returns a frozen configuration object that serves as the
|
|
189
|
+
single source of truth during export and load. It does not define the schema
|
|
190
|
+
or control model behavior.
|
|
191
|
+
"""
|
|
192
|
+
|
|
186
193
|
def __new__(cls, **kwargs):
|
|
187
194
|
cls_name = kwargs.get("cls_name")
|
|
188
195
|
if cls_name is None:
|
|
@@ -192,6 +199,33 @@ class RBLNAutoConfig:
|
|
|
192
199
|
|
|
193
200
|
@staticmethod
|
|
194
201
|
def load_from_dict(config_dict: Dict[str, Any]) -> "RBLNModelConfig":
|
|
202
|
+
"""
|
|
203
|
+
Build a `RBLNModelConfig` from a plain dictionary.
|
|
204
|
+
|
|
205
|
+
The dictionary must contain `cls_name`, which identifies the concrete
|
|
206
|
+
configuration class to instantiate. All other keys are forwarded to the
|
|
207
|
+
target class initializer. This method does not mutate `config_dict`.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
config_dict: Mapping typically created by `json.load` or `yaml.safe_load`.
|
|
211
|
+
For example, the parsed contents of `rbln_config.json`.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
RBLNModelConfig: A configuration instance. The specific subclass is
|
|
215
|
+
selected by `config_dict["cls_name"]`.
|
|
216
|
+
|
|
217
|
+
Raises:
|
|
218
|
+
ValueError: If `cls_name` is missing.
|
|
219
|
+
Exception: Any error raised by the target config class during init.
|
|
220
|
+
|
|
221
|
+
Examples:
|
|
222
|
+
>>> data = {
|
|
223
|
+
... "cls_name": "RBLNLlamaForCausalLMConfig",
|
|
224
|
+
... "create_runtimes": False,
|
|
225
|
+
... "tensor_parallel_size": 4
|
|
226
|
+
... }
|
|
227
|
+
>>> cfg = RBLNAutoConfig.load_from_dict(data)
|
|
228
|
+
"""
|
|
195
229
|
cls_name = config_dict.get("cls_name")
|
|
196
230
|
if cls_name is None:
|
|
197
231
|
raise ValueError("`cls_name` is required.")
|
|
@@ -204,7 +238,8 @@ class RBLNAutoConfig:
|
|
|
204
238
|
Register a new configuration for this class.
|
|
205
239
|
|
|
206
240
|
Args:
|
|
207
|
-
config (
|
|
241
|
+
config (RBLNModelConfig): The config to register.
|
|
242
|
+
exist_ok (bool): Whether to allow registering an already registered model.
|
|
208
243
|
"""
|
|
209
244
|
if not issubclass(config, RBLNModelConfig):
|
|
210
245
|
raise ValueError("`config` must be a subclass of RBLNModelConfig.")
|
|
@@ -246,9 +281,6 @@ class RBLNAutoConfig:
|
|
|
246
281
|
if key[5:] not in RUNTIME_KEYWORDS and key[5:] not in cls.submodules
|
|
247
282
|
}
|
|
248
283
|
|
|
249
|
-
if len(rbln_kwargs) > 0:
|
|
250
|
-
raise ValueError(f"Cannot set the following arguments: {list(rbln_kwargs.keys())}")
|
|
251
|
-
|
|
252
284
|
# Process submodule's rbln_config
|
|
253
285
|
for submodule in cls.submodules:
|
|
254
286
|
if submodule not in config_file:
|
|
@@ -263,6 +295,16 @@ class RBLNAutoConfig:
|
|
|
263
295
|
|
|
264
296
|
config_file.update(rbln_runtime_kwargs)
|
|
265
297
|
|
|
298
|
+
rbln_config = cls(**config_file)
|
|
299
|
+
|
|
300
|
+
if len(rbln_kwargs) > 0:
|
|
301
|
+
for key, value in rbln_kwargs.items():
|
|
302
|
+
if getattr(rbln_config, key) != value:
|
|
303
|
+
raise ValueError(
|
|
304
|
+
f"Cannot set the following arguments: {list(rbln_kwargs.keys())} "
|
|
305
|
+
f"Since the value is already set to {getattr(rbln_config, key)}"
|
|
306
|
+
)
|
|
307
|
+
|
|
266
308
|
if return_unused_kwargs:
|
|
267
309
|
return cls(**config_file), kwargs
|
|
268
310
|
else:
|
|
@@ -273,6 +315,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
273
315
|
"""Base configuration class for RBLN models that handles compilation settings, runtime options, and submodules.
|
|
274
316
|
|
|
275
317
|
This class provides functionality for:
|
|
318
|
+
|
|
276
319
|
1. Managing compilation configurations for RBLN devices
|
|
277
320
|
2. Configuring runtime behavior such as device placement
|
|
278
321
|
3. Handling nested configuration objects for complex model architectures
|
|
@@ -474,29 +517,31 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
474
517
|
non_save_attributes = [
|
|
475
518
|
"_frozen",
|
|
476
519
|
"_runtime_options",
|
|
520
|
+
"torch_dtype",
|
|
477
521
|
"npu",
|
|
478
522
|
"tensor_parallel_size",
|
|
479
523
|
"create_runtimes",
|
|
480
|
-
"optimize_host_memory",
|
|
481
524
|
"device",
|
|
482
525
|
"device_map",
|
|
483
526
|
"activate_profiler",
|
|
527
|
+
"timeout",
|
|
484
528
|
]
|
|
485
529
|
submodules: List[str] = []
|
|
486
530
|
subclass_non_save_attributes = []
|
|
531
|
+
_allow_no_compile_cfgs = False
|
|
487
532
|
|
|
488
|
-
def
|
|
533
|
+
def initialize_submodule_config(
|
|
489
534
|
self,
|
|
490
|
-
submodule_config_cls: Type["RBLNModelConfig"],
|
|
491
535
|
submodule_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
|
|
492
|
-
|
|
536
|
+
force_kwargs: bool = False,
|
|
537
|
+
**kwargs: Any,
|
|
493
538
|
) -> "RBLNModelConfig":
|
|
494
|
-
# Initialize a submodule config from a dict or a RBLNModelConfig.
|
|
495
|
-
# kwargs is specified from the predecessor config.
|
|
496
|
-
|
|
497
539
|
if submodule_config is None:
|
|
498
540
|
submodule_config = {}
|
|
499
541
|
|
|
542
|
+
if isinstance(submodule_config, RBLNModelConfig):
|
|
543
|
+
return submodule_config
|
|
544
|
+
|
|
500
545
|
if isinstance(submodule_config, dict):
|
|
501
546
|
from_predecessor = self._runtime_options.copy()
|
|
502
547
|
from_predecessor.update(
|
|
@@ -510,13 +555,60 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
510
555
|
|
|
511
556
|
init_kwargs = from_predecessor
|
|
512
557
|
init_kwargs.update(submodule_config)
|
|
513
|
-
submodule_config = submodule_config_cls(**init_kwargs)
|
|
514
558
|
|
|
515
|
-
|
|
559
|
+
if force_kwargs:
|
|
560
|
+
for key, value in kwargs.items():
|
|
561
|
+
if key in init_kwargs:
|
|
562
|
+
if init_kwargs[key] != value:
|
|
563
|
+
raise ValueError(
|
|
564
|
+
f"Parameter conflict for '{key}': submodule_config has {init_kwargs[key]}, "
|
|
565
|
+
f"but kwargs has {value}. Using kwargs value: {value}"
|
|
566
|
+
)
|
|
567
|
+
init_kwargs[key] = value
|
|
568
|
+
|
|
569
|
+
if "cls_name" in init_kwargs:
|
|
570
|
+
config_cls = get_rbln_config_class(init_kwargs["cls_name"])
|
|
571
|
+
else:
|
|
572
|
+
return init_kwargs
|
|
573
|
+
|
|
574
|
+
submodule_config = config_cls(**init_kwargs)
|
|
575
|
+
|
|
576
|
+
if not isinstance(submodule_config, RBLNModelConfig):
|
|
516
577
|
raise TypeError(f"Invalid submodule config type: {type(submodule_config)}")
|
|
517
578
|
|
|
518
579
|
return submodule_config
|
|
519
580
|
|
|
581
|
+
def filter_parameters(self, config_cls: Type["RBLNModelConfig"], parameters: Dict[str, Any]) -> Dict[str, Any]:
|
|
582
|
+
import importlib
|
|
583
|
+
|
|
584
|
+
model_cls_name = config_cls.__name__.replace("Config", "")
|
|
585
|
+
modeling_module_name = config_cls.__module__.replace("configuration_", "modeling_")
|
|
586
|
+
|
|
587
|
+
model_cls = None
|
|
588
|
+
try:
|
|
589
|
+
modeling_module = importlib.import_module(modeling_module_name)
|
|
590
|
+
if hasattr(modeling_module, model_cls_name):
|
|
591
|
+
model_cls = getattr(modeling_module, model_cls_name)
|
|
592
|
+
except ImportError:
|
|
593
|
+
logger.debug(f"Could not import modeling module: {modeling_module_name}")
|
|
594
|
+
|
|
595
|
+
filtered_out_params = set()
|
|
596
|
+
|
|
597
|
+
if model_cls is not None:
|
|
598
|
+
if not getattr(model_cls, "_tp_support", False):
|
|
599
|
+
filtered_out_params.add("tensor_parallel_size")
|
|
600
|
+
|
|
601
|
+
filtered_params = {}
|
|
602
|
+
for key, value in parameters.items():
|
|
603
|
+
if key in filtered_out_params:
|
|
604
|
+
logger.debug(
|
|
605
|
+
f"Parameter '{key}' filtered out for {config_cls.__name__} (not supported by model flags)."
|
|
606
|
+
)
|
|
607
|
+
else:
|
|
608
|
+
filtered_params[key] = value
|
|
609
|
+
|
|
610
|
+
return filtered_params
|
|
611
|
+
|
|
520
612
|
def __setattr__(self, key, value):
|
|
521
613
|
if (
|
|
522
614
|
key != "_attributes_map"
|
|
@@ -555,15 +647,18 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
555
647
|
self,
|
|
556
648
|
cls_name: Optional[str] = None,
|
|
557
649
|
create_runtimes: Optional[bool] = None,
|
|
558
|
-
optimize_host_memory: Optional[bool] = None,
|
|
559
650
|
device: Optional[Union[int, List[int]]] = None,
|
|
560
651
|
device_map: Optional[Dict[str, Union[int, List[int]]]] = None,
|
|
561
652
|
activate_profiler: Optional[bool] = None,
|
|
562
653
|
npu: Optional[str] = None,
|
|
563
654
|
tensor_parallel_size: Optional[int] = None,
|
|
655
|
+
timeout: Optional[int] = None,
|
|
564
656
|
optimum_rbln_version: Optional[str] = None,
|
|
657
|
+
_torch_dtype: Optional[str] = None,
|
|
565
658
|
_compile_cfgs: List[RBLNCompileConfig] = [],
|
|
566
|
-
|
|
659
|
+
*,
|
|
660
|
+
optimize_host_memory: Optional[bool] = None,
|
|
661
|
+
**kwargs: Any,
|
|
567
662
|
):
|
|
568
663
|
"""
|
|
569
664
|
Initialize a RBLN model configuration with runtime options and compile configurations.
|
|
@@ -571,15 +666,16 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
571
666
|
Args:
|
|
572
667
|
cls_name (Optional[str]): The class name of the configuration. Defaults to the current class name.
|
|
573
668
|
create_runtimes (Optional[bool]): Whether to create RBLN runtimes. Defaults to True.
|
|
574
|
-
optimize_host_memory (Optional[bool]): Whether to optimize host memory usage. Defaults to True.
|
|
575
669
|
device (Optional[Union[int, List[int]]]): The device(s) to load the model onto. Can be a single device ID or a list.
|
|
576
670
|
device_map (Optional[Dict[str, Union[int, List[int]]]]): Mapping from compiled model names to device IDs.
|
|
577
671
|
activate_profiler (Optional[bool]): Whether to activate the profiler for performance analysis.
|
|
578
672
|
npu (Optional[str]): The NPU device name to use for compilation.
|
|
579
673
|
tensor_parallel_size (Optional[int]): Size for tensor parallelism to distribute the model across devices.
|
|
674
|
+
timeout (Optional[int]): The timeout for the runtime in seconds. If it isn't provided, it will be set to 60 by default.
|
|
580
675
|
optimum_rbln_version (Optional[str]): The optimum-rbln version used for this configuration.
|
|
676
|
+
_torch_dtype (Optional[str]): The data type to use for the model.
|
|
581
677
|
_compile_cfgs (List[RBLNCompileConfig]): List of compilation configurations for the model.
|
|
582
|
-
|
|
678
|
+
kwargs: Additional keyword arguments.
|
|
583
679
|
|
|
584
680
|
Raises:
|
|
585
681
|
ValueError: If unexpected keyword arguments are provided.
|
|
@@ -595,15 +691,19 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
595
691
|
|
|
596
692
|
self._runtime_options = {}
|
|
597
693
|
self._runtime_options["create_runtimes"] = create_runtimes
|
|
598
|
-
self._runtime_options["optimize_host_memory"] = optimize_host_memory
|
|
599
694
|
self._runtime_options["device"] = device
|
|
600
695
|
self._runtime_options["device_map"] = device_map
|
|
601
696
|
self._runtime_options["activate_profiler"] = activate_profiler
|
|
697
|
+
self._runtime_options["timeout"] = timeout
|
|
698
|
+
|
|
699
|
+
if optimize_host_memory is not None:
|
|
700
|
+
logger.warning("`optimize_host_memory` is deprecated and will be removed in future versions.")
|
|
602
701
|
|
|
603
702
|
# Automatically pass npu, tensor_parallel_size to compile_cfgs
|
|
604
703
|
self.npu = npu
|
|
605
704
|
self.tensor_parallel_size = tensor_parallel_size
|
|
606
705
|
|
|
706
|
+
self._torch_dtype = _torch_dtype or "float32"
|
|
607
707
|
self.optimum_rbln_version = optimum_rbln_version
|
|
608
708
|
if self.optimum_rbln_version is None:
|
|
609
709
|
self.optimum_rbln_version = __version__
|
|
@@ -616,8 +716,34 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
616
716
|
self.set_compile_cfgs([RBLNCompileConfig(**cfg) for cfg in self._compile_cfgs])
|
|
617
717
|
|
|
618
718
|
if len(kwargs) > 0:
|
|
719
|
+
if optimum_rbln_version is not None: # loaded from file
|
|
720
|
+
if Version(__version__) < Version(optimum_rbln_version):
|
|
721
|
+
diff = "newer"
|
|
722
|
+
elif Version(__version__) > Version(optimum_rbln_version):
|
|
723
|
+
diff = "older"
|
|
724
|
+
else:
|
|
725
|
+
diff = None
|
|
726
|
+
if diff is not None:
|
|
727
|
+
raise ValueError(
|
|
728
|
+
f"Unexpected arguments: {kwargs.keys()}\n"
|
|
729
|
+
f"Maybe you are trying to load a model compiled with {diff} version of optimum-rbln. "
|
|
730
|
+
"It is recommended to use the same version to compile and load the model.\n"
|
|
731
|
+
f"Current version: {__version__}, Loaded version: {optimum_rbln_version}"
|
|
732
|
+
)
|
|
733
|
+
|
|
619
734
|
raise ValueError(f"Unexpected arguments: {kwargs.keys()}")
|
|
620
735
|
|
|
736
|
+
@property
|
|
737
|
+
def torch_dtype(self):
|
|
738
|
+
return getattr(torch, self._torch_dtype)
|
|
739
|
+
|
|
740
|
+
@torch_dtype.setter
|
|
741
|
+
def torch_dtype(self, torch_dtype: Union[str, torch.dtype]):
|
|
742
|
+
if isinstance(torch_dtype, torch.dtype):
|
|
743
|
+
torch_dtype = RBLNCompileConfig.normalize_dtype(torch_dtype)
|
|
744
|
+
|
|
745
|
+
self._torch_dtype = torch_dtype
|
|
746
|
+
|
|
621
747
|
@property
|
|
622
748
|
def rbln_model_cls_name(self) -> str:
|
|
623
749
|
return self.__class__.__name__[:-6]
|
|
@@ -671,6 +797,9 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
671
797
|
compile_cfg.npu = self.npu
|
|
672
798
|
compile_cfg.tensor_parallel_size = self.tensor_parallel_size
|
|
673
799
|
|
|
800
|
+
target_npu = self.npu or next((cfg.npu for cfg in self._compile_cfgs if cfg.npu is not None), None)
|
|
801
|
+
warn_deprecated_npu(target_npu)
|
|
802
|
+
|
|
674
803
|
def freeze(self):
|
|
675
804
|
if self._frozen:
|
|
676
805
|
raise RuntimeError(f"`{self.__class__.__name__}` is already frozen.")
|
|
@@ -680,7 +809,8 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
680
809
|
or len(self._compile_cfgs) == 0
|
|
681
810
|
or not all(isinstance(cfg, RBLNCompileConfig) for cfg in self._compile_cfgs)
|
|
682
811
|
):
|
|
683
|
-
|
|
812
|
+
if not self._allow_no_compile_cfgs:
|
|
813
|
+
raise RuntimeError("`compile_cfgs` must contain at least one `RBLNCompileConfig` before freezing.")
|
|
684
814
|
|
|
685
815
|
for submodule_name in self.submodules:
|
|
686
816
|
submodule_config = getattr(self, submodule_name, None)
|
|
@@ -709,13 +839,13 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
709
839
|
json.dump(serializable_data, jsonf, indent=2)
|
|
710
840
|
|
|
711
841
|
@classmethod
|
|
712
|
-
def load(cls, path: str, **kwargs:
|
|
842
|
+
def load(cls, path: str, **kwargs: Any) -> "RBLNModelConfig":
|
|
713
843
|
"""
|
|
714
844
|
Load a RBLNModelConfig from a path.
|
|
715
845
|
|
|
716
846
|
Args:
|
|
717
847
|
path (str): Path to the RBLNModelConfig file or directory containing the config file.
|
|
718
|
-
|
|
848
|
+
kwargs: Additional keyword arguments to override configuration values.
|
|
719
849
|
Keys starting with 'rbln_' will have the prefix removed and be used
|
|
720
850
|
to update the configuration.
|
|
721
851
|
|
|
@@ -742,7 +872,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
742
872
|
def initialize_from_kwargs(
|
|
743
873
|
cls: Type["RBLNModelConfig"],
|
|
744
874
|
rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
|
|
745
|
-
**kwargs:
|
|
875
|
+
**kwargs: Any,
|
|
746
876
|
) -> Tuple["RBLNModelConfig", Dict[str, Any]]:
|
|
747
877
|
# Initialize RBLNModelConfig from kwargs.
|
|
748
878
|
kwargs_keys = list(kwargs.keys())
|
|
@@ -787,19 +917,6 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
787
917
|
def create_runtimes(self, create_runtimes: bool):
|
|
788
918
|
self._runtime_options["create_runtimes"] = create_runtimes
|
|
789
919
|
|
|
790
|
-
@property
|
|
791
|
-
def optimize_host_memory(self):
|
|
792
|
-
context = ContextRblnConfig.get_current_context()["optimize_host_memory"]
|
|
793
|
-
if context is not None:
|
|
794
|
-
return context
|
|
795
|
-
elif self._runtime_options["optimize_host_memory"] is None:
|
|
796
|
-
return True
|
|
797
|
-
return self._runtime_options["optimize_host_memory"]
|
|
798
|
-
|
|
799
|
-
@optimize_host_memory.setter
|
|
800
|
-
def optimize_host_memory(self, optimize_host_memory: bool):
|
|
801
|
-
self._runtime_options["optimize_host_memory"] = optimize_host_memory
|
|
802
|
-
|
|
803
920
|
@property
|
|
804
921
|
def device(self):
|
|
805
922
|
context = ContextRblnConfig.get_current_context()["device"]
|
|
@@ -838,3 +955,14 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
838
955
|
@activate_profiler.setter
|
|
839
956
|
def activate_profiler(self, activate_profiler: bool):
|
|
840
957
|
self._runtime_options["activate_profiler"] = activate_profiler
|
|
958
|
+
|
|
959
|
+
@property
|
|
960
|
+
def timeout(self):
|
|
961
|
+
context = ContextRblnConfig.get_current_context()["timeout"]
|
|
962
|
+
if context is not None:
|
|
963
|
+
return context
|
|
964
|
+
return self._runtime_options["timeout"]
|
|
965
|
+
|
|
966
|
+
@timeout.setter
|
|
967
|
+
def timeout(self, timeout: int):
|
|
968
|
+
self._runtime_options["timeout"] = timeout
|
|
@@ -57,8 +57,14 @@ _import_structure = {
|
|
|
57
57
|
"RBLNSD3Transformer2DModelConfig",
|
|
58
58
|
"RBLNUNet2DConditionModelConfig",
|
|
59
59
|
"RBLNVQModelConfig",
|
|
60
|
+
"RBLNUNetSpatioTemporalConditionModelConfig",
|
|
61
|
+
"RBLNStableVideoDiffusionPipelineConfig",
|
|
62
|
+
"RBLNAutoencoderKLTemporalDecoderConfig",
|
|
60
63
|
],
|
|
61
64
|
"pipelines": [
|
|
65
|
+
"RBLNAutoPipelineForImage2Image",
|
|
66
|
+
"RBLNAutoPipelineForInpainting",
|
|
67
|
+
"RBLNAutoPipelineForText2Image",
|
|
62
68
|
"RBLNCosmosTextToWorldPipeline",
|
|
63
69
|
"RBLNCosmosVideoToWorldPipeline",
|
|
64
70
|
"RBLNCosmosSafetyChecker",
|
|
@@ -83,14 +89,17 @@ _import_structure = {
|
|
|
83
89
|
"RBLNStableDiffusion3Pipeline",
|
|
84
90
|
"RBLNStableDiffusion3Img2ImgPipeline",
|
|
85
91
|
"RBLNStableDiffusion3InpaintPipeline",
|
|
92
|
+
"RBLNStableVideoDiffusionPipeline",
|
|
86
93
|
],
|
|
87
94
|
"models": [
|
|
88
95
|
"RBLNAutoencoderKL",
|
|
89
96
|
"RBLNAutoencoderKLCosmos",
|
|
90
97
|
"RBLNUNet2DConditionModel",
|
|
98
|
+
"RBLNUNetSpatioTemporalConditionModel",
|
|
91
99
|
"RBLNControlNetModel",
|
|
92
100
|
"RBLNCosmosTransformer3DModel",
|
|
93
101
|
"RBLNSD3Transformer2DModel",
|
|
102
|
+
"RBLNAutoencoderKLTemporalDecoder",
|
|
94
103
|
"RBLNPriorTransformer",
|
|
95
104
|
"RBLNVQModel",
|
|
96
105
|
],
|
|
@@ -103,6 +112,7 @@ if TYPE_CHECKING:
|
|
|
103
112
|
from .configurations import (
|
|
104
113
|
RBLNAutoencoderKLConfig,
|
|
105
114
|
RBLNAutoencoderKLCosmosConfig,
|
|
115
|
+
RBLNAutoencoderKLTemporalDecoderConfig,
|
|
106
116
|
RBLNControlNetModelConfig,
|
|
107
117
|
RBLNCosmosTextToWorldPipelineConfig,
|
|
108
118
|
RBLNCosmosTransformer3DModelConfig,
|
|
@@ -129,20 +139,28 @@ if TYPE_CHECKING:
|
|
|
129
139
|
RBLNStableDiffusionXLImg2ImgPipelineConfig,
|
|
130
140
|
RBLNStableDiffusionXLInpaintPipelineConfig,
|
|
131
141
|
RBLNStableDiffusionXLPipelineConfig,
|
|
142
|
+
RBLNStableVideoDiffusionPipelineConfig,
|
|
132
143
|
RBLNUNet2DConditionModelConfig,
|
|
144
|
+
RBLNUNetSpatioTemporalConditionModelConfig,
|
|
133
145
|
RBLNVQModelConfig,
|
|
134
146
|
)
|
|
135
147
|
from .modeling_diffusers import RBLNDiffusionMixin
|
|
136
148
|
from .models import (
|
|
137
149
|
RBLNAutoencoderKL,
|
|
150
|
+
RBLNAutoencoderKLCosmos,
|
|
151
|
+
RBLNAutoencoderKLTemporalDecoder,
|
|
138
152
|
RBLNControlNetModel,
|
|
139
153
|
RBLNCosmosTransformer3DModel,
|
|
140
154
|
RBLNPriorTransformer,
|
|
141
155
|
RBLNSD3Transformer2DModel,
|
|
142
156
|
RBLNUNet2DConditionModel,
|
|
157
|
+
RBLNUNetSpatioTemporalConditionModel,
|
|
143
158
|
RBLNVQModel,
|
|
144
159
|
)
|
|
145
160
|
from .pipelines import (
|
|
161
|
+
RBLNAutoPipelineForImage2Image,
|
|
162
|
+
RBLNAutoPipelineForInpainting,
|
|
163
|
+
RBLNAutoPipelineForText2Image,
|
|
146
164
|
RBLNCosmosSafetyChecker,
|
|
147
165
|
RBLNCosmosTextToWorldPipeline,
|
|
148
166
|
RBLNCosmosVideoToWorldPipeline,
|
|
@@ -167,6 +185,7 @@ if TYPE_CHECKING:
|
|
|
167
185
|
RBLNStableDiffusionXLImg2ImgPipeline,
|
|
168
186
|
RBLNStableDiffusionXLInpaintPipeline,
|
|
169
187
|
RBLNStableDiffusionXLPipeline,
|
|
188
|
+
RBLNStableVideoDiffusionPipeline,
|
|
170
189
|
)
|
|
171
190
|
else:
|
|
172
191
|
import sys
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
from .models import (
|
|
2
2
|
RBLNAutoencoderKLConfig,
|
|
3
3
|
RBLNAutoencoderKLCosmosConfig,
|
|
4
|
+
RBLNAutoencoderKLTemporalDecoderConfig,
|
|
4
5
|
RBLNControlNetModelConfig,
|
|
5
6
|
RBLNCosmosTransformer3DModelConfig,
|
|
6
7
|
RBLNPriorTransformerConfig,
|
|
7
8
|
RBLNSD3Transformer2DModelConfig,
|
|
8
9
|
RBLNUNet2DConditionModelConfig,
|
|
10
|
+
RBLNUNetSpatioTemporalConditionModelConfig,
|
|
9
11
|
RBLNVQModelConfig,
|
|
10
12
|
)
|
|
11
13
|
from .pipelines import (
|
|
@@ -31,4 +33,5 @@ from .pipelines import (
|
|
|
31
33
|
RBLNStableDiffusionXLImg2ImgPipelineConfig,
|
|
32
34
|
RBLNStableDiffusionXLInpaintPipelineConfig,
|
|
33
35
|
RBLNStableDiffusionXLPipelineConfig,
|
|
36
|
+
RBLNStableVideoDiffusionPipelineConfig,
|
|
34
37
|
)
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from .configuration_autoencoder_kl import RBLNAutoencoderKLConfig
|
|
2
2
|
from .configuration_autoencoder_kl_cosmos import RBLNAutoencoderKLCosmosConfig
|
|
3
|
+
from .configuration_autoencoder_kl_temporal_decoder import RBLNAutoencoderKLTemporalDecoderConfig
|
|
3
4
|
from .configuration_controlnet import RBLNControlNetModelConfig
|
|
4
5
|
from .configuration_prior_transformer import RBLNPriorTransformerConfig
|
|
5
6
|
from .configuration_transformer_cosmos import RBLNCosmosTransformer3DModelConfig
|
|
6
7
|
from .configuration_transformer_sd3 import RBLNSD3Transformer2DModelConfig
|
|
7
8
|
from .configuration_unet_2d_condition import RBLNUNet2DConditionModelConfig
|
|
9
|
+
from .configuration_unet_spatio_temporal_condition import RBLNUNetSpatioTemporalConditionModelConfig
|
|
8
10
|
from .configuration_vq_model import RBLNVQModelConfig
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
@@ -33,7 +33,7 @@ class RBLNAutoencoderKLConfig(RBLNModelConfig):
|
|
|
33
33
|
vae_scale_factor: Optional[float] = None, # TODO: rename to scaling_factor
|
|
34
34
|
in_channels: Optional[int] = None,
|
|
35
35
|
latent_channels: Optional[int] = None,
|
|
36
|
-
**kwargs:
|
|
36
|
+
**kwargs: Any,
|
|
37
37
|
):
|
|
38
38
|
"""
|
|
39
39
|
Args:
|
|
@@ -46,7 +46,7 @@ class RBLNAutoencoderKLConfig(RBLNModelConfig):
|
|
|
46
46
|
Determines how much smaller the latent representations are compared to the original images.
|
|
47
47
|
in_channels (Optional[int]): Number of input channels for the model.
|
|
48
48
|
latent_channels (Optional[int]): Number of channels in the latent space.
|
|
49
|
-
|
|
49
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
50
50
|
|
|
51
51
|
Raises:
|
|
52
52
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -52,7 +52,7 @@ class RBLNAutoencoderKLCosmosConfig(RBLNModelConfig):
|
|
|
52
52
|
Determines how much smaller the latent representations are compared to the original videos.
|
|
53
53
|
use_slicing (Optional[bool]): Enable sliced VAE encoding and decoding.
|
|
54
54
|
If True, the VAE will split the input tensor in slices to compute encoding or decoding in several steps.
|
|
55
|
-
|
|
55
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
56
56
|
|
|
57
57
|
Raises:
|
|
58
58
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
|
+
|
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RBLNAutoencoderKLTemporalDecoderConfig(RBLNModelConfig):
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
batch_size: Optional[int] = None,
|
|
24
|
+
sample_size: Optional[Tuple[int, int]] = None,
|
|
25
|
+
uses_encoder: Optional[bool] = None,
|
|
26
|
+
num_frames: Optional[int] = None,
|
|
27
|
+
decode_chunk_size: Optional[int] = None,
|
|
28
|
+
vae_scale_factor: Optional[float] = None,
|
|
29
|
+
**kwargs: Any,
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Args:
|
|
33
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
34
|
+
sample_size (Optional[Tuple[int, int]]): The spatial dimensions (height, width) of the input/output images.
|
|
35
|
+
If an integer is provided, it's used for both height and width.
|
|
36
|
+
uses_encoder (Optional[bool]): Whether to include the encoder part of the VAE in the model.
|
|
37
|
+
When False, only the decoder is used (for latent-to-image conversion).
|
|
38
|
+
num_frames (Optional[int]): The number of frames in the generated video.
|
|
39
|
+
decode_chunk_size (Optional[int]): The number of frames to decode at once during VAE decoding.
|
|
40
|
+
Useful for managing memory usage during video generation.
|
|
41
|
+
vae_scale_factor (Optional[float]): The scaling factor between pixel space and latent space.
|
|
42
|
+
Determines how much smaller the latent representations are compared to the original images.
|
|
43
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
ValueError: If batch_size is not a positive integer.
|
|
47
|
+
"""
|
|
48
|
+
super().__init__(**kwargs)
|
|
49
|
+
self.batch_size = batch_size or 1
|
|
50
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
51
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
52
|
+
|
|
53
|
+
self.uses_encoder = uses_encoder
|
|
54
|
+
self.num_frames = num_frames
|
|
55
|
+
self.decode_chunk_size = decode_chunk_size
|
|
56
|
+
self.vae_scale_factor = vae_scale_factor
|
|
57
|
+
self.sample_size = sample_size
|
|
58
|
+
if isinstance(sample_size, int):
|
|
59
|
+
self.sample_size = (sample_size, sample_size)
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def image_size(self):
|
|
63
|
+
return self.sample_size
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def latent_sample_size(self):
|
|
67
|
+
return (self.image_size[0] // self.vae_scale_factor, self.image_size[1] // self.vae_scale_factor)
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
@@ -29,7 +29,7 @@ class RBLNControlNetModelConfig(RBLNModelConfig):
|
|
|
29
29
|
unet_sample_size: Optional[Tuple[int, int]] = None,
|
|
30
30
|
vae_sample_size: Optional[Tuple[int, int]] = None,
|
|
31
31
|
text_model_hidden_size: Optional[int] = None,
|
|
32
|
-
**kwargs:
|
|
32
|
+
**kwargs: Any,
|
|
33
33
|
):
|
|
34
34
|
"""
|
|
35
35
|
Args:
|
|
@@ -42,7 +42,7 @@ class RBLNControlNetModelConfig(RBLNModelConfig):
|
|
|
42
42
|
of the VAE input/output images.
|
|
43
43
|
text_model_hidden_size (Optional[int]): Hidden size of the text encoder model used
|
|
44
44
|
for conditioning.
|
|
45
|
-
|
|
45
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
46
46
|
|
|
47
47
|
Raises:
|
|
48
48
|
ValueError: If batch_size is not a positive integer.
|