optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +116 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +171 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +33 -18
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +50 -24
- optimum/rbln/modeling_base.py +116 -35
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +100 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +93 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +22 -50
- optimum/rbln/utils/runtime_utils.py +85 -17
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
@@ -22,7 +22,7 @@ class RBLNPriorTransformerConfig(RBLNModelConfig):
|
|
|
22
22
|
Configuration class for RBLN Prior Transformer models.
|
|
23
23
|
|
|
24
24
|
This class inherits from RBLNModelConfig and provides specific configuration options
|
|
25
|
-
for
|
|
25
|
+
for Transformer models used in diffusion models like Kandinsky V2.2.
|
|
26
26
|
"""
|
|
27
27
|
|
|
28
28
|
subclass_non_save_attributes = ["_batch_size_is_specified"]
|
|
@@ -32,14 +32,14 @@ class RBLNPriorTransformerConfig(RBLNModelConfig):
|
|
|
32
32
|
batch_size: Optional[int] = None,
|
|
33
33
|
embedding_dim: Optional[int] = None,
|
|
34
34
|
num_embeddings: Optional[int] = None,
|
|
35
|
-
**kwargs:
|
|
35
|
+
**kwargs: Any,
|
|
36
36
|
):
|
|
37
37
|
"""
|
|
38
38
|
Args:
|
|
39
39
|
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
40
40
|
embedding_dim (Optional[int]): Dimension of the embedding vectors in the model.
|
|
41
41
|
num_embeddings (Optional[int]): Number of discrete embeddings in the codebook.
|
|
42
|
-
|
|
42
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
43
43
|
|
|
44
44
|
Raises:
|
|
45
45
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -12,13 +12,18 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class RBLNCosmosTransformer3DModelConfig(RBLNModelConfig):
|
|
21
|
-
"""
|
|
21
|
+
"""
|
|
22
|
+
Configuration class for RBLN Cosmos Transformer models.
|
|
23
|
+
|
|
24
|
+
This class inherits from RBLNModelConfig and provides specific configuration options
|
|
25
|
+
for Transformer models used in diffusion models like Cosmos.
|
|
26
|
+
"""
|
|
22
27
|
|
|
23
28
|
def __init__(
|
|
24
29
|
self,
|
|
@@ -33,7 +38,7 @@ class RBLNCosmosTransformer3DModelConfig(RBLNModelConfig):
|
|
|
33
38
|
num_latent_frames: Optional[int] = None,
|
|
34
39
|
latent_height: Optional[int] = None,
|
|
35
40
|
latent_width: Optional[int] = None,
|
|
36
|
-
**kwargs:
|
|
41
|
+
**kwargs: Any,
|
|
37
42
|
):
|
|
38
43
|
"""
|
|
39
44
|
Args:
|
|
@@ -47,11 +52,14 @@ class RBLNCosmosTransformer3DModelConfig(RBLNModelConfig):
|
|
|
47
52
|
num_channels_latents (Optional[int]): The number of channels in latent space.
|
|
48
53
|
latent_height (Optional[int]): The height in pixels in latent space.
|
|
49
54
|
latent_width (Optional[int]): The width in pixels in latent space.
|
|
50
|
-
|
|
55
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
51
56
|
|
|
52
57
|
Raises:
|
|
53
58
|
ValueError: If batch_size is not a positive integer.
|
|
54
59
|
"""
|
|
60
|
+
if kwargs.get("timeout") is None:
|
|
61
|
+
kwargs["timeout"] = 80
|
|
62
|
+
|
|
55
63
|
super().__init__(**kwargs)
|
|
56
64
|
self.batch_size = batch_size or 1
|
|
57
65
|
self.num_frames = num_frames or 121
|
|
@@ -12,13 +12,18 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional, Tuple, Union
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class RBLNSD3Transformer2DModelConfig(RBLNModelConfig):
|
|
21
|
-
"""
|
|
21
|
+
"""
|
|
22
|
+
Configuration class for RBLN Stable Diffusion 3 Transformer models.
|
|
23
|
+
|
|
24
|
+
This class inherits from RBLNModelConfig and provides specific configuration options
|
|
25
|
+
for Transformer models used in diffusion models like Stable Diffusion 3.
|
|
26
|
+
"""
|
|
22
27
|
|
|
23
28
|
subclass_non_save_attributes = ["_batch_size_is_specified"]
|
|
24
29
|
|
|
@@ -27,7 +32,7 @@ class RBLNSD3Transformer2DModelConfig(RBLNModelConfig):
|
|
|
27
32
|
batch_size: Optional[int] = None,
|
|
28
33
|
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
29
34
|
prompt_embed_length: Optional[int] = None,
|
|
30
|
-
**kwargs:
|
|
35
|
+
**kwargs: Any,
|
|
31
36
|
):
|
|
32
37
|
"""
|
|
33
38
|
Args:
|
|
@@ -36,7 +41,7 @@ class RBLNSD3Transformer2DModelConfig(RBLNModelConfig):
|
|
|
36
41
|
of the generated samples. If an integer is provided, it's used for both height and width.
|
|
37
42
|
prompt_embed_length (Optional[int]): The length of the embedded prompt vectors that
|
|
38
43
|
will be used to condition the transformer model.
|
|
39
|
-
|
|
44
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
40
45
|
|
|
41
46
|
Raises:
|
|
42
47
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
@@ -38,7 +38,7 @@ class RBLNUNet2DConditionModelConfig(RBLNModelConfig):
|
|
|
38
38
|
in_features: Optional[int] = None,
|
|
39
39
|
text_model_hidden_size: Optional[int] = None,
|
|
40
40
|
image_model_hidden_size: Optional[int] = None,
|
|
41
|
-
**kwargs:
|
|
41
|
+
**kwargs: Any,
|
|
42
42
|
):
|
|
43
43
|
"""
|
|
44
44
|
Args:
|
|
@@ -52,7 +52,7 @@ class RBLNUNet2DConditionModelConfig(RBLNModelConfig):
|
|
|
52
52
|
in_features (Optional[int]): Number of input features for the model.
|
|
53
53
|
text_model_hidden_size (Optional[int]): Hidden size of the text encoder model.
|
|
54
54
|
image_model_hidden_size (Optional[int]): Hidden size of the image encoder model.
|
|
55
|
-
|
|
55
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
56
56
|
|
|
57
57
|
Raises:
|
|
58
58
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
|
+
|
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RBLNUNetSpatioTemporalConditionModelConfig(RBLNModelConfig):
|
|
21
|
+
subclass_non_save_attributes = ["_batch_size_is_specified"]
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
batch_size: Optional[int] = None,
|
|
26
|
+
sample_size: Optional[Tuple[int, int]] = None,
|
|
27
|
+
in_features: Optional[int] = None,
|
|
28
|
+
num_frames: Optional[int] = None,
|
|
29
|
+
**kwargs: Any,
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Args:
|
|
33
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
34
|
+
sample_size (Optional[Tuple[int, int]]): The spatial dimensions (height, width) of the generated samples.
|
|
35
|
+
If an integer is provided, it's used for both height and width.
|
|
36
|
+
in_features (Optional[int]): Number of input features for the model.
|
|
37
|
+
num_frames (Optional[int]): The number of frames in the generated video.
|
|
38
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
ValueError: If batch_size is not a positive integer.
|
|
42
|
+
"""
|
|
43
|
+
super().__init__(**kwargs)
|
|
44
|
+
self._batch_size_is_specified = batch_size is not None
|
|
45
|
+
|
|
46
|
+
self.batch_size = batch_size or 1
|
|
47
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
48
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
49
|
+
|
|
50
|
+
self.in_features = in_features
|
|
51
|
+
self.num_frames = num_frames
|
|
52
|
+
|
|
53
|
+
self.sample_size = sample_size
|
|
54
|
+
if isinstance(sample_size, int):
|
|
55
|
+
self.sample_size = (sample_size, sample_size)
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def batch_size_is_specified(self):
|
|
59
|
+
return self._batch_size_is_specified
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
@@ -33,7 +33,7 @@ class RBLNVQModelConfig(RBLNModelConfig):
|
|
|
33
33
|
vqmodel_scale_factor: Optional[float] = None, # TODO: rename to scaling_factor
|
|
34
34
|
in_channels: Optional[int] = None,
|
|
35
35
|
latent_channels: Optional[int] = None,
|
|
36
|
-
**kwargs:
|
|
36
|
+
**kwargs: Any,
|
|
37
37
|
):
|
|
38
38
|
"""
|
|
39
39
|
Args:
|
|
@@ -46,7 +46,7 @@ class RBLNVQModelConfig(RBLNModelConfig):
|
|
|
46
46
|
Determines the downsampling ratio between original images and latent representations.
|
|
47
47
|
in_channels (Optional[int]): Number of input channels for the model.
|
|
48
48
|
latent_channels (Optional[int]): Number of channels in the latent space.
|
|
49
|
-
|
|
49
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
50
50
|
|
|
51
51
|
Raises:
|
|
52
52
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
from ....transformers import RBLNCLIPTextModelConfig, RBLNCLIPTextModelWithProjectionConfig
|
|
@@ -38,7 +38,7 @@ class RBLNStableDiffusionControlNetPipelineBaseConfig(RBLNModelConfig):
|
|
|
38
38
|
sample_size: Optional[Tuple[int, int]] = None,
|
|
39
39
|
image_size: Optional[Tuple[int, int]] = None,
|
|
40
40
|
guidance_scale: Optional[float] = None,
|
|
41
|
-
**kwargs:
|
|
41
|
+
**kwargs: Any,
|
|
42
42
|
):
|
|
43
43
|
"""
|
|
44
44
|
Args:
|
|
@@ -59,7 +59,7 @@ class RBLNStableDiffusionControlNetPipelineBaseConfig(RBLNModelConfig):
|
|
|
59
59
|
image_size (Optional[Tuple[int, int]]): Alternative way to specify image dimensions.
|
|
60
60
|
Cannot be used together with img_height/img_width.
|
|
61
61
|
guidance_scale (Optional[float]): Scale for classifier-free guidance.
|
|
62
|
-
|
|
62
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
63
63
|
|
|
64
64
|
Raises:
|
|
65
65
|
ValueError: If both image_size and img_height/img_width are provided.
|
|
@@ -93,20 +93,27 @@ class RBLNStableDiffusionControlNetPipelineBaseConfig(RBLNModelConfig):
|
|
|
93
93
|
elif (img_height is not None and img_width is None) or (img_height is None and img_width is not None):
|
|
94
94
|
raise ValueError("Both img_height and img_width must be provided together if used")
|
|
95
95
|
|
|
96
|
-
self.text_encoder = self.
|
|
97
|
-
|
|
98
|
-
|
|
96
|
+
self.text_encoder = self.initialize_submodule_config(
|
|
97
|
+
text_encoder,
|
|
98
|
+
cls_name="RBLNCLIPTextModelConfig",
|
|
99
|
+
batch_size=batch_size,
|
|
100
|
+
)
|
|
101
|
+
self.unet = self.initialize_submodule_config(
|
|
99
102
|
unet,
|
|
103
|
+
cls_name="RBLNUNet2DConditionModelConfig",
|
|
100
104
|
sample_size=sample_size,
|
|
101
105
|
)
|
|
102
|
-
self.vae = self.
|
|
103
|
-
RBLNAutoencoderKLConfig,
|
|
106
|
+
self.vae = self.initialize_submodule_config(
|
|
104
107
|
vae,
|
|
108
|
+
cls_name="RBLNAutoencoderKLConfig",
|
|
105
109
|
batch_size=batch_size,
|
|
106
110
|
uses_encoder=self.__class__._vae_uses_encoder,
|
|
107
111
|
sample_size=image_size, # image size is equal to sample size in vae
|
|
108
112
|
)
|
|
109
|
-
self.controlnet = self.
|
|
113
|
+
self.controlnet = self.initialize_submodule_config(
|
|
114
|
+
controlnet,
|
|
115
|
+
cls_name="RBLNControlNetModelConfig",
|
|
116
|
+
)
|
|
110
117
|
|
|
111
118
|
# Get default guidance scale from original class to set UNet and ControlNet batch size
|
|
112
119
|
if guidance_scale is None:
|
|
@@ -178,7 +185,7 @@ class RBLNStableDiffusionXLControlNetPipelineBaseConfig(RBLNModelConfig):
|
|
|
178
185
|
sample_size: Optional[Tuple[int, int]] = None,
|
|
179
186
|
image_size: Optional[Tuple[int, int]] = None,
|
|
180
187
|
guidance_scale: Optional[float] = None,
|
|
181
|
-
**kwargs:
|
|
188
|
+
**kwargs: Any,
|
|
182
189
|
):
|
|
183
190
|
"""
|
|
184
191
|
Args:
|
|
@@ -201,7 +208,7 @@ class RBLNStableDiffusionXLControlNetPipelineBaseConfig(RBLNModelConfig):
|
|
|
201
208
|
image_size (Optional[Tuple[int, int]]): Alternative way to specify image dimensions.
|
|
202
209
|
Cannot be used together with img_height/img_width.
|
|
203
210
|
guidance_scale (Optional[float]): Scale for classifier-free guidance.
|
|
204
|
-
|
|
211
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
205
212
|
|
|
206
213
|
Raises:
|
|
207
214
|
ValueError: If both image_size and img_height/img_width are provided.
|
|
@@ -235,23 +242,32 @@ class RBLNStableDiffusionXLControlNetPipelineBaseConfig(RBLNModelConfig):
|
|
|
235
242
|
elif (img_height is not None and img_width is None) or (img_height is None and img_width is not None):
|
|
236
243
|
raise ValueError("Both img_height and img_width must be provided together if used")
|
|
237
244
|
|
|
238
|
-
self.text_encoder = self.
|
|
239
|
-
|
|
240
|
-
|
|
245
|
+
self.text_encoder = self.initialize_submodule_config(
|
|
246
|
+
text_encoder,
|
|
247
|
+
cls_name="RBLNCLIPTextModelConfig",
|
|
248
|
+
batch_size=batch_size,
|
|
241
249
|
)
|
|
242
|
-
self.
|
|
243
|
-
|
|
250
|
+
self.text_encoder_2 = self.initialize_submodule_config(
|
|
251
|
+
text_encoder_2,
|
|
252
|
+
cls_name="RBLNCLIPTextModelWithProjectionConfig",
|
|
253
|
+
batch_size=batch_size,
|
|
254
|
+
)
|
|
255
|
+
self.unet = self.initialize_submodule_config(
|
|
244
256
|
unet,
|
|
257
|
+
cls_name="RBLNUNet2DConditionModelConfig",
|
|
245
258
|
sample_size=sample_size,
|
|
246
259
|
)
|
|
247
|
-
self.vae = self.
|
|
248
|
-
RBLNAutoencoderKLConfig,
|
|
260
|
+
self.vae = self.initialize_submodule_config(
|
|
249
261
|
vae,
|
|
262
|
+
cls_name="RBLNAutoencoderKLConfig",
|
|
250
263
|
batch_size=batch_size,
|
|
251
264
|
uses_encoder=self.__class__._vae_uses_encoder,
|
|
252
265
|
sample_size=image_size, # image size is equal to sample size in vae
|
|
253
266
|
)
|
|
254
|
-
self.controlnet = self.
|
|
267
|
+
self.controlnet = self.initialize_submodule_config(
|
|
268
|
+
controlnet,
|
|
269
|
+
cls_name="RBLNControlNetModelConfig",
|
|
270
|
+
)
|
|
255
271
|
|
|
256
272
|
# Get default guidance scale from original class to set UNet and ControlNet batch size
|
|
257
273
|
guidance_scale = (
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
from ....transformers import RBLNT5EncoderModelConfig
|
|
@@ -41,7 +41,7 @@ class RBLNCosmosPipelineBaseConfig(RBLNModelConfig):
|
|
|
41
41
|
num_frames: Optional[int] = None,
|
|
42
42
|
fps: Optional[int] = None,
|
|
43
43
|
max_seq_len: Optional[int] = None,
|
|
44
|
-
**kwargs:
|
|
44
|
+
**kwargs: Any,
|
|
45
45
|
):
|
|
46
46
|
"""
|
|
47
47
|
Args:
|
|
@@ -59,16 +59,19 @@ class RBLNCosmosPipelineBaseConfig(RBLNModelConfig):
|
|
|
59
59
|
num_frames (Optional[int]): The number of frames in the generated video.
|
|
60
60
|
fps (Optional[int]): The frames per second of the generated video.
|
|
61
61
|
max_seq_len (Optional[int]): Maximum sequence length supported by the model.
|
|
62
|
-
|
|
62
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
63
63
|
"""
|
|
64
64
|
super().__init__(**kwargs)
|
|
65
65
|
|
|
66
|
-
self.text_encoder = self.
|
|
67
|
-
|
|
66
|
+
self.text_encoder = self.initialize_submodule_config(
|
|
67
|
+
text_encoder,
|
|
68
|
+
cls_name="RBLNT5EncoderModelConfig",
|
|
69
|
+
batch_size=batch_size,
|
|
70
|
+
max_seq_len=max_seq_len,
|
|
68
71
|
)
|
|
69
|
-
self.transformer = self.
|
|
70
|
-
RBLNCosmosTransformer3DModelConfig,
|
|
72
|
+
self.transformer = self.initialize_submodule_config(
|
|
71
73
|
transformer,
|
|
74
|
+
cls_name="RBLNCosmosTransformer3DModelConfig",
|
|
72
75
|
batch_size=batch_size,
|
|
73
76
|
max_seq_len=max_seq_len,
|
|
74
77
|
height=height,
|
|
@@ -76,18 +79,18 @@ class RBLNCosmosPipelineBaseConfig(RBLNModelConfig):
|
|
|
76
79
|
num_frames=num_frames,
|
|
77
80
|
fps=fps,
|
|
78
81
|
)
|
|
79
|
-
self.vae = self.
|
|
80
|
-
RBLNAutoencoderKLCosmosConfig,
|
|
82
|
+
self.vae = self.initialize_submodule_config(
|
|
81
83
|
vae,
|
|
84
|
+
cls_name="RBLNAutoencoderKLCosmosConfig",
|
|
82
85
|
batch_size=batch_size,
|
|
83
86
|
uses_encoder=self.__class__._vae_uses_encoder,
|
|
84
87
|
height=height,
|
|
85
88
|
width=width,
|
|
86
89
|
num_frames=num_frames,
|
|
87
90
|
)
|
|
88
|
-
self.safety_checker = self.
|
|
89
|
-
RBLNCosmosSafetyCheckerConfig,
|
|
91
|
+
self.safety_checker = self.initialize_submodule_config(
|
|
90
92
|
safety_checker,
|
|
93
|
+
cls_name="RBLNCosmosSafetyCheckerConfig",
|
|
91
94
|
batch_size=batch_size,
|
|
92
95
|
height=height,
|
|
93
96
|
width=width,
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
from ....transformers import RBLNCLIPTextModelWithProjectionConfig, RBLNCLIPVisionModelWithProjectionConfig
|
|
@@ -37,7 +37,7 @@ class RBLNKandinskyV22PipelineBaseConfig(RBLNModelConfig):
|
|
|
37
37
|
img_width: Optional[int] = None,
|
|
38
38
|
height: Optional[int] = None,
|
|
39
39
|
width: Optional[int] = None,
|
|
40
|
-
**kwargs:
|
|
40
|
+
**kwargs: Any,
|
|
41
41
|
):
|
|
42
42
|
"""
|
|
43
43
|
Args:
|
|
@@ -54,7 +54,7 @@ class RBLNKandinskyV22PipelineBaseConfig(RBLNModelConfig):
|
|
|
54
54
|
img_width (Optional[int]): Width of the generated images.
|
|
55
55
|
height (Optional[int]): Height of the generated images.
|
|
56
56
|
width (Optional[int]): Width of the generated images.
|
|
57
|
-
|
|
57
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
58
58
|
|
|
59
59
|
Raises:
|
|
60
60
|
ValueError: If both image_size and img_height/img_width are provided.
|
|
@@ -88,10 +88,14 @@ class RBLNKandinskyV22PipelineBaseConfig(RBLNModelConfig):
|
|
|
88
88
|
elif (img_height is not None and img_width is None) or (img_height is None and img_width is not None):
|
|
89
89
|
raise ValueError("Both img_height and img_width must be provided together if used")
|
|
90
90
|
|
|
91
|
-
self.unet = self.
|
|
92
|
-
|
|
93
|
-
|
|
91
|
+
self.unet = self.initialize_submodule_config(
|
|
92
|
+
unet,
|
|
93
|
+
cls_name="RBLNUNet2DConditionModelConfig",
|
|
94
|
+
sample_size=sample_size,
|
|
95
|
+
)
|
|
96
|
+
self.movq = self.initialize_submodule_config(
|
|
94
97
|
movq,
|
|
98
|
+
cls_name="RBLNVQModelConfig",
|
|
95
99
|
batch_size=batch_size,
|
|
96
100
|
sample_size=image_size, # image size is equal to sample size in vae
|
|
97
101
|
uses_encoder=self._movq_uses_encoder,
|
|
@@ -148,7 +152,7 @@ class RBLNKandinskyV22PriorPipelineConfig(RBLNModelConfig):
|
|
|
148
152
|
*,
|
|
149
153
|
batch_size: Optional[int] = None,
|
|
150
154
|
guidance_scale: Optional[float] = None,
|
|
151
|
-
**kwargs:
|
|
155
|
+
**kwargs: Any,
|
|
152
156
|
):
|
|
153
157
|
"""
|
|
154
158
|
Initialize a configuration for Kandinsky 2.2 prior pipeline optimized for RBLN NPU.
|
|
@@ -166,21 +170,27 @@ class RBLNKandinskyV22PriorPipelineConfig(RBLNModelConfig):
|
|
|
166
170
|
Initialized as RBLNPriorTransformerConfig if not provided.
|
|
167
171
|
batch_size (Optional[int]): Batch size for inference, applied to all submodules.
|
|
168
172
|
guidance_scale (Optional[float]): Scale for classifier-free guidance.
|
|
169
|
-
|
|
173
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
170
174
|
|
|
171
175
|
Note:
|
|
172
176
|
When guidance_scale > 1.0, the prior batch size is automatically doubled to
|
|
173
177
|
accommodate classifier-free guidance.
|
|
174
178
|
"""
|
|
175
179
|
super().__init__(**kwargs)
|
|
176
|
-
self.text_encoder = self.
|
|
177
|
-
|
|
180
|
+
self.text_encoder = self.initialize_submodule_config(
|
|
181
|
+
text_encoder,
|
|
182
|
+
cls_name="RBLNCLIPTextModelWithProjectionConfig",
|
|
183
|
+
batch_size=batch_size,
|
|
178
184
|
)
|
|
179
|
-
self.image_encoder = self.
|
|
180
|
-
|
|
185
|
+
self.image_encoder = self.initialize_submodule_config(
|
|
186
|
+
image_encoder,
|
|
187
|
+
cls_name="RBLNCLIPVisionModelWithProjectionConfig",
|
|
188
|
+
batch_size=batch_size,
|
|
189
|
+
)
|
|
190
|
+
self.prior = self.initialize_submodule_config(
|
|
191
|
+
prior,
|
|
192
|
+
cls_name="RBLNPriorTransformerConfig",
|
|
181
193
|
)
|
|
182
|
-
|
|
183
|
-
self.prior = self.init_submodule_config(RBLNPriorTransformerConfig, prior)
|
|
184
194
|
|
|
185
195
|
# Get default guidance scale from original class to set UNet batch size
|
|
186
196
|
if guidance_scale is None:
|
|
@@ -226,7 +236,7 @@ class RBLNKandinskyV22CombinedPipelineBaseConfig(RBLNModelConfig):
|
|
|
226
236
|
prior_text_encoder: Optional[RBLNCLIPTextModelWithProjectionConfig] = None,
|
|
227
237
|
unet: Optional[RBLNUNet2DConditionModelConfig] = None,
|
|
228
238
|
movq: Optional[RBLNVQModelConfig] = None,
|
|
229
|
-
**kwargs:
|
|
239
|
+
**kwargs: Any,
|
|
230
240
|
):
|
|
231
241
|
"""
|
|
232
242
|
Initialize a configuration for combined Kandinsky 2.2 pipelines optimized for RBLN NPU.
|
|
@@ -259,7 +269,7 @@ class RBLNKandinskyV22CombinedPipelineBaseConfig(RBLNModelConfig):
|
|
|
259
269
|
Used if decoder_pipe is not provided.
|
|
260
270
|
movq (Optional[RBLNVQModelConfig]): Direct configuration for the MoVQ (VQ-GAN) model.
|
|
261
271
|
Used if decoder_pipe is not provided.
|
|
262
|
-
|
|
272
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
263
273
|
"""
|
|
264
274
|
super().__init__(**kwargs)
|
|
265
275
|
|
|
@@ -286,18 +296,18 @@ class RBLNKandinskyV22CombinedPipelineBaseConfig(RBLNModelConfig):
|
|
|
286
296
|
elif (img_height is not None and img_width is None) or (img_height is None and img_width is not None):
|
|
287
297
|
raise ValueError("Both img_height and img_width must be provided together if used")
|
|
288
298
|
|
|
289
|
-
self.prior_pipe = self.
|
|
290
|
-
RBLNKandinskyV22PriorPipelineConfig,
|
|
299
|
+
self.prior_pipe = self.initialize_submodule_config(
|
|
291
300
|
prior_pipe,
|
|
301
|
+
cls_name="RBLNKandinskyV22PriorPipelineConfig",
|
|
292
302
|
prior=prior_prior,
|
|
293
303
|
image_encoder=prior_image_encoder,
|
|
294
304
|
text_encoder=prior_text_encoder,
|
|
295
305
|
batch_size=batch_size,
|
|
296
306
|
guidance_scale=guidance_scale,
|
|
297
307
|
)
|
|
298
|
-
self.decoder_pipe = self.
|
|
299
|
-
self._decoder_pipe_cls,
|
|
308
|
+
self.decoder_pipe = self.initialize_submodule_config(
|
|
300
309
|
decoder_pipe,
|
|
310
|
+
cls_name=self._decoder_pipe_cls.__name__,
|
|
301
311
|
unet=unet,
|
|
302
312
|
movq=movq,
|
|
303
313
|
batch_size=batch_size,
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
from ....transformers import RBLNCLIPTextModelConfig
|
|
@@ -37,7 +37,7 @@ class RBLNStableDiffusionPipelineBaseConfig(RBLNModelConfig):
|
|
|
37
37
|
sample_size: Optional[Tuple[int, int]] = None,
|
|
38
38
|
image_size: Optional[Tuple[int, int]] = None,
|
|
39
39
|
guidance_scale: Optional[float] = None,
|
|
40
|
-
**kwargs:
|
|
40
|
+
**kwargs: Any,
|
|
41
41
|
):
|
|
42
42
|
"""
|
|
43
43
|
Args:
|
|
@@ -56,7 +56,7 @@ class RBLNStableDiffusionPipelineBaseConfig(RBLNModelConfig):
|
|
|
56
56
|
image_size (Optional[Tuple[int, int]]): Alternative way to specify image dimensions.
|
|
57
57
|
Cannot be used together with img_height/img_width.
|
|
58
58
|
guidance_scale (Optional[float]): Scale for classifier-free guidance.
|
|
59
|
-
|
|
59
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
60
60
|
|
|
61
61
|
Raises:
|
|
62
62
|
ValueError: If both image_size and img_height/img_width are provided.
|
|
@@ -90,18 +90,22 @@ class RBLNStableDiffusionPipelineBaseConfig(RBLNModelConfig):
|
|
|
90
90
|
elif (img_height is not None and img_width is None) or (img_height is None and img_width is not None):
|
|
91
91
|
raise ValueError("Both img_height and img_width must be provided together if used")
|
|
92
92
|
|
|
93
|
-
self.text_encoder = self.
|
|
94
|
-
|
|
95
|
-
|
|
93
|
+
self.text_encoder = self.initialize_submodule_config(
|
|
94
|
+
text_encoder,
|
|
95
|
+
cls_name="RBLNCLIPTextModelConfig",
|
|
96
|
+
batch_size=batch_size,
|
|
97
|
+
)
|
|
98
|
+
self.unet = self.initialize_submodule_config(
|
|
96
99
|
unet,
|
|
100
|
+
cls_name="RBLNUNet2DConditionModelConfig",
|
|
97
101
|
sample_size=sample_size,
|
|
98
102
|
)
|
|
99
|
-
self.vae = self.
|
|
100
|
-
RBLNAutoencoderKLConfig,
|
|
103
|
+
self.vae = self.initialize_submodule_config(
|
|
101
104
|
vae,
|
|
105
|
+
cls_name="RBLNAutoencoderKLConfig",
|
|
102
106
|
batch_size=batch_size,
|
|
103
107
|
uses_encoder=self.__class__._vae_uses_encoder,
|
|
104
|
-
sample_size=image_size,
|
|
108
|
+
sample_size=image_size,
|
|
105
109
|
)
|
|
106
110
|
|
|
107
111
|
# Get default guidance scale from original class to set UNet batch size
|