optimum-rbln 0.8.0.post1__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +24 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +53 -33
- optimum/rbln/diffusers/__init__.py +21 -1
- optimum/rbln/diffusers/configurations/__init__.py +4 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
- optimum/rbln/diffusers/modeling_diffusers.py +72 -65
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
- optimum/rbln/diffusers/models/controlnet.py +14 -8
- optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
- optimum/rbln/diffusers/pipelines/__init__.py +10 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
- optimum/rbln/modeling.py +71 -37
- optimum/rbln/modeling_base.py +63 -109
- optimum/rbln/transformers/__init__.py +41 -47
- optimum/rbln/transformers/configuration_generic.py +16 -13
- optimum/rbln/transformers/modeling_generic.py +21 -22
- optimum/rbln/transformers/modeling_rope_utils.py +5 -2
- optimum/rbln/transformers/models/__init__.py +54 -4
- optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
- optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
- optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
- optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
- optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
- optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
- optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
- optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
- optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
- optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
- optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
- optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
- optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
- optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
- optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
- optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
- optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
- optimum/rbln/utils/model_utils.py +20 -0
- optimum/rbln/utils/runtime_utils.py +49 -1
- optimum/rbln/utils/submodule.py +6 -8
- {optimum_rbln-0.8.0.post1.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
- optimum_rbln-0.8.1.dist-info/RECORD +211 -0
- optimum_rbln-0.8.0.post1.dist-info/RECORD +0 -184
- /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
- {optimum_rbln-0.8.0.post1.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.0.post1.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -26,6 +26,7 @@ _import_structure = {
|
|
26
26
|
"RBLNModel",
|
27
27
|
],
|
28
28
|
"configuration_utils": [
|
29
|
+
"RBLNAutoConfig",
|
29
30
|
"RBLNCompileConfig",
|
30
31
|
"RBLNModelConfig",
|
31
32
|
],
|
@@ -69,6 +70,8 @@ _import_structure = {
|
|
69
70
|
"RBLNCLIPVisionModelConfig",
|
70
71
|
"RBLNCLIPVisionModelWithProjection",
|
71
72
|
"RBLNCLIPVisionModelWithProjectionConfig",
|
73
|
+
"RBLNColPaliForRetrieval",
|
74
|
+
"RBLNColPaliForRetrievalConfig",
|
72
75
|
"RBLNDecoderOnlyModelForCausalLM",
|
73
76
|
"RBLNDecoderOnlyModelForCausalLMConfig",
|
74
77
|
"RBLNDistilBertForQuestionAnswering",
|
@@ -135,8 +138,17 @@ _import_structure = {
|
|
135
138
|
"diffusers": [
|
136
139
|
"RBLNAutoencoderKL",
|
137
140
|
"RBLNAutoencoderKLConfig",
|
141
|
+
"RBLNAutoencoderKLCosmos",
|
142
|
+
"RBLNAutoencoderKLCosmosConfig",
|
138
143
|
"RBLNControlNetModel",
|
139
144
|
"RBLNControlNetModelConfig",
|
145
|
+
"RBLNCosmosTextToWorldPipeline",
|
146
|
+
"RBLNCosmosVideoToWorldPipeline",
|
147
|
+
"RBLNCosmosTextToWorldPipelineConfig",
|
148
|
+
"RBLNCosmosVideoToWorldPipelineConfig",
|
149
|
+
"RBLNCosmosSafetyChecker",
|
150
|
+
"RBLNCosmosTransformer3DModel",
|
151
|
+
"RBLNCosmosTransformer3DModelConfig",
|
140
152
|
"RBLNDiffusionMixin",
|
141
153
|
"RBLNKandinskyV22CombinedPipeline",
|
142
154
|
"RBLNKandinskyV22CombinedPipelineConfig",
|
@@ -192,14 +204,24 @@ _import_structure = {
|
|
192
204
|
|
193
205
|
if TYPE_CHECKING:
|
194
206
|
from .configuration_utils import (
|
207
|
+
RBLNAutoConfig,
|
195
208
|
RBLNCompileConfig,
|
196
209
|
RBLNModelConfig,
|
197
210
|
)
|
198
211
|
from .diffusers import (
|
199
212
|
RBLNAutoencoderKL,
|
200
213
|
RBLNAutoencoderKLConfig,
|
214
|
+
RBLNAutoencoderKLCosmos,
|
215
|
+
RBLNAutoencoderKLCosmosConfig,
|
201
216
|
RBLNControlNetModel,
|
202
217
|
RBLNControlNetModelConfig,
|
218
|
+
RBLNCosmosSafetyChecker,
|
219
|
+
RBLNCosmosTextToWorldPipeline,
|
220
|
+
RBLNCosmosTextToWorldPipelineConfig,
|
221
|
+
RBLNCosmosTransformer3DModel,
|
222
|
+
RBLNCosmosTransformer3DModelConfig,
|
223
|
+
RBLNCosmosVideoToWorldPipeline,
|
224
|
+
RBLNCosmosVideoToWorldPipelineConfig,
|
203
225
|
RBLNDiffusionMixin,
|
204
226
|
RBLNKandinskyV22CombinedPipeline,
|
205
227
|
RBLNKandinskyV22CombinedPipelineConfig,
|
@@ -295,6 +317,8 @@ if TYPE_CHECKING:
|
|
295
317
|
RBLNCLIPVisionModelConfig,
|
296
318
|
RBLNCLIPVisionModelWithProjection,
|
297
319
|
RBLNCLIPVisionModelWithProjectionConfig,
|
320
|
+
RBLNColPaliForRetrieval,
|
321
|
+
RBLNColPaliForRetrievalConfig,
|
298
322
|
RBLNDecoderOnlyModelForCausalLM,
|
299
323
|
RBLNDecoderOnlyModelForCausalLMConfig,
|
300
324
|
RBLNDistilBertForQuestionAnswering,
|
optimum/rbln/__version__.py
CHANGED
@@ -17,5 +17,5 @@ __version__: str
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
18
18
|
version_tuple: VERSION_TUPLE
|
19
19
|
|
20
|
-
__version__ = version = '0.8.
|
21
|
-
__version_tuple__ = version_tuple = (0, 8,
|
20
|
+
__version__ = version = '0.8.1'
|
21
|
+
__version_tuple__ = version_tuple = (0, 8, 1)
|
@@ -19,6 +19,7 @@ from dataclasses import asdict, dataclass
|
|
19
19
|
from pathlib import Path
|
20
20
|
from typing import Any, Dict, List, Optional, Protocol, Tuple, Type, Union, runtime_checkable
|
21
21
|
|
22
|
+
import numpy as np
|
22
23
|
import torch
|
23
24
|
|
24
25
|
from .__version__ import __version__
|
@@ -61,7 +62,7 @@ class RBLNCompileConfig:
|
|
61
62
|
tensor_parallel_size: Optional[int] = None
|
62
63
|
|
63
64
|
@staticmethod
|
64
|
-
def normalize_dtype(dtype):
|
65
|
+
def normalize_dtype(dtype: Union[str, torch.dtype, np.dtype]) -> str:
|
65
66
|
"""
|
66
67
|
Convert framework-specific dtype to string representation.
|
67
68
|
i.e. torch.float32 -> "float32"
|
@@ -70,7 +71,7 @@ class RBLNCompileConfig:
|
|
70
71
|
dtype: The input dtype (can be string, torch dtype, or numpy dtype).
|
71
72
|
|
72
73
|
Returns:
|
73
|
-
|
74
|
+
The normalized string representation of the dtype.
|
74
75
|
"""
|
75
76
|
if isinstance(dtype, str):
|
76
77
|
return dtype
|
@@ -147,6 +148,17 @@ class RBLNCompileConfig:
|
|
147
148
|
|
148
149
|
|
149
150
|
RUNTIME_KEYWORDS = ["create_runtimes", "optimize_host_memory", "device", "device_map", "activate_profiler"]
|
151
|
+
CONFIG_MAPPING: Dict[str, Type["RBLNModelConfig"]] = {}
|
152
|
+
|
153
|
+
|
154
|
+
def get_rbln_config_class(rbln_config_class_name: str) -> Type["RBLNModelConfig"]:
|
155
|
+
cls = getattr(importlib.import_module("optimum.rbln"), rbln_config_class_name, None)
|
156
|
+
if cls is None:
|
157
|
+
if rbln_config_class_name in CONFIG_MAPPING:
|
158
|
+
cls = CONFIG_MAPPING[rbln_config_class_name]
|
159
|
+
else:
|
160
|
+
raise ValueError(f"Configuration for {rbln_config_class_name} not found.")
|
161
|
+
return cls
|
150
162
|
|
151
163
|
|
152
164
|
def load_config(path: str) -> Tuple[Type["RBLNModelConfig"], Dict[str, Any]]:
|
@@ -166,7 +178,7 @@ def load_config(path: str) -> Tuple[Type["RBLNModelConfig"], Dict[str, Any]]:
|
|
166
178
|
)
|
167
179
|
|
168
180
|
cls_name = config_file["cls_name"]
|
169
|
-
cls =
|
181
|
+
cls = get_rbln_config_class(cls_name)
|
170
182
|
return cls, config_file
|
171
183
|
|
172
184
|
|
@@ -175,7 +187,7 @@ class RBLNAutoConfig:
|
|
175
187
|
cls_name = kwargs.get("cls_name")
|
176
188
|
if cls_name is None:
|
177
189
|
raise ValueError("`cls_name` is required.")
|
178
|
-
cls =
|
190
|
+
cls = get_rbln_config_class(cls_name)
|
179
191
|
return cls(**kwargs)
|
180
192
|
|
181
193
|
@staticmethod
|
@@ -183,9 +195,27 @@ class RBLNAutoConfig:
|
|
183
195
|
cls_name = config_dict.get("cls_name")
|
184
196
|
if cls_name is None:
|
185
197
|
raise ValueError("`cls_name` is required.")
|
186
|
-
cls =
|
198
|
+
cls = get_rbln_config_class(cls_name)
|
187
199
|
return cls(**config_dict)
|
188
200
|
|
201
|
+
@staticmethod
|
202
|
+
def register(config: Type["RBLNModelConfig"], exist_ok=False):
|
203
|
+
"""
|
204
|
+
Register a new configuration for this class.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
config ([`RBLNModelConfig`]): The config to register.
|
208
|
+
"""
|
209
|
+
if not issubclass(config, RBLNModelConfig):
|
210
|
+
raise ValueError("`config` must be a subclass of RBLNModelConfig.")
|
211
|
+
|
212
|
+
native_cls = getattr(importlib.import_module("optimum.rbln"), config.__name__, None)
|
213
|
+
if config.__name__ in CONFIG_MAPPING or native_cls is not None:
|
214
|
+
if not exist_ok:
|
215
|
+
raise ValueError(f"Configuration for {config.__name__} already registered.")
|
216
|
+
|
217
|
+
CONFIG_MAPPING[config.__name__] = config
|
218
|
+
|
189
219
|
@staticmethod
|
190
220
|
def load(
|
191
221
|
path: str,
|
@@ -307,9 +337,6 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
307
337
|
# Save to disk
|
308
338
|
config.save("/path/to/model")
|
309
339
|
|
310
|
-
# Load configuration from disk
|
311
|
-
loaded_config = RBLNModelConfig.load("/path/to/model")
|
312
|
-
|
313
340
|
# Using AutoConfig
|
314
341
|
loaded_config = RBLNAutoConfig.load("/path/to/model")
|
315
342
|
```
|
@@ -462,19 +489,25 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
462
489
|
self,
|
463
490
|
submodule_config_cls: Type["RBLNModelConfig"],
|
464
491
|
submodule_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
|
465
|
-
**kwargs,
|
492
|
+
**kwargs: Dict[str, Any],
|
466
493
|
) -> "RBLNModelConfig":
|
467
|
-
|
468
|
-
|
494
|
+
# Initialize a submodule config from a dict or a RBLNModelConfig.
|
495
|
+
# kwargs is specified from the predecessor config.
|
469
496
|
|
470
|
-
kwargs is specified from the predecessor config.
|
471
|
-
"""
|
472
497
|
if submodule_config is None:
|
473
498
|
submodule_config = {}
|
474
499
|
|
475
500
|
if isinstance(submodule_config, dict):
|
476
501
|
from_predecessor = self._runtime_options.copy()
|
502
|
+
from_predecessor.update(
|
503
|
+
{
|
504
|
+
"npu": self.npu,
|
505
|
+
"tensor_parallel_size": self.tensor_parallel_size,
|
506
|
+
"optimum_rbln_version": self.optimum_rbln_version,
|
507
|
+
}
|
508
|
+
)
|
477
509
|
from_predecessor.update(kwargs)
|
510
|
+
|
478
511
|
init_kwargs = from_predecessor
|
479
512
|
init_kwargs.update(submodule_config)
|
480
513
|
submodule_config = submodule_config_cls(**init_kwargs)
|
@@ -530,7 +563,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
530
563
|
tensor_parallel_size: Optional[int] = None,
|
531
564
|
optimum_rbln_version: Optional[str] = None,
|
532
565
|
_compile_cfgs: List[RBLNCompileConfig] = [],
|
533
|
-
**kwargs,
|
566
|
+
**kwargs: Dict[str, Any],
|
534
567
|
):
|
535
568
|
"""
|
536
569
|
Initialize a RBLN model configuration with runtime options and compile configurations.
|
@@ -600,10 +633,8 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
600
633
|
return rbln_model_cls
|
601
634
|
|
602
635
|
def _prepare_for_serialization(self) -> Dict[str, Any]:
|
603
|
-
|
604
|
-
|
605
|
-
objects to their serializable form.
|
606
|
-
"""
|
636
|
+
# Prepare the attributes map for serialization by converting nested RBLNModelConfig
|
637
|
+
# objects to their serializable form.
|
607
638
|
serializable_map = {}
|
608
639
|
for key, value in self._attributes_map.items():
|
609
640
|
if isinstance(value, RBLNSerializableConfigProtocol):
|
@@ -678,7 +709,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
678
709
|
json.dump(serializable_data, jsonf, indent=2)
|
679
710
|
|
680
711
|
@classmethod
|
681
|
-
def load(cls, path: str, **kwargs) -> "RBLNModelConfig":
|
712
|
+
def load(cls, path: str, **kwargs: Dict[str, Any]) -> "RBLNModelConfig":
|
682
713
|
"""
|
683
714
|
Load a RBLNModelConfig from a path.
|
684
715
|
|
@@ -711,11 +742,9 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
711
742
|
def initialize_from_kwargs(
|
712
743
|
cls: Type["RBLNModelConfig"],
|
713
744
|
rbln_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
|
714
|
-
**kwargs,
|
745
|
+
**kwargs: Dict[str, Any],
|
715
746
|
) -> Tuple["RBLNModelConfig", Dict[str, Any]]:
|
716
|
-
|
717
|
-
Initialize RBLNModelConfig from kwargs.
|
718
|
-
"""
|
747
|
+
# Initialize RBLNModelConfig from kwargs.
|
719
748
|
kwargs_keys = list(kwargs.keys())
|
720
749
|
rbln_kwargs = {key[5:]: kwargs.pop(key) for key in kwargs_keys if key.startswith("rbln_")}
|
721
750
|
|
@@ -733,16 +762,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
733
762
|
return rbln_config, kwargs
|
734
763
|
|
735
764
|
def get_default_values_for_original_cls(self, func_name: str, keys: List[str]) -> Dict[str, Any]:
|
736
|
-
|
737
|
-
Get default values for original class attributes from RBLNModelConfig.
|
738
|
-
|
739
|
-
Args:
|
740
|
-
func_name (str): The name of the function to get the default values for.
|
741
|
-
keys (List[str]): The keys of the attributes to get.
|
742
|
-
|
743
|
-
Returns:
|
744
|
-
Dict[str, Any]: The default values for the attributes.
|
745
|
-
"""
|
765
|
+
# Get default values for original class attributes from RBLNModelConfig.
|
746
766
|
model_cls = self.rbln_model_cls.get_hf_class()
|
747
767
|
func = getattr(model_cls, func_name)
|
748
768
|
func_signature = inspect.signature(func)
|
@@ -18,14 +18,21 @@ from diffusers.pipelines.pipeline_utils import ALL_IMPORTABLE_CLASSES, LOADABLE_
|
|
18
18
|
from transformers.utils import _LazyModule
|
19
19
|
|
20
20
|
|
21
|
-
LOADABLE_CLASSES["optimum.rbln"] = {
|
21
|
+
LOADABLE_CLASSES["optimum.rbln"] = {
|
22
|
+
"RBLNBaseModel": ["save_pretrained", "from_pretrained"],
|
23
|
+
"RBLNCosmosSafetyChecker": ["save_pretrained", "from_pretrained"],
|
24
|
+
}
|
22
25
|
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES["optimum.rbln"])
|
23
26
|
|
24
27
|
|
25
28
|
_import_structure = {
|
26
29
|
"configurations": [
|
27
30
|
"RBLNAutoencoderKLConfig",
|
31
|
+
"RBLNAutoencoderKLCosmosConfig",
|
28
32
|
"RBLNControlNetModelConfig",
|
33
|
+
"RBLNCosmosTextToWorldPipelineConfig",
|
34
|
+
"RBLNCosmosVideoToWorldPipelineConfig",
|
35
|
+
"RBLNCosmosTransformer3DModelConfig",
|
29
36
|
"RBLNKandinskyV22CombinedPipelineConfig",
|
30
37
|
"RBLNKandinskyV22Img2ImgCombinedPipelineConfig",
|
31
38
|
"RBLNKandinskyV22Img2ImgPipelineConfig",
|
@@ -52,6 +59,9 @@ _import_structure = {
|
|
52
59
|
"RBLNVQModelConfig",
|
53
60
|
],
|
54
61
|
"pipelines": [
|
62
|
+
"RBLNCosmosTextToWorldPipeline",
|
63
|
+
"RBLNCosmosVideoToWorldPipeline",
|
64
|
+
"RBLNCosmosSafetyChecker",
|
55
65
|
"RBLNKandinskyV22CombinedPipeline",
|
56
66
|
"RBLNKandinskyV22Img2ImgCombinedPipeline",
|
57
67
|
"RBLNKandinskyV22InpaintCombinedPipeline",
|
@@ -76,8 +86,10 @@ _import_structure = {
|
|
76
86
|
],
|
77
87
|
"models": [
|
78
88
|
"RBLNAutoencoderKL",
|
89
|
+
"RBLNAutoencoderKLCosmos",
|
79
90
|
"RBLNUNet2DConditionModel",
|
80
91
|
"RBLNControlNetModel",
|
92
|
+
"RBLNCosmosTransformer3DModel",
|
81
93
|
"RBLNSD3Transformer2DModel",
|
82
94
|
"RBLNPriorTransformer",
|
83
95
|
"RBLNVQModel",
|
@@ -90,7 +102,11 @@ _import_structure = {
|
|
90
102
|
if TYPE_CHECKING:
|
91
103
|
from .configurations import (
|
92
104
|
RBLNAutoencoderKLConfig,
|
105
|
+
RBLNAutoencoderKLCosmosConfig,
|
93
106
|
RBLNControlNetModelConfig,
|
107
|
+
RBLNCosmosTextToWorldPipelineConfig,
|
108
|
+
RBLNCosmosTransformer3DModelConfig,
|
109
|
+
RBLNCosmosVideoToWorldPipelineConfig,
|
94
110
|
RBLNKandinskyV22CombinedPipelineConfig,
|
95
111
|
RBLNKandinskyV22Img2ImgCombinedPipelineConfig,
|
96
112
|
RBLNKandinskyV22Img2ImgPipelineConfig,
|
@@ -120,12 +136,16 @@ if TYPE_CHECKING:
|
|
120
136
|
from .models import (
|
121
137
|
RBLNAutoencoderKL,
|
122
138
|
RBLNControlNetModel,
|
139
|
+
RBLNCosmosTransformer3DModel,
|
123
140
|
RBLNPriorTransformer,
|
124
141
|
RBLNSD3Transformer2DModel,
|
125
142
|
RBLNUNet2DConditionModel,
|
126
143
|
RBLNVQModel,
|
127
144
|
)
|
128
145
|
from .pipelines import (
|
146
|
+
RBLNCosmosSafetyChecker,
|
147
|
+
RBLNCosmosTextToWorldPipeline,
|
148
|
+
RBLNCosmosVideoToWorldPipeline,
|
129
149
|
RBLNKandinskyV22CombinedPipeline,
|
130
150
|
RBLNKandinskyV22Img2ImgCombinedPipeline,
|
131
151
|
RBLNKandinskyV22Img2ImgPipeline,
|
@@ -1,12 +1,16 @@
|
|
1
1
|
from .models import (
|
2
2
|
RBLNAutoencoderKLConfig,
|
3
|
+
RBLNAutoencoderKLCosmosConfig,
|
3
4
|
RBLNControlNetModelConfig,
|
5
|
+
RBLNCosmosTransformer3DModelConfig,
|
4
6
|
RBLNPriorTransformerConfig,
|
5
7
|
RBLNSD3Transformer2DModelConfig,
|
6
8
|
RBLNUNet2DConditionModelConfig,
|
7
9
|
RBLNVQModelConfig,
|
8
10
|
)
|
9
11
|
from .pipelines import (
|
12
|
+
RBLNCosmosTextToWorldPipelineConfig,
|
13
|
+
RBLNCosmosVideoToWorldPipelineConfig,
|
10
14
|
RBLNKandinskyV22CombinedPipelineConfig,
|
11
15
|
RBLNKandinskyV22Img2ImgCombinedPipelineConfig,
|
12
16
|
RBLNKandinskyV22Img2ImgPipelineConfig,
|
@@ -1,6 +1,8 @@
|
|
1
1
|
from .configuration_autoencoder_kl import RBLNAutoencoderKLConfig
|
2
|
+
from .configuration_autoencoder_kl_cosmos import RBLNAutoencoderKLCosmosConfig
|
2
3
|
from .configuration_controlnet import RBLNControlNetModelConfig
|
3
4
|
from .configuration_prior_transformer import RBLNPriorTransformerConfig
|
5
|
+
from .configuration_transformer_cosmos import RBLNCosmosTransformer3DModelConfig
|
4
6
|
from .configuration_transformer_sd3 import RBLNSD3Transformer2DModelConfig
|
5
7
|
from .configuration_unet_2d_condition import RBLNUNet2DConditionModelConfig
|
6
8
|
from .configuration_vq_model import RBLNVQModelConfig
|
@@ -12,12 +12,19 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Optional, Tuple
|
15
|
+
from typing import Any, Dict, Optional, Tuple
|
16
16
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
18
18
|
|
19
19
|
|
20
20
|
class RBLNAutoencoderKLConfig(RBLNModelConfig):
|
21
|
+
"""
|
22
|
+
Configuration class for RBLN Variational Autoencoder (VAE) models.
|
23
|
+
|
24
|
+
This class inherits from RBLNModelConfig and provides specific configuration options
|
25
|
+
for VAE models used in diffusion-based image generation.
|
26
|
+
"""
|
27
|
+
|
21
28
|
def __init__(
|
22
29
|
self,
|
23
30
|
batch_size: Optional[int] = None,
|
@@ -26,7 +33,7 @@ class RBLNAutoencoderKLConfig(RBLNModelConfig):
|
|
26
33
|
vae_scale_factor: Optional[float] = None, # TODO: rename to scaling_factor
|
27
34
|
in_channels: Optional[int] = None,
|
28
35
|
latent_channels: Optional[int] = None,
|
29
|
-
**kwargs,
|
36
|
+
**kwargs: Dict[str, Any],
|
30
37
|
):
|
31
38
|
"""
|
32
39
|
Args:
|
@@ -0,0 +1,84 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Any, Dict, Optional
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
from ....utils.logging import get_logger
|
19
|
+
|
20
|
+
|
21
|
+
logger = get_logger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class RBLNAutoencoderKLCosmosConfig(RBLNModelConfig):
|
25
|
+
"""Configuration class for RBLN Cosmos Variational Autoencoder (VAE) models."""
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
batch_size: Optional[int] = None,
|
30
|
+
uses_encoder: Optional[bool] = None,
|
31
|
+
num_frames: Optional[int] = None,
|
32
|
+
height: Optional[int] = None,
|
33
|
+
width: Optional[int] = None,
|
34
|
+
num_channels_latents: Optional[int] = None,
|
35
|
+
vae_scale_factor_temporal: Optional[int] = None,
|
36
|
+
vae_scale_factor_spatial: Optional[int] = None,
|
37
|
+
use_slicing: Optional[bool] = None,
|
38
|
+
**kwargs: Dict[str, Any],
|
39
|
+
):
|
40
|
+
"""
|
41
|
+
Args:
|
42
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
43
|
+
uses_encoder (Optional[bool]): Whether to include the encoder part of the VAE in the model.
|
44
|
+
When False, only the decoder is used (for latent-to-video conversion).
|
45
|
+
num_frames (Optional[int]): The number of frames in the generated video. Defaults to 121.
|
46
|
+
height (Optional[int]): The height in pixels of the generated video. Defaults to 704.
|
47
|
+
width (Optional[int]): The width in pixels of the generated video. Defaults to 1280.
|
48
|
+
num_channels_latents (Optional[int]): The number of channels in latent space.
|
49
|
+
vae_scale_factor_temporal (Optional[int]): The scaling factor between time space and latent space.
|
50
|
+
Determines how much shorter the latent representations are compared to the original videos.
|
51
|
+
vae_scale_factor_spatial (Optional[int]): The scaling factor between pixel space and latent space.
|
52
|
+
Determines how much smaller the latent representations are compared to the original videos.
|
53
|
+
use_slicing (Optional[bool]): Enable sliced VAE encoding and decoding.
|
54
|
+
If True, the VAE will split the input tensor in slices to compute encoding or decoding in several steps.
|
55
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
56
|
+
|
57
|
+
Raises:
|
58
|
+
ValueError: If batch_size is not a positive integer.
|
59
|
+
"""
|
60
|
+
super().__init__(**kwargs)
|
61
|
+
# Since the Cosmos VAE Decoder already requires approximately 7.9 GiB of memory,
|
62
|
+
# Optimum-rbln cannot execute this model on RBLN-CA12 when the batch size > 1.
|
63
|
+
# However, the Cosmos VAE Decoder propose batch slicing when the batch size is greater than 1,
|
64
|
+
# Optimum-rbln utilize this method by compiling with batch_size=1 to enable batch slicing.
|
65
|
+
self.batch_size = batch_size or 1
|
66
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
67
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
68
|
+
elif self.batch_size > 1:
|
69
|
+
logger.warning("The batch size of Cosmos VAE Decoder will be explicitly 1 for memory efficiency.")
|
70
|
+
self.batch_size = 1
|
71
|
+
|
72
|
+
self.uses_encoder = uses_encoder
|
73
|
+
self.num_frames = num_frames or 121
|
74
|
+
self.height = height or 704
|
75
|
+
self.width = width or 1280
|
76
|
+
|
77
|
+
self.num_channels_latents = num_channels_latents
|
78
|
+
self.vae_scale_factor_temporal = vae_scale_factor_temporal
|
79
|
+
self.vae_scale_factor_spatial = vae_scale_factor_spatial
|
80
|
+
self.use_slicing = use_slicing or False
|
81
|
+
|
82
|
+
@property
|
83
|
+
def image_size(self):
|
84
|
+
return (self.height, self.width)
|
@@ -12,12 +12,14 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Optional, Tuple
|
15
|
+
from typing import Any, Dict, Optional, Tuple
|
16
16
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
18
18
|
|
19
19
|
|
20
20
|
class RBLNControlNetModelConfig(RBLNModelConfig):
|
21
|
+
"""Configuration class for RBLN ControlNet models."""
|
22
|
+
|
21
23
|
subclass_non_save_attributes = ["_batch_size_is_specified"]
|
22
24
|
|
23
25
|
def __init__(
|
@@ -27,7 +29,7 @@ class RBLNControlNetModelConfig(RBLNModelConfig):
|
|
27
29
|
unet_sample_size: Optional[Tuple[int, int]] = None,
|
28
30
|
vae_sample_size: Optional[Tuple[int, int]] = None,
|
29
31
|
text_model_hidden_size: Optional[int] = None,
|
30
|
-
**kwargs,
|
32
|
+
**kwargs: Dict[str, Any],
|
31
33
|
):
|
32
34
|
"""
|
33
35
|
Args:
|
@@ -12,12 +12,19 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Optional
|
15
|
+
from typing import Any, Dict, Optional
|
16
16
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
18
18
|
|
19
19
|
|
20
20
|
class RBLNPriorTransformerConfig(RBLNModelConfig):
|
21
|
+
"""
|
22
|
+
Configuration class for RBLN Prior Transformer models.
|
23
|
+
|
24
|
+
This class inherits from RBLNModelConfig and provides specific configuration options
|
25
|
+
for Prior Transformer models used in diffusion models like Kandinsky V2.2.
|
26
|
+
"""
|
27
|
+
|
21
28
|
subclass_non_save_attributes = ["_batch_size_is_specified"]
|
22
29
|
|
23
30
|
def __init__(
|
@@ -25,7 +32,7 @@ class RBLNPriorTransformerConfig(RBLNModelConfig):
|
|
25
32
|
batch_size: Optional[int] = None,
|
26
33
|
embedding_dim: Optional[int] = None,
|
27
34
|
num_embeddings: Optional[int] = None,
|
28
|
-
**kwargs,
|
35
|
+
**kwargs: Dict[str, Any],
|
29
36
|
):
|
30
37
|
"""
|
31
38
|
Args:
|
@@ -0,0 +1,70 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Any, Dict, Optional
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
|
19
|
+
|
20
|
+
class RBLNCosmosTransformer3DModelConfig(RBLNModelConfig):
|
21
|
+
"""Configuration class for RBLN Cosmos Transformer models."""
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
batch_size: Optional[int] = None,
|
26
|
+
num_frames: Optional[int] = None,
|
27
|
+
height: Optional[int] = None,
|
28
|
+
width: Optional[int] = None,
|
29
|
+
fps: Optional[int] = None,
|
30
|
+
max_seq_len: Optional[int] = None,
|
31
|
+
embedding_dim: Optional[int] = None,
|
32
|
+
num_channels_latents: Optional[int] = None,
|
33
|
+
num_latent_frames: Optional[int] = None,
|
34
|
+
latent_height: Optional[int] = None,
|
35
|
+
latent_width: Optional[int] = None,
|
36
|
+
**kwargs: Dict[str, Any],
|
37
|
+
):
|
38
|
+
"""
|
39
|
+
Args:
|
40
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
41
|
+
num_frames (Optional[int]): The number of frames in the generated video. Defaults to 121.
|
42
|
+
height (Optional[int]): The height in pixels of the generated video. Defaults to 704.
|
43
|
+
width (Optional[int]): The width in pixels of the generated video. Defaults to 1280.
|
44
|
+
fps (Optional[int]): The frames per second of the generated video. Defaults to 30.
|
45
|
+
max_seq_len (Optional[int]): Maximum sequence length of prompt embeds.
|
46
|
+
embedding_dim (Optional[int]): Embedding vector dimension of prompt embeds.
|
47
|
+
num_channels_latents (Optional[int]): The number of channels in latent space.
|
48
|
+
latent_height (Optional[int]): The height in pixels in latent space.
|
49
|
+
latent_width (Optional[int]): The width in pixels in latent space.
|
50
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
51
|
+
|
52
|
+
Raises:
|
53
|
+
ValueError: If batch_size is not a positive integer.
|
54
|
+
"""
|
55
|
+
super().__init__(**kwargs)
|
56
|
+
self.batch_size = batch_size or 1
|
57
|
+
self.num_frames = num_frames or 121
|
58
|
+
self.height = height or 704
|
59
|
+
self.width = width or 1280
|
60
|
+
self.fps = fps or 30
|
61
|
+
|
62
|
+
self.max_seq_len = max_seq_len
|
63
|
+
self.num_channels_latents = num_channels_latents
|
64
|
+
self.num_latent_frames = num_latent_frames
|
65
|
+
self.latent_height = latent_height
|
66
|
+
self.latent_width = latent_width
|
67
|
+
self.embedding_dim = embedding_dim
|
68
|
+
|
69
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
70
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
@@ -12,12 +12,14 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Optional, Tuple, Union
|
15
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
16
16
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
18
18
|
|
19
19
|
|
20
20
|
class RBLNSD3Transformer2DModelConfig(RBLNModelConfig):
|
21
|
+
"""Configuration class for RBLN Stable Diffusion 3 Transformer models."""
|
22
|
+
|
21
23
|
subclass_non_save_attributes = ["_batch_size_is_specified"]
|
22
24
|
|
23
25
|
def __init__(
|
@@ -25,7 +27,7 @@ class RBLNSD3Transformer2DModelConfig(RBLNModelConfig):
|
|
25
27
|
batch_size: Optional[int] = None,
|
26
28
|
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
27
29
|
prompt_embed_length: Optional[int] = None,
|
28
|
-
**kwargs,
|
30
|
+
**kwargs: Dict[str, Any],
|
29
31
|
):
|
30
32
|
"""
|
31
33
|
Args:
|
@@ -12,12 +12,19 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Optional, Tuple
|
15
|
+
from typing import Any, Dict, Optional, Tuple
|
16
16
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
18
18
|
|
19
19
|
|
20
20
|
class RBLNUNet2DConditionModelConfig(RBLNModelConfig):
|
21
|
+
"""
|
22
|
+
Configuration class for RBLN UNet2DCondition models.
|
23
|
+
|
24
|
+
This class inherits from RBLNModelConfig and provides specific configuration options
|
25
|
+
for UNet2DCondition models used in diffusion-based image generation.
|
26
|
+
"""
|
27
|
+
|
21
28
|
subclass_non_save_attributes = ["_batch_size_is_specified"]
|
22
29
|
|
23
30
|
def __init__(
|
@@ -31,7 +38,7 @@ class RBLNUNet2DConditionModelConfig(RBLNModelConfig):
|
|
31
38
|
in_features: Optional[int] = None,
|
32
39
|
text_model_hidden_size: Optional[int] = None,
|
33
40
|
image_model_hidden_size: Optional[int] = None,
|
34
|
-
**kwargs,
|
41
|
+
**kwargs: Dict[str, Any],
|
35
42
|
):
|
36
43
|
"""
|
37
44
|
Args:
|