optimum-rbln 0.8.1a5__py3-none-any.whl → 0.8.1a7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +18 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/__init__.py +21 -1
- optimum/rbln/diffusers/configurations/__init__.py +4 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +82 -0
- optimum/rbln/diffusers/configurations/models/configuration_cosmos_transformer.py +68 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +0 -4
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +110 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +0 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +0 -4
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +1 -4
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +0 -4
- optimum/rbln/diffusers/modeling_diffusers.py +57 -40
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +6 -1
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +49 -5
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +6 -1
- optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
- optimum/rbln/diffusers/pipelines/__init__.py +10 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +451 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
- optimum/rbln/modeling.py +38 -2
- optimum/rbln/modeling_base.py +18 -2
- optimum/rbln/transformers/modeling_generic.py +3 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
- optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
- optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
- optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +13 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +15 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +12 -2
- optimum/rbln/transformers/models/clip/modeling_clip.py +27 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +22 -20
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +6 -1
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +8 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -3
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +8 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +16 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +6 -1
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +5 -1
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +12 -2
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +16 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +6 -2
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +7 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +7 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
- optimum/rbln/transformers/models/t5/modeling_t5.py +10 -4
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +7 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +6 -2
- optimum/rbln/transformers/models/vit/configuration_vit.py +6 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +7 -1
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +7 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +7 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +6 -2
- optimum/rbln/utils/runtime_utils.py +49 -1
- {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/RECORD +70 -60
- {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -138,8 +138,17 @@ _import_structure = {
|
|
138
138
|
"diffusers": [
|
139
139
|
"RBLNAutoencoderKL",
|
140
140
|
"RBLNAutoencoderKLConfig",
|
141
|
+
"RBLNAutoencoderKLCosmos",
|
142
|
+
"RBLNAutoencoderKLCosmosConfig",
|
141
143
|
"RBLNControlNetModel",
|
142
144
|
"RBLNControlNetModelConfig",
|
145
|
+
"RBLNCosmosTextToWorldPipeline",
|
146
|
+
"RBLNCosmosVideoToWorldPipeline",
|
147
|
+
"RBLNCosmosTextToWorldPipelineConfig",
|
148
|
+
"RBLNCosmosVideoToWorldPipelineConfig",
|
149
|
+
"RBLNCosmosSafetyChecker",
|
150
|
+
"RBLNCosmosTransformer3DModel",
|
151
|
+
"RBLNCosmosTransformer3DModelConfig",
|
143
152
|
"RBLNDiffusionMixin",
|
144
153
|
"RBLNKandinskyV22CombinedPipeline",
|
145
154
|
"RBLNKandinskyV22CombinedPipelineConfig",
|
@@ -202,8 +211,17 @@ if TYPE_CHECKING:
|
|
202
211
|
from .diffusers import (
|
203
212
|
RBLNAutoencoderKL,
|
204
213
|
RBLNAutoencoderKLConfig,
|
214
|
+
RBLNAutoencoderKLCosmos,
|
215
|
+
RBLNAutoencoderKLCosmosConfig,
|
205
216
|
RBLNControlNetModel,
|
206
217
|
RBLNControlNetModelConfig,
|
218
|
+
RBLNCosmosSafetyChecker,
|
219
|
+
RBLNCosmosTextToWorldPipeline,
|
220
|
+
RBLNCosmosTextToWorldPipelineConfig,
|
221
|
+
RBLNCosmosTransformer3DModel,
|
222
|
+
RBLNCosmosTransformer3DModelConfig,
|
223
|
+
RBLNCosmosVideoToWorldPipeline,
|
224
|
+
RBLNCosmosVideoToWorldPipelineConfig,
|
207
225
|
RBLNDiffusionMixin,
|
208
226
|
RBLNKandinskyV22CombinedPipeline,
|
209
227
|
RBLNKandinskyV22CombinedPipelineConfig,
|
optimum/rbln/__version__.py
CHANGED
@@ -17,5 +17,5 @@ __version__: str
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
18
18
|
version_tuple: VERSION_TUPLE
|
19
19
|
|
20
|
-
__version__ = version = '0.8.
|
21
|
-
__version_tuple__ = version_tuple = (0, 8, 1, '
|
20
|
+
__version__ = version = '0.8.1a7'
|
21
|
+
__version_tuple__ = version_tuple = (0, 8, 1, 'a7')
|
@@ -18,14 +18,21 @@ from diffusers.pipelines.pipeline_utils import ALL_IMPORTABLE_CLASSES, LOADABLE_
|
|
18
18
|
from transformers.utils import _LazyModule
|
19
19
|
|
20
20
|
|
21
|
-
LOADABLE_CLASSES["optimum.rbln"] = {
|
21
|
+
LOADABLE_CLASSES["optimum.rbln"] = {
|
22
|
+
"RBLNBaseModel": ["save_pretrained", "from_pretrained"],
|
23
|
+
"RBLNCosmosSafetyChecker": ["save_pretrained", "from_pretrained"],
|
24
|
+
}
|
22
25
|
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES["optimum.rbln"])
|
23
26
|
|
24
27
|
|
25
28
|
_import_structure = {
|
26
29
|
"configurations": [
|
27
30
|
"RBLNAutoencoderKLConfig",
|
31
|
+
"RBLNAutoencoderKLCosmosConfig",
|
28
32
|
"RBLNControlNetModelConfig",
|
33
|
+
"RBLNCosmosTextToWorldPipelineConfig",
|
34
|
+
"RBLNCosmosVideoToWorldPipelineConfig",
|
35
|
+
"RBLNCosmosTransformer3DModelConfig",
|
29
36
|
"RBLNKandinskyV22CombinedPipelineConfig",
|
30
37
|
"RBLNKandinskyV22Img2ImgCombinedPipelineConfig",
|
31
38
|
"RBLNKandinskyV22Img2ImgPipelineConfig",
|
@@ -52,6 +59,9 @@ _import_structure = {
|
|
52
59
|
"RBLNVQModelConfig",
|
53
60
|
],
|
54
61
|
"pipelines": [
|
62
|
+
"RBLNCosmosTextToWorldPipeline",
|
63
|
+
"RBLNCosmosVideoToWorldPipeline",
|
64
|
+
"RBLNCosmosSafetyChecker",
|
55
65
|
"RBLNKandinskyV22CombinedPipeline",
|
56
66
|
"RBLNKandinskyV22Img2ImgCombinedPipeline",
|
57
67
|
"RBLNKandinskyV22InpaintCombinedPipeline",
|
@@ -76,8 +86,10 @@ _import_structure = {
|
|
76
86
|
],
|
77
87
|
"models": [
|
78
88
|
"RBLNAutoencoderKL",
|
89
|
+
"RBLNAutoencoderKLCosmos",
|
79
90
|
"RBLNUNet2DConditionModel",
|
80
91
|
"RBLNControlNetModel",
|
92
|
+
"RBLNCosmosTransformer3DModel",
|
81
93
|
"RBLNSD3Transformer2DModel",
|
82
94
|
"RBLNPriorTransformer",
|
83
95
|
"RBLNVQModel",
|
@@ -90,7 +102,11 @@ _import_structure = {
|
|
90
102
|
if TYPE_CHECKING:
|
91
103
|
from .configurations import (
|
92
104
|
RBLNAutoencoderKLConfig,
|
105
|
+
RBLNAutoencoderKLCosmosConfig,
|
93
106
|
RBLNControlNetModelConfig,
|
107
|
+
RBLNCosmosTextToWorldPipelineConfig,
|
108
|
+
RBLNCosmosTransformer3DModelConfig,
|
109
|
+
RBLNCosmosVideoToWorldPipelineConfig,
|
94
110
|
RBLNKandinskyV22CombinedPipelineConfig,
|
95
111
|
RBLNKandinskyV22Img2ImgCombinedPipelineConfig,
|
96
112
|
RBLNKandinskyV22Img2ImgPipelineConfig,
|
@@ -120,12 +136,16 @@ if TYPE_CHECKING:
|
|
120
136
|
from .models import (
|
121
137
|
RBLNAutoencoderKL,
|
122
138
|
RBLNControlNetModel,
|
139
|
+
RBLNCosmosTransformer3DModel,
|
123
140
|
RBLNPriorTransformer,
|
124
141
|
RBLNSD3Transformer2DModel,
|
125
142
|
RBLNUNet2DConditionModel,
|
126
143
|
RBLNVQModel,
|
127
144
|
)
|
128
145
|
from .pipelines import (
|
146
|
+
RBLNCosmosSafetyChecker,
|
147
|
+
RBLNCosmosTextToWorldPipeline,
|
148
|
+
RBLNCosmosVideoToWorldPipeline,
|
129
149
|
RBLNKandinskyV22CombinedPipeline,
|
130
150
|
RBLNKandinskyV22Img2ImgCombinedPipeline,
|
131
151
|
RBLNKandinskyV22Img2ImgPipeline,
|
@@ -1,12 +1,16 @@
|
|
1
1
|
from .models import (
|
2
2
|
RBLNAutoencoderKLConfig,
|
3
|
+
RBLNAutoencoderKLCosmosConfig,
|
3
4
|
RBLNControlNetModelConfig,
|
5
|
+
RBLNCosmosTransformer3DModelConfig,
|
4
6
|
RBLNPriorTransformerConfig,
|
5
7
|
RBLNSD3Transformer2DModelConfig,
|
6
8
|
RBLNUNet2DConditionModelConfig,
|
7
9
|
RBLNVQModelConfig,
|
8
10
|
)
|
9
11
|
from .pipelines import (
|
12
|
+
RBLNCosmosTextToWorldPipelineConfig,
|
13
|
+
RBLNCosmosVideoToWorldPipelineConfig,
|
10
14
|
RBLNKandinskyV22CombinedPipelineConfig,
|
11
15
|
RBLNKandinskyV22Img2ImgCombinedPipelineConfig,
|
12
16
|
RBLNKandinskyV22Img2ImgPipelineConfig,
|
@@ -1,5 +1,7 @@
|
|
1
1
|
from .configuration_autoencoder_kl import RBLNAutoencoderKLConfig
|
2
|
+
from .configuration_autoencoder_kl_cosmos import RBLNAutoencoderKLCosmosConfig
|
2
3
|
from .configuration_controlnet import RBLNControlNetModelConfig
|
4
|
+
from .configuration_cosmos_transformer import RBLNCosmosTransformer3DModelConfig
|
3
5
|
from .configuration_prior_transformer import RBLNPriorTransformerConfig
|
4
6
|
from .configuration_transformer_sd3 import RBLNSD3Transformer2DModelConfig
|
5
7
|
from .configuration_unet_2d_condition import RBLNUNet2DConditionModelConfig
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
from ....utils.logging import get_logger
|
19
|
+
|
20
|
+
|
21
|
+
logger = get_logger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class RBLNAutoencoderKLCosmosConfig(RBLNModelConfig):
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
batch_size: Optional[int] = None,
|
28
|
+
uses_encoder: Optional[bool] = None,
|
29
|
+
num_frames: Optional[int] = None,
|
30
|
+
height: Optional[int] = None,
|
31
|
+
width: Optional[int] = None,
|
32
|
+
num_channels_latents: Optional[int] = None,
|
33
|
+
vae_scale_factor_temporal: Optional[int] = None,
|
34
|
+
vae_scale_factor_spatial: Optional[int] = None,
|
35
|
+
use_slicing: Optional[bool] = None,
|
36
|
+
**kwargs,
|
37
|
+
):
|
38
|
+
"""
|
39
|
+
Args:
|
40
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
41
|
+
uses_encoder (Optional[bool]): Whether to include the encoder part of the VAE in the model.
|
42
|
+
When False, only the decoder is used (for latent-to-video conversion).
|
43
|
+
num_frames (Optional[int]): The number of frames in the generated video. Defaults to 121.
|
44
|
+
height (Optional[int]): The height in pixels of the generated video. Defaults to 704.
|
45
|
+
width (Optional[int]): The width in pixels of the generated video. Defaults to 1280.
|
46
|
+
num_channels_latents (Optional[int]): The number of channels in latent space.
|
47
|
+
vae_scale_factor_temporal (Optional[int]): The scaling factor between time space and latent space.
|
48
|
+
Determines how much shorter the latent representations are compared to the original videos.
|
49
|
+
vae_scale_factor_spatial (Optional[int]): The scaling factor between pixel space and latent space.
|
50
|
+
Determines how much smaller the latent representations are compared to the original videos.
|
51
|
+
use_slicing (Optional[Bool]): Enable sliced VAE encoding and decoding.
|
52
|
+
If True, the VAE will split the input tensor in slices to compute encoding or decoding in several steps.
|
53
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
54
|
+
|
55
|
+
Raises:
|
56
|
+
ValueError: If batch_size is not a positive integer.
|
57
|
+
"""
|
58
|
+
super().__init__(**kwargs)
|
59
|
+
# Since the Cosmos VAE Decoder already requires approximately 7.9 GiB of memory,
|
60
|
+
# Optimum-rbln cannot execute this model on RBLN-CA12 when the batch size > 1.
|
61
|
+
# However, the Cosmos VAE Decoder propose batch slicing when the batch size is greater than 1,
|
62
|
+
# Optimum-rbln utilize this method by compiling with batch_size=1 to enable batch slicing.
|
63
|
+
self.batch_size = batch_size or 1
|
64
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
65
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
66
|
+
elif self.batch_size > 1:
|
67
|
+
logger.warning("The batch size of Cosmos VAE Decoder will be explicitly 1 for memory efficiency.")
|
68
|
+
self.batch_size = 1
|
69
|
+
|
70
|
+
self.uses_encoder = uses_encoder
|
71
|
+
self.num_frames = num_frames or 121
|
72
|
+
self.height = height or 704
|
73
|
+
self.width = width or 1280
|
74
|
+
|
75
|
+
self.num_channels_latents = num_channels_latents
|
76
|
+
self.vae_scale_factor_temporal = vae_scale_factor_temporal
|
77
|
+
self.vae_scale_factor_spatial = vae_scale_factor_spatial
|
78
|
+
self.use_slicing = use_slicing or False
|
79
|
+
|
80
|
+
@property
|
81
|
+
def image_size(self):
|
82
|
+
return (self.height, self.width)
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
|
19
|
+
|
20
|
+
class RBLNCosmosTransformer3DModelConfig(RBLNModelConfig):
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
batch_size: Optional[int] = None,
|
24
|
+
num_frames: Optional[int] = None,
|
25
|
+
height: Optional[int] = None,
|
26
|
+
width: Optional[int] = None,
|
27
|
+
fps: Optional[int] = None,
|
28
|
+
max_seq_len: Optional[int] = None,
|
29
|
+
embedding_dim: Optional[int] = None,
|
30
|
+
num_channels_latents: Optional[int] = None,
|
31
|
+
num_latent_frames: Optional[int] = None,
|
32
|
+
latent_height: Optional[int] = None,
|
33
|
+
latent_width: Optional[int] = None,
|
34
|
+
**kwargs,
|
35
|
+
):
|
36
|
+
"""
|
37
|
+
Args:
|
38
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
39
|
+
num_frames (Optional[int]): The number of frames in the generated video. Defaults to 121.
|
40
|
+
height (Optional[int]): The height in pixels of the generated video. Defaults to 704.
|
41
|
+
width (Optional[int]): The width in pixels of the generated video. Defaults to 1280.
|
42
|
+
fps (Optional[int]): The frames per second of the generated video. Defaults to 30.
|
43
|
+
max_seq_len (Optional[int]): Maximum sequence length of prompt embeds.
|
44
|
+
embedding_dim (Optional[int]): Embedding vector dimension of prompt embeds.
|
45
|
+
num_channels_latents (Optional[int]): The number of channels in latent space.
|
46
|
+
latent_height (Optional[int]): The height in pixels in latent space.
|
47
|
+
latent_width (Optional[int]): The width in pixels in latent space.
|
48
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
49
|
+
|
50
|
+
Raises:
|
51
|
+
ValueError: If batch_size is not a positive integer.
|
52
|
+
"""
|
53
|
+
super().__init__(**kwargs)
|
54
|
+
self.batch_size = batch_size or 1
|
55
|
+
self.num_frames = num_frames or 121
|
56
|
+
self.height = height or 704
|
57
|
+
self.width = width or 1280
|
58
|
+
self.fps = fps or 30
|
59
|
+
|
60
|
+
self.max_seq_len = max_seq_len
|
61
|
+
self.num_channels_latents = num_channels_latents
|
62
|
+
self.num_latent_frames = num_latent_frames
|
63
|
+
self.latent_height = latent_height
|
64
|
+
self.latent_width = latent_width
|
65
|
+
self.embedding_dim = embedding_dim
|
66
|
+
|
67
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
68
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
@@ -4,6 +4,7 @@ from .configuration_controlnet import (
|
|
4
4
|
RBLNStableDiffusionXLControlNetImg2ImgPipelineConfig,
|
5
5
|
RBLNStableDiffusionXLControlNetPipelineConfig,
|
6
6
|
)
|
7
|
+
from .configuration_cosmos import RBLNCosmosTextToWorldPipelineConfig, RBLNCosmosVideoToWorldPipelineConfig
|
7
8
|
from .configuration_kandinsky2_2 import (
|
8
9
|
RBLNKandinskyV22CombinedPipelineConfig,
|
9
10
|
RBLNKandinskyV22Img2ImgCombinedPipelineConfig,
|
@@ -20,10 +20,6 @@ from ..models import RBLNAutoencoderKLConfig, RBLNControlNetModelConfig, RBLNUNe
|
|
20
20
|
|
21
21
|
|
22
22
|
class RBLNStableDiffusionControlNetPipelineBaseConfig(RBLNModelConfig):
|
23
|
-
"""
|
24
|
-
Base configuration for Stable Diffusion ControlNet pipelines.
|
25
|
-
"""
|
26
|
-
|
27
23
|
submodules = ["text_encoder", "unet", "vae", "controlnet"]
|
28
24
|
_vae_uses_encoder = False
|
29
25
|
|
@@ -0,0 +1,110 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
from ....transformers import RBLNT5EncoderModelConfig
|
19
|
+
from ....utils.logging import get_logger
|
20
|
+
from ...pipelines.cosmos.cosmos_guardrail import RBLNCosmosSafetyCheckerConfig
|
21
|
+
from ..models import RBLNAutoencoderKLCosmosConfig, RBLNCosmosTransformer3DModelConfig
|
22
|
+
|
23
|
+
|
24
|
+
logger = get_logger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
class _RBLNCosmosPipelineBaseConfig(RBLNModelConfig):
|
28
|
+
submodules = ["text_encoder", "transformer", "vae", "safety_checker"]
|
29
|
+
_vae_uses_encoder = False
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
text_encoder: Optional[RBLNT5EncoderModelConfig] = None,
|
34
|
+
transformer: Optional[RBLNCosmosTransformer3DModelConfig] = None,
|
35
|
+
vae: Optional[RBLNAutoencoderKLCosmosConfig] = None,
|
36
|
+
safety_checker: Optional[RBLNCosmosSafetyCheckerConfig] = None,
|
37
|
+
*,
|
38
|
+
batch_size: Optional[int] = None,
|
39
|
+
height: Optional[int] = None,
|
40
|
+
width: Optional[int] = None,
|
41
|
+
num_frames: Optional[int] = None,
|
42
|
+
fps: Optional[int] = None,
|
43
|
+
max_seq_len: Optional[int] = None,
|
44
|
+
**kwargs,
|
45
|
+
):
|
46
|
+
"""
|
47
|
+
Args:
|
48
|
+
text_encoder (Optional[RBLNT5EncoderModelConfig]): Configuration for the text encoder component.
|
49
|
+
Initialized as RBLNT5EncoderModelConfig if not provided.
|
50
|
+
transformer (Optional[RBLNCosmosTransformer3DModelConfig]): Configuration for the UNet model component.
|
51
|
+
Initialized as RBLNCosmosTransformer3DModelConfig if not provided.
|
52
|
+
vae (Optional[RBLNAutoencoderKLCosmosConfig]): Configuration for the VAE model component.
|
53
|
+
Initialized as RBLNAutoencoderKLCosmosConfig if not provided.
|
54
|
+
safety_checker (Optional[RBLNCosmosSafetyCheckerConfig]): Configuration for the safety checker component.
|
55
|
+
Initialized as RBLNCosmosSafetyCheckerConfig if not provided.
|
56
|
+
batch_size (Optional[int]): Batch size for inference, applied to all submodules.
|
57
|
+
height (Optional[int]): Height of the generated videos.
|
58
|
+
width (Optional[int]): Width of the generated videos.
|
59
|
+
num_frames (Optional[int]): The number of frames in the generated video.
|
60
|
+
fps (Optional[int]): The frames per second of the generated video.
|
61
|
+
max_seq_len (Optional[int]): Maximum sequence length supported by the model.
|
62
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
63
|
+
"""
|
64
|
+
super().__init__(**kwargs)
|
65
|
+
|
66
|
+
self.text_encoder = self.init_submodule_config(
|
67
|
+
RBLNT5EncoderModelConfig, text_encoder, batch_size=batch_size, max_seq_len=max_seq_len
|
68
|
+
)
|
69
|
+
self.transformer = self.init_submodule_config(
|
70
|
+
RBLNCosmosTransformer3DModelConfig,
|
71
|
+
transformer,
|
72
|
+
batch_size=batch_size,
|
73
|
+
max_seq_len=max_seq_len,
|
74
|
+
height=height,
|
75
|
+
width=width,
|
76
|
+
num_frames=num_frames,
|
77
|
+
fps=fps,
|
78
|
+
)
|
79
|
+
self.vae = self.init_submodule_config(
|
80
|
+
RBLNAutoencoderKLCosmosConfig,
|
81
|
+
vae,
|
82
|
+
batch_size=batch_size,
|
83
|
+
uses_encoder=self.__class__._vae_uses_encoder,
|
84
|
+
height=height,
|
85
|
+
width=width,
|
86
|
+
num_frames=num_frames,
|
87
|
+
)
|
88
|
+
self.safety_checker = self.init_submodule_config(
|
89
|
+
RBLNCosmosSafetyCheckerConfig,
|
90
|
+
safety_checker,
|
91
|
+
batch_size=batch_size,
|
92
|
+
height=height,
|
93
|
+
width=width,
|
94
|
+
)
|
95
|
+
|
96
|
+
@property
|
97
|
+
def batch_size(self):
|
98
|
+
return self.vae.batch_size
|
99
|
+
|
100
|
+
@property
|
101
|
+
def max_seq_len(self):
|
102
|
+
return self.text_encoder.max_seq_len
|
103
|
+
|
104
|
+
|
105
|
+
class RBLNCosmosTextToWorldPipelineConfig(_RBLNCosmosPipelineBaseConfig):
|
106
|
+
_vae_uses_encoder = False
|
107
|
+
|
108
|
+
|
109
|
+
class RBLNCosmosVideoToWorldPipelineConfig(_RBLNCosmosPipelineBaseConfig):
|
110
|
+
_vae_uses_encoder = True
|
@@ -21,8 +21,6 @@ from ..models.configuration_prior_transformer import RBLNPriorTransformerConfig
|
|
21
21
|
|
22
22
|
|
23
23
|
class RBLNKandinskyV22PipelineBaseConfig(RBLNModelConfig):
|
24
|
-
"""Base configuration class for Kandinsky V2.2 decoder pipelines."""
|
25
|
-
|
26
24
|
submodules = ["unet", "movq"]
|
27
25
|
_movq_uses_encoder = False
|
28
26
|
|
@@ -20,10 +20,6 @@ from ..models import RBLNAutoencoderKLConfig, RBLNUNet2DConditionModelConfig
|
|
20
20
|
|
21
21
|
|
22
22
|
class RBLNStableDiffusionPipelineBaseConfig(RBLNModelConfig):
|
23
|
-
"""
|
24
|
-
Base configuration for Stable Diffusion pipelines.
|
25
|
-
"""
|
26
|
-
|
27
23
|
submodules = ["text_encoder", "unet", "vae"]
|
28
24
|
_vae_uses_encoder = False
|
29
25
|
|
@@ -20,10 +20,6 @@ from ..models import RBLNAutoencoderKLConfig, RBLNSD3Transformer2DModelConfig
|
|
20
20
|
|
21
21
|
|
22
22
|
class RBLNStableDiffusion3PipelineBaseConfig(RBLNModelConfig):
|
23
|
-
"""
|
24
|
-
Base configuration for Stable Diffusion 3 pipelines.
|
25
|
-
"""
|
26
|
-
|
27
23
|
submodules = ["transformer", "text_encoder", "text_encoder_2", "text_encoder_3", "vae"]
|
28
24
|
_vae_uses_encoder = False
|
29
25
|
|
@@ -115,6 +111,7 @@ class RBLNStableDiffusion3PipelineBaseConfig(RBLNModelConfig):
|
|
115
111
|
text_encoder_3,
|
116
112
|
batch_size=batch_size,
|
117
113
|
max_seq_len=max_seq_len,
|
114
|
+
model_input_names=["input_ids"],
|
118
115
|
)
|
119
116
|
self.transformer = self.init_submodule_config(
|
120
117
|
RBLNSD3Transformer2DModelConfig,
|
@@ -20,10 +20,6 @@ from ..models import RBLNAutoencoderKLConfig, RBLNUNet2DConditionModelConfig
|
|
20
20
|
|
21
21
|
|
22
22
|
class RBLNStableDiffusionXLPipelineBaseConfig(RBLNModelConfig):
|
23
|
-
"""
|
24
|
-
Base configuration for Stable Diffusion XL pipelines.
|
25
|
-
"""
|
26
|
-
|
27
23
|
submodules = ["text_encoder", "text_encoder_2", "unet", "vae"]
|
28
24
|
_vae_uses_encoder = False
|
29
25
|
|
@@ -45,7 +45,7 @@ class RBLNDiffusionMixin:
|
|
45
45
|
To use this mixin:
|
46
46
|
|
47
47
|
1. Create a new pipeline class that inherits from both this mixin and the original StableDiffusionPipeline.
|
48
|
-
2. Define the required _submodules class variable listing the components to be compiled.
|
48
|
+
2. Define the required _submodules and _optional_submodules class variable listing the components to be compiled.
|
49
49
|
|
50
50
|
Example:
|
51
51
|
```python
|
@@ -55,6 +55,7 @@ class RBLNDiffusionMixin:
|
|
55
55
|
|
56
56
|
Class Variables:
|
57
57
|
_submodules: List of submodule names that should be compiled (typically ["text_encoder", "unet", "vae"])
|
58
|
+
_optional_submodules: List of submodule names compiled without inheriting RBLNModel (typically ["safety_checker"])
|
58
59
|
|
59
60
|
Methods:
|
60
61
|
from_pretrained: Creates and optionally compiles a model from a pretrained checkpoint
|
@@ -67,6 +68,7 @@ class RBLNDiffusionMixin:
|
|
67
68
|
|
68
69
|
_connected_classes = {}
|
69
70
|
_submodules = []
|
71
|
+
_optional_submodules = []
|
70
72
|
_prefix = {}
|
71
73
|
_rbln_config_class = None
|
72
74
|
_hf_class = None
|
@@ -184,31 +186,42 @@ class RBLNDiffusionMixin:
|
|
184
186
|
if export:
|
185
187
|
# keep submodules if user passed any of them.
|
186
188
|
passed_submodules = {
|
187
|
-
name: kwargs.pop(name)
|
189
|
+
name: kwargs.pop(name)
|
190
|
+
for name in cls._submodules + cls._optional_submodules
|
191
|
+
if isinstance(kwargs.get(name), RBLNModel)
|
188
192
|
}
|
189
193
|
|
190
194
|
else:
|
191
195
|
# raise error if any of submodules are torch module.
|
192
196
|
model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
|
193
|
-
for submodule_name in cls._submodules:
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
+
for submodule_name in cls._submodules + cls._optional_submodules:
|
198
|
+
passed_submodule = kwargs.get(submodule_name, None)
|
199
|
+
|
200
|
+
if passed_submodule is None:
|
201
|
+
module_name, class_name = model_index_config[submodule_name]
|
202
|
+
if module_name != "optimum.rbln":
|
203
|
+
raise ValueError(
|
204
|
+
f"Invalid module_name '{module_name}' found in model_index.json for "
|
205
|
+
f"submodule '{submodule_name}'. "
|
206
|
+
"Expected 'optimum.rbln'. Please check the model_index.json configuration."
|
207
|
+
"If you want to compile, set `export=True`."
|
208
|
+
)
|
209
|
+
|
210
|
+
submodule_cls = get_rbln_model_cls(class_name)
|
211
|
+
submodule_config = getattr(rbln_config, submodule_name)
|
212
|
+
submodule = submodule_cls.from_pretrained(
|
213
|
+
model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
|
197
214
|
)
|
198
215
|
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
216
|
+
else:
|
217
|
+
if passed_submodule.__class__.__name__.startswith("RBLN"):
|
218
|
+
submodule = passed_submodule
|
219
|
+
|
220
|
+
elif isinstance(passed_submodule, torch.nn.Module):
|
221
|
+
raise AssertionError(
|
222
|
+
f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
|
223
|
+
)
|
206
224
|
|
207
|
-
submodule_cls = get_rbln_model_cls(class_name)
|
208
|
-
submodule_config = getattr(rbln_config, submodule_name)
|
209
|
-
submodule = submodule_cls.from_pretrained(
|
210
|
-
model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
|
211
|
-
)
|
212
225
|
kwargs[submodule_name] = submodule
|
213
226
|
|
214
227
|
with ContextRblnConfig(
|
@@ -352,10 +365,16 @@ class RBLNDiffusionMixin:
|
|
352
365
|
# Causing warning messeages.
|
353
366
|
|
354
367
|
update_dict = {}
|
355
|
-
for submodule_name in cls._submodules:
|
368
|
+
for submodule_name in cls._submodules + cls._optional_submodules:
|
356
369
|
# replace submodule
|
357
|
-
|
358
|
-
|
370
|
+
if submodule_name in submodules:
|
371
|
+
setattr(model, submodule_name, submodules[submodule_name])
|
372
|
+
update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
|
373
|
+
else:
|
374
|
+
# It assumes that the modules in _optional_components is compiled
|
375
|
+
# and already registered as an attribute of the model.
|
376
|
+
update_dict[submodule_name] = ("optimum.rbln", getattr(model, submodule_name).__class__.__name__)
|
377
|
+
|
359
378
|
if cls._load_connected_pipes:
|
360
379
|
for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
|
361
380
|
prefix = cls._prefix.get(connected_pipe_name, "")
|
@@ -386,31 +405,29 @@ class RBLNDiffusionMixin:
|
|
386
405
|
return model
|
387
406
|
|
388
407
|
def get_compiled_image_size(self):
|
389
|
-
if hasattr(self, "vae"):
|
408
|
+
if hasattr(self, "vae") and hasattr(self.vae, "image_size"):
|
390
409
|
compiled_image_size = self.vae.image_size
|
391
410
|
else:
|
392
411
|
compiled_image_size = None
|
393
412
|
return compiled_image_size
|
394
413
|
|
395
414
|
def handle_additional_kwargs(self, **kwargs):
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
```
|
413
|
-
"""
|
415
|
+
# Function to handle additional compile-time parameters during inference.
|
416
|
+
|
417
|
+
# If the additional variable is determined by another module, this method should be overrided.
|
418
|
+
|
419
|
+
# Example:
|
420
|
+
# ```python
|
421
|
+
# if hasattr(self, "movq"):
|
422
|
+
# compiled_image_size = self.movq.image_size
|
423
|
+
# kwargs["height"] = compiled_image_size[0]
|
424
|
+
# kwargs["width"] = compiled_image_size[1]
|
425
|
+
|
426
|
+
# compiled_num_frames = self.unet.rbln_config.num_frames
|
427
|
+
# if compiled_num_frames is not None:
|
428
|
+
# kwargs["num_frames"] = compiled_num_frames
|
429
|
+
# return kwargs
|
430
|
+
# ```
|
414
431
|
return kwargs
|
415
432
|
|
416
433
|
@remove_compile_time_kwargs
|
@@ -20,6 +20,7 @@ from transformers.utils import _LazyModule
|
|
20
20
|
_import_structure = {
|
21
21
|
"autoencoders": [
|
22
22
|
"RBLNAutoencoderKL",
|
23
|
+
"RBLNAutoencoderKLCosmos",
|
23
24
|
"RBLNVQModel",
|
24
25
|
],
|
25
26
|
"unets": [
|
@@ -28,6 +29,7 @@ _import_structure = {
|
|
28
29
|
"controlnet": ["RBLNControlNetModel"],
|
29
30
|
"transformers": [
|
30
31
|
"RBLNPriorTransformer",
|
32
|
+
"RBLNCosmosTransformer3DModel",
|
31
33
|
"RBLNSD3Transformer2DModel",
|
32
34
|
],
|
33
35
|
}
|
@@ -35,10 +37,12 @@ _import_structure = {
|
|
35
37
|
if TYPE_CHECKING:
|
36
38
|
from .autoencoders import (
|
37
39
|
RBLNAutoencoderKL,
|
40
|
+
RBLNAutoencoderKLCosmos,
|
38
41
|
RBLNVQModel,
|
39
42
|
)
|
40
43
|
from .controlnet import RBLNControlNetModel
|
41
44
|
from .transformers import (
|
45
|
+
RBLNCosmosTransformer3DModel,
|
42
46
|
RBLNPriorTransformer,
|
43
47
|
RBLNSD3Transformer2DModel,
|
44
48
|
)
|