optimum-rbln 0.8.1a4__py3-none-any.whl → 0.8.1a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +22 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/__init__.py +21 -1
- optimum/rbln/diffusers/configurations/__init__.py +4 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +82 -0
- optimum/rbln/diffusers/configurations/models/configuration_cosmos_transformer.py +68 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +110 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +1 -0
- optimum/rbln/diffusers/modeling_diffusers.py +41 -22
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +209 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +49 -5
- optimum/rbln/diffusers/models/controlnet.py +1 -1
- optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
- optimum/rbln/diffusers/pipelines/__init__.py +10 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +395 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/models/__init__.py +8 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +2 -2
- optimum/rbln/transformers/models/t5/modeling_t5.py +10 -4
- optimum/rbln/utils/runtime_utils.py +3 -0
- {optimum_rbln-0.8.1a4.dist-info → optimum_rbln-0.8.1a6.dist-info}/METADATA +4 -4
- {optimum_rbln-0.8.1a4.dist-info → optimum_rbln-0.8.1a6.dist-info}/RECORD +37 -23
- {optimum_rbln-0.8.1a4.dist-info → optimum_rbln-0.8.1a6.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.1a4.dist-info → optimum_rbln-0.8.1a6.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -70,6 +70,8 @@ _import_structure = {
|
|
70
70
|
"RBLNCLIPVisionModelConfig",
|
71
71
|
"RBLNCLIPVisionModelWithProjection",
|
72
72
|
"RBLNCLIPVisionModelWithProjectionConfig",
|
73
|
+
"RBLNColPaliForRetrieval",
|
74
|
+
"RBLNColPaliForRetrievalConfig",
|
73
75
|
"RBLNDecoderOnlyModelForCausalLM",
|
74
76
|
"RBLNDecoderOnlyModelForCausalLMConfig",
|
75
77
|
"RBLNDistilBertForQuestionAnswering",
|
@@ -136,8 +138,17 @@ _import_structure = {
|
|
136
138
|
"diffusers": [
|
137
139
|
"RBLNAutoencoderKL",
|
138
140
|
"RBLNAutoencoderKLConfig",
|
141
|
+
"RBLNAutoencoderKLCosmos",
|
142
|
+
"RBLNAutoencoderKLCosmosConfig",
|
139
143
|
"RBLNControlNetModel",
|
140
144
|
"RBLNControlNetModelConfig",
|
145
|
+
"RBLNCosmosTextToWorldPipeline",
|
146
|
+
"RBLNCosmosVideoToWorldPipeline",
|
147
|
+
"RBLNCosmosTextToWorldPipelineConfig",
|
148
|
+
"RBLNCosmosVideoToWorldPipelineConfig",
|
149
|
+
"RBLNCosmosSafetyChecker",
|
150
|
+
"RBLNCosmosTransformer3DModel",
|
151
|
+
"RBLNCosmosTransformer3DModelConfig",
|
141
152
|
"RBLNDiffusionMixin",
|
142
153
|
"RBLNKandinskyV22CombinedPipeline",
|
143
154
|
"RBLNKandinskyV22CombinedPipelineConfig",
|
@@ -200,8 +211,17 @@ if TYPE_CHECKING:
|
|
200
211
|
from .diffusers import (
|
201
212
|
RBLNAutoencoderKL,
|
202
213
|
RBLNAutoencoderKLConfig,
|
214
|
+
RBLNAutoencoderKLCosmos,
|
215
|
+
RBLNAutoencoderKLCosmosConfig,
|
203
216
|
RBLNControlNetModel,
|
204
217
|
RBLNControlNetModelConfig,
|
218
|
+
RBLNCosmosSafetyChecker,
|
219
|
+
RBLNCosmosTextToWorldPipeline,
|
220
|
+
RBLNCosmosTextToWorldPipelineConfig,
|
221
|
+
RBLNCosmosTransformer3DModel,
|
222
|
+
RBLNCosmosTransformer3DModelConfig,
|
223
|
+
RBLNCosmosVideoToWorldPipeline,
|
224
|
+
RBLNCosmosVideoToWorldPipelineConfig,
|
205
225
|
RBLNDiffusionMixin,
|
206
226
|
RBLNKandinskyV22CombinedPipeline,
|
207
227
|
RBLNKandinskyV22CombinedPipelineConfig,
|
@@ -297,6 +317,8 @@ if TYPE_CHECKING:
|
|
297
317
|
RBLNCLIPVisionModelConfig,
|
298
318
|
RBLNCLIPVisionModelWithProjection,
|
299
319
|
RBLNCLIPVisionModelWithProjectionConfig,
|
320
|
+
RBLNColPaliForRetrieval,
|
321
|
+
RBLNColPaliForRetrievalConfig,
|
300
322
|
RBLNDecoderOnlyModelForCausalLM,
|
301
323
|
RBLNDecoderOnlyModelForCausalLMConfig,
|
302
324
|
RBLNDistilBertForQuestionAnswering,
|
optimum/rbln/__version__.py
CHANGED
@@ -17,5 +17,5 @@ __version__: str
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
18
18
|
version_tuple: VERSION_TUPLE
|
19
19
|
|
20
|
-
__version__ = version = '0.8.
|
21
|
-
__version_tuple__ = version_tuple = (0, 8, 1, '
|
20
|
+
__version__ = version = '0.8.1a6'
|
21
|
+
__version_tuple__ = version_tuple = (0, 8, 1, 'a6')
|
@@ -18,14 +18,21 @@ from diffusers.pipelines.pipeline_utils import ALL_IMPORTABLE_CLASSES, LOADABLE_
|
|
18
18
|
from transformers.utils import _LazyModule
|
19
19
|
|
20
20
|
|
21
|
-
LOADABLE_CLASSES["optimum.rbln"] = {
|
21
|
+
LOADABLE_CLASSES["optimum.rbln"] = {
|
22
|
+
"RBLNBaseModel": ["save_pretrained", "from_pretrained"],
|
23
|
+
"RBLNCosmosSafetyChecker": ["save_pretrained", "from_pretrained"],
|
24
|
+
}
|
22
25
|
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES["optimum.rbln"])
|
23
26
|
|
24
27
|
|
25
28
|
_import_structure = {
|
26
29
|
"configurations": [
|
27
30
|
"RBLNAutoencoderKLConfig",
|
31
|
+
"RBLNAutoencoderKLCosmosConfig",
|
28
32
|
"RBLNControlNetModelConfig",
|
33
|
+
"RBLNCosmosTextToWorldPipelineConfig",
|
34
|
+
"RBLNCosmosVideoToWorldPipelineConfig",
|
35
|
+
"RBLNCosmosTransformer3DModelConfig",
|
29
36
|
"RBLNKandinskyV22CombinedPipelineConfig",
|
30
37
|
"RBLNKandinskyV22Img2ImgCombinedPipelineConfig",
|
31
38
|
"RBLNKandinskyV22Img2ImgPipelineConfig",
|
@@ -52,6 +59,9 @@ _import_structure = {
|
|
52
59
|
"RBLNVQModelConfig",
|
53
60
|
],
|
54
61
|
"pipelines": [
|
62
|
+
"RBLNCosmosTextToWorldPipeline",
|
63
|
+
"RBLNCosmosVideoToWorldPipeline",
|
64
|
+
"RBLNCosmosSafetyChecker",
|
55
65
|
"RBLNKandinskyV22CombinedPipeline",
|
56
66
|
"RBLNKandinskyV22Img2ImgCombinedPipeline",
|
57
67
|
"RBLNKandinskyV22InpaintCombinedPipeline",
|
@@ -76,8 +86,10 @@ _import_structure = {
|
|
76
86
|
],
|
77
87
|
"models": [
|
78
88
|
"RBLNAutoencoderKL",
|
89
|
+
"RBLNAutoencoderKLCosmos",
|
79
90
|
"RBLNUNet2DConditionModel",
|
80
91
|
"RBLNControlNetModel",
|
92
|
+
"RBLNCosmosTransformer3DModel",
|
81
93
|
"RBLNSD3Transformer2DModel",
|
82
94
|
"RBLNPriorTransformer",
|
83
95
|
"RBLNVQModel",
|
@@ -90,7 +102,11 @@ _import_structure = {
|
|
90
102
|
if TYPE_CHECKING:
|
91
103
|
from .configurations import (
|
92
104
|
RBLNAutoencoderKLConfig,
|
105
|
+
RBLNAutoencoderKLCosmosConfig,
|
93
106
|
RBLNControlNetModelConfig,
|
107
|
+
RBLNCosmosTextToWorldPipelineConfig,
|
108
|
+
RBLNCosmosTransformer3DModelConfig,
|
109
|
+
RBLNCosmosVideoToWorldPipelineConfig,
|
94
110
|
RBLNKandinskyV22CombinedPipelineConfig,
|
95
111
|
RBLNKandinskyV22Img2ImgCombinedPipelineConfig,
|
96
112
|
RBLNKandinskyV22Img2ImgPipelineConfig,
|
@@ -120,12 +136,16 @@ if TYPE_CHECKING:
|
|
120
136
|
from .models import (
|
121
137
|
RBLNAutoencoderKL,
|
122
138
|
RBLNControlNetModel,
|
139
|
+
RBLNCosmosTransformer3DModel,
|
123
140
|
RBLNPriorTransformer,
|
124
141
|
RBLNSD3Transformer2DModel,
|
125
142
|
RBLNUNet2DConditionModel,
|
126
143
|
RBLNVQModel,
|
127
144
|
)
|
128
145
|
from .pipelines import (
|
146
|
+
RBLNCosmosSafetyChecker,
|
147
|
+
RBLNCosmosTextToWorldPipeline,
|
148
|
+
RBLNCosmosVideoToWorldPipeline,
|
129
149
|
RBLNKandinskyV22CombinedPipeline,
|
130
150
|
RBLNKandinskyV22Img2ImgCombinedPipeline,
|
131
151
|
RBLNKandinskyV22Img2ImgPipeline,
|
@@ -1,12 +1,16 @@
|
|
1
1
|
from .models import (
|
2
2
|
RBLNAutoencoderKLConfig,
|
3
|
+
RBLNAutoencoderKLCosmosConfig,
|
3
4
|
RBLNControlNetModelConfig,
|
5
|
+
RBLNCosmosTransformer3DModelConfig,
|
4
6
|
RBLNPriorTransformerConfig,
|
5
7
|
RBLNSD3Transformer2DModelConfig,
|
6
8
|
RBLNUNet2DConditionModelConfig,
|
7
9
|
RBLNVQModelConfig,
|
8
10
|
)
|
9
11
|
from .pipelines import (
|
12
|
+
RBLNCosmosTextToWorldPipelineConfig,
|
13
|
+
RBLNCosmosVideoToWorldPipelineConfig,
|
10
14
|
RBLNKandinskyV22CombinedPipelineConfig,
|
11
15
|
RBLNKandinskyV22Img2ImgCombinedPipelineConfig,
|
12
16
|
RBLNKandinskyV22Img2ImgPipelineConfig,
|
@@ -1,5 +1,7 @@
|
|
1
1
|
from .configuration_autoencoder_kl import RBLNAutoencoderKLConfig
|
2
|
+
from .configuration_autoencoder_kl_cosmos import RBLNAutoencoderKLCosmosConfig
|
2
3
|
from .configuration_controlnet import RBLNControlNetModelConfig
|
4
|
+
from .configuration_cosmos_transformer import RBLNCosmosTransformer3DModelConfig
|
3
5
|
from .configuration_prior_transformer import RBLNPriorTransformerConfig
|
4
6
|
from .configuration_transformer_sd3 import RBLNSD3Transformer2DModelConfig
|
5
7
|
from .configuration_unet_2d_condition import RBLNUNet2DConditionModelConfig
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
from ....utils.logging import get_logger
|
19
|
+
|
20
|
+
|
21
|
+
logger = get_logger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class RBLNAutoencoderKLCosmosConfig(RBLNModelConfig):
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
batch_size: Optional[int] = None,
|
28
|
+
uses_encoder: Optional[bool] = None,
|
29
|
+
num_frames: Optional[int] = None,
|
30
|
+
height: Optional[int] = None,
|
31
|
+
width: Optional[int] = None,
|
32
|
+
num_channels_latents: Optional[int] = None,
|
33
|
+
vae_scale_factor_temporal: Optional[int] = None,
|
34
|
+
vae_scale_factor_spatial: Optional[int] = None,
|
35
|
+
use_slicing: Optional[bool] = None,
|
36
|
+
**kwargs,
|
37
|
+
):
|
38
|
+
"""
|
39
|
+
Args:
|
40
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
41
|
+
uses_encoder (Optional[bool]): Whether to include the encoder part of the VAE in the model.
|
42
|
+
When False, only the decoder is used (for latent-to-video conversion).
|
43
|
+
num_frames (Optional[int]): The number of frames in the generated video. Defaults to 121.
|
44
|
+
height (Optional[int]): The height in pixels of the generated video. Defaults to 704.
|
45
|
+
width (Optional[int]): The width in pixels of the generated video. Defaults to 1280.
|
46
|
+
num_channels_latents (Optional[int]): The number of channels in latent space.
|
47
|
+
vae_scale_factor_temporal (Optional[int]): The scaling factor between time space and latent space.
|
48
|
+
Determines how much shorter the latent representations are compared to the original videos.
|
49
|
+
vae_scale_factor_spatial (Optional[int]): The scaling factor between pixel space and latent space.
|
50
|
+
Determines how much smaller the latent representations are compared to the original videos.
|
51
|
+
use_slicing (Optional[Bool]): Enable sliced VAE encoding and decoding.
|
52
|
+
If True, the VAE will split the input tensor in slices to compute encoding or decoding in several steps.
|
53
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
54
|
+
|
55
|
+
Raises:
|
56
|
+
ValueError: If batch_size is not a positive integer.
|
57
|
+
"""
|
58
|
+
super().__init__(**kwargs)
|
59
|
+
# Since the Cosmos VAE Decoder already requires approximately 7.9 GiB of memory,
|
60
|
+
# Optimum-rbln cannot execute this model on RBLN-CA12 when the batch size > 1.
|
61
|
+
# However, the Cosmos VAE Decoder propose batch slicing when the batch size is greater than 1,
|
62
|
+
# Optimum-rbln utilize this method by compiling with batch_size=1 to enable batch slicing.
|
63
|
+
self.batch_size = batch_size or 1
|
64
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
65
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
66
|
+
elif self.batch_size > 1:
|
67
|
+
logger.warning("The batch size of Cosmos VAE Decoder will be explicitly 1 for memory efficiency.")
|
68
|
+
self.batch_size = 1
|
69
|
+
|
70
|
+
self.uses_encoder = uses_encoder
|
71
|
+
self.num_frames = num_frames or 121
|
72
|
+
self.height = height or 704
|
73
|
+
self.width = width or 1280
|
74
|
+
|
75
|
+
self.num_channels_latents = num_channels_latents
|
76
|
+
self.vae_scale_factor_temporal = vae_scale_factor_temporal
|
77
|
+
self.vae_scale_factor_spatial = vae_scale_factor_spatial
|
78
|
+
self.use_slicing = use_slicing or False
|
79
|
+
|
80
|
+
@property
|
81
|
+
def image_size(self):
|
82
|
+
return (self.height, self.width)
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
|
19
|
+
|
20
|
+
class RBLNCosmosTransformer3DModelConfig(RBLNModelConfig):
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
batch_size: Optional[int] = None,
|
24
|
+
num_frames: Optional[int] = None,
|
25
|
+
height: Optional[int] = None,
|
26
|
+
width: Optional[int] = None,
|
27
|
+
fps: Optional[int] = None,
|
28
|
+
max_seq_len: Optional[int] = None,
|
29
|
+
embedding_dim: Optional[int] = None,
|
30
|
+
num_channels_latents: Optional[int] = None,
|
31
|
+
num_latent_frames: Optional[int] = None,
|
32
|
+
latent_height: Optional[int] = None,
|
33
|
+
latent_width: Optional[int] = None,
|
34
|
+
**kwargs,
|
35
|
+
):
|
36
|
+
"""
|
37
|
+
Args:
|
38
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
39
|
+
num_frames (Optional[int]): The number of frames in the generated video. Defaults to 121.
|
40
|
+
height (Optional[int]): The height in pixels of the generated video. Defaults to 704.
|
41
|
+
width (Optional[int]): The width in pixels of the generated video. Defaults to 1280.
|
42
|
+
fps (Optional[int]): The frames per second of the generated video. Defaults to 30.
|
43
|
+
max_seq_len (Optional[int]): Maximum sequence length of prompt embeds.
|
44
|
+
embedding_dim (Optional[int]): Embedding vector dimension of prompt embeds.
|
45
|
+
num_channels_latents (Optional[int]): The number of channels in latent space.
|
46
|
+
latent_height (Optional[int]): The height in pixels in latent space.
|
47
|
+
latent_width (Optional[int]): The width in pixels in latent space.
|
48
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
49
|
+
|
50
|
+
Raises:
|
51
|
+
ValueError: If batch_size is not a positive integer.
|
52
|
+
"""
|
53
|
+
super().__init__(**kwargs)
|
54
|
+
self.batch_size = batch_size or 1
|
55
|
+
self.num_frames = num_frames or 121
|
56
|
+
self.height = height or 704
|
57
|
+
self.width = width or 1280
|
58
|
+
self.fps = fps or 30
|
59
|
+
|
60
|
+
self.max_seq_len = max_seq_len
|
61
|
+
self.num_channels_latents = num_channels_latents
|
62
|
+
self.num_latent_frames = num_latent_frames
|
63
|
+
self.latent_height = latent_height
|
64
|
+
self.latent_width = latent_width
|
65
|
+
self.embedding_dim = embedding_dim
|
66
|
+
|
67
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
68
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
@@ -4,6 +4,7 @@ from .configuration_controlnet import (
|
|
4
4
|
RBLNStableDiffusionXLControlNetImg2ImgPipelineConfig,
|
5
5
|
RBLNStableDiffusionXLControlNetPipelineConfig,
|
6
6
|
)
|
7
|
+
from .configuration_cosmos import RBLNCosmosTextToWorldPipelineConfig, RBLNCosmosVideoToWorldPipelineConfig
|
7
8
|
from .configuration_kandinsky2_2 import (
|
8
9
|
RBLNKandinskyV22CombinedPipelineConfig,
|
9
10
|
RBLNKandinskyV22Img2ImgCombinedPipelineConfig,
|
@@ -0,0 +1,110 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
from ....transformers import RBLNT5EncoderModelConfig
|
19
|
+
from ....utils.logging import get_logger
|
20
|
+
from ...pipelines.cosmos.cosmos_guardrail import RBLNCosmosSafetyCheckerConfig
|
21
|
+
from ..models import RBLNAutoencoderKLCosmosConfig, RBLNCosmosTransformer3DModelConfig
|
22
|
+
|
23
|
+
|
24
|
+
logger = get_logger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
class _RBLNCosmosPipelineBaseConfig(RBLNModelConfig):
|
28
|
+
submodules = ["text_encoder", "transformer", "vae", "safety_checker"]
|
29
|
+
_vae_uses_encoder = False
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
text_encoder: Optional[RBLNT5EncoderModelConfig] = None,
|
34
|
+
transformer: Optional[RBLNCosmosTransformer3DModelConfig] = None,
|
35
|
+
vae: Optional[RBLNAutoencoderKLCosmosConfig] = None,
|
36
|
+
safety_checker: Optional[RBLNCosmosSafetyCheckerConfig] = None,
|
37
|
+
*,
|
38
|
+
batch_size: Optional[int] = None,
|
39
|
+
height: Optional[int] = None,
|
40
|
+
width: Optional[int] = None,
|
41
|
+
num_frames: Optional[int] = None,
|
42
|
+
fps: Optional[int] = None,
|
43
|
+
max_seq_len: Optional[int] = None,
|
44
|
+
**kwargs,
|
45
|
+
):
|
46
|
+
"""
|
47
|
+
Args:
|
48
|
+
text_encoder (Optional[RBLNT5EncoderModelConfig]): Configuration for the text encoder component.
|
49
|
+
Initialized as RBLNT5EncoderModelConfig if not provided.
|
50
|
+
transformer (Optional[RBLNCosmosTransformer3DModelConfig]): Configuration for the UNet model component.
|
51
|
+
Initialized as RBLNCosmosTransformer3DModelConfig if not provided.
|
52
|
+
vae (Optional[RBLNAutoencoderKLCosmosConfig]): Configuration for the VAE model component.
|
53
|
+
Initialized as RBLNAutoencoderKLCosmosConfig if not provided.
|
54
|
+
safety_checker (Optional[RBLNCosmosSafetyCheckerConfig]): Configuration for the safety checker component.
|
55
|
+
Initialized as RBLNCosmosSafetyCheckerConfig if not provided.
|
56
|
+
batch_size (Optional[int]): Batch size for inference, applied to all submodules.
|
57
|
+
height (Optional[int]): Height of the generated videos.
|
58
|
+
width (Optional[int]): Width of the generated videos.
|
59
|
+
num_frames (Optional[int]): The number of frames in the generated video.
|
60
|
+
fps (Optional[int]): The frames per second of the generated video.
|
61
|
+
max_seq_len (Optional[int]): Maximum sequence length supported by the model.
|
62
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
63
|
+
"""
|
64
|
+
super().__init__(**kwargs)
|
65
|
+
|
66
|
+
self.text_encoder = self.init_submodule_config(
|
67
|
+
RBLNT5EncoderModelConfig, text_encoder, batch_size=batch_size, max_seq_len=max_seq_len
|
68
|
+
)
|
69
|
+
self.transformer = self.init_submodule_config(
|
70
|
+
RBLNCosmosTransformer3DModelConfig,
|
71
|
+
transformer,
|
72
|
+
batch_size=batch_size,
|
73
|
+
max_seq_len=max_seq_len,
|
74
|
+
height=height,
|
75
|
+
width=width,
|
76
|
+
num_frames=num_frames,
|
77
|
+
fps=fps,
|
78
|
+
)
|
79
|
+
self.vae = self.init_submodule_config(
|
80
|
+
RBLNAutoencoderKLCosmosConfig,
|
81
|
+
vae,
|
82
|
+
batch_size=batch_size,
|
83
|
+
uses_encoder=self.__class__._vae_uses_encoder,
|
84
|
+
height=height,
|
85
|
+
width=width,
|
86
|
+
num_frames=num_frames,
|
87
|
+
)
|
88
|
+
self.safety_checker = self.init_submodule_config(
|
89
|
+
RBLNCosmosSafetyCheckerConfig,
|
90
|
+
safety_checker,
|
91
|
+
batch_size=batch_size,
|
92
|
+
height=height,
|
93
|
+
width=width,
|
94
|
+
)
|
95
|
+
|
96
|
+
@property
|
97
|
+
def batch_size(self):
|
98
|
+
return self.vae.batch_size
|
99
|
+
|
100
|
+
@property
|
101
|
+
def max_seq_len(self):
|
102
|
+
return self.text_encoder.max_seq_len
|
103
|
+
|
104
|
+
|
105
|
+
class RBLNCosmosTextToWorldPipelineConfig(_RBLNCosmosPipelineBaseConfig):
|
106
|
+
_vae_uses_encoder = False
|
107
|
+
|
108
|
+
|
109
|
+
class RBLNCosmosVideoToWorldPipelineConfig(_RBLNCosmosPipelineBaseConfig):
|
110
|
+
_vae_uses_encoder = True
|
@@ -115,6 +115,7 @@ class RBLNStableDiffusion3PipelineBaseConfig(RBLNModelConfig):
|
|
115
115
|
text_encoder_3,
|
116
116
|
batch_size=batch_size,
|
117
117
|
max_seq_len=max_seq_len,
|
118
|
+
model_input_names=["input_ids"],
|
118
119
|
)
|
119
120
|
self.transformer = self.init_submodule_config(
|
120
121
|
RBLNSD3Transformer2DModelConfig,
|
@@ -45,7 +45,7 @@ class RBLNDiffusionMixin:
|
|
45
45
|
To use this mixin:
|
46
46
|
|
47
47
|
1. Create a new pipeline class that inherits from both this mixin and the original StableDiffusionPipeline.
|
48
|
-
2. Define the required _submodules class variable listing the components to be compiled.
|
48
|
+
2. Define the required _submodules and _optional_submodules class variable listing the components to be compiled.
|
49
49
|
|
50
50
|
Example:
|
51
51
|
```python
|
@@ -55,6 +55,7 @@ class RBLNDiffusionMixin:
|
|
55
55
|
|
56
56
|
Class Variables:
|
57
57
|
_submodules: List of submodule names that should be compiled (typically ["text_encoder", "unet", "vae"])
|
58
|
+
_optional_submodules: List of submodule names compiled without inheriting RBLNModel (typically ["safety_checker"])
|
58
59
|
|
59
60
|
Methods:
|
60
61
|
from_pretrained: Creates and optionally compiles a model from a pretrained checkpoint
|
@@ -67,6 +68,7 @@ class RBLNDiffusionMixin:
|
|
67
68
|
|
68
69
|
_connected_classes = {}
|
69
70
|
_submodules = []
|
71
|
+
_optional_submodules = []
|
70
72
|
_prefix = {}
|
71
73
|
_rbln_config_class = None
|
72
74
|
_hf_class = None
|
@@ -184,31 +186,42 @@ class RBLNDiffusionMixin:
|
|
184
186
|
if export:
|
185
187
|
# keep submodules if user passed any of them.
|
186
188
|
passed_submodules = {
|
187
|
-
name: kwargs.pop(name)
|
189
|
+
name: kwargs.pop(name)
|
190
|
+
for name in cls._submodules + cls._optional_submodules
|
191
|
+
if isinstance(kwargs.get(name), RBLNModel)
|
188
192
|
}
|
189
193
|
|
190
194
|
else:
|
191
195
|
# raise error if any of submodules are torch module.
|
192
196
|
model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
|
193
|
-
for submodule_name in cls._submodules:
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
+
for submodule_name in cls._submodules + cls._optional_submodules:
|
198
|
+
passed_submodule = kwargs.get(submodule_name, None)
|
199
|
+
|
200
|
+
if passed_submodule is None:
|
201
|
+
module_name, class_name = model_index_config[submodule_name]
|
202
|
+
if module_name != "optimum.rbln":
|
203
|
+
raise ValueError(
|
204
|
+
f"Invalid module_name '{module_name}' found in model_index.json for "
|
205
|
+
f"submodule '{submodule_name}'. "
|
206
|
+
"Expected 'optimum.rbln'. Please check the model_index.json configuration."
|
207
|
+
"If you want to compile, set `export=True`."
|
208
|
+
)
|
209
|
+
|
210
|
+
submodule_cls = get_rbln_model_cls(class_name)
|
211
|
+
submodule_config = getattr(rbln_config, submodule_name)
|
212
|
+
submodule = submodule_cls.from_pretrained(
|
213
|
+
model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
|
197
214
|
)
|
198
215
|
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
216
|
+
else:
|
217
|
+
if passed_submodule.__class__.__name__.startswith("RBLN"):
|
218
|
+
submodule = passed_submodule
|
219
|
+
|
220
|
+
elif isinstance(passed_submodule, torch.nn.Module):
|
221
|
+
raise AssertionError(
|
222
|
+
f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
|
223
|
+
)
|
206
224
|
|
207
|
-
submodule_cls = get_rbln_model_cls(class_name)
|
208
|
-
submodule_config = getattr(rbln_config, submodule_name)
|
209
|
-
submodule = submodule_cls.from_pretrained(
|
210
|
-
model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
|
211
|
-
)
|
212
225
|
kwargs[submodule_name] = submodule
|
213
226
|
|
214
227
|
with ContextRblnConfig(
|
@@ -352,10 +365,16 @@ class RBLNDiffusionMixin:
|
|
352
365
|
# Causing warning messeages.
|
353
366
|
|
354
367
|
update_dict = {}
|
355
|
-
for submodule_name in cls._submodules:
|
368
|
+
for submodule_name in cls._submodules + cls._optional_submodules:
|
356
369
|
# replace submodule
|
357
|
-
|
358
|
-
|
370
|
+
if submodule_name in submodules:
|
371
|
+
setattr(model, submodule_name, submodules[submodule_name])
|
372
|
+
update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
|
373
|
+
else:
|
374
|
+
# It assumes that the modules in _optional_components is compiled
|
375
|
+
# and already registered as an attribute of the model.
|
376
|
+
update_dict[submodule_name] = ("optimum.rbln", getattr(model, submodule_name).__class__.__name__)
|
377
|
+
|
359
378
|
if cls._load_connected_pipes:
|
360
379
|
for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
|
361
380
|
prefix = cls._prefix.get(connected_pipe_name, "")
|
@@ -386,7 +405,7 @@ class RBLNDiffusionMixin:
|
|
386
405
|
return model
|
387
406
|
|
388
407
|
def get_compiled_image_size(self):
|
389
|
-
if hasattr(self, "vae"):
|
408
|
+
if hasattr(self, "vae") and hasattr(self.vae, "image_size"):
|
390
409
|
compiled_image_size = self.vae.image_size
|
391
410
|
else:
|
392
411
|
compiled_image_size = None
|
@@ -20,6 +20,7 @@ from transformers.utils import _LazyModule
|
|
20
20
|
_import_structure = {
|
21
21
|
"autoencoders": [
|
22
22
|
"RBLNAutoencoderKL",
|
23
|
+
"RBLNAutoencoderKLCosmos",
|
23
24
|
"RBLNVQModel",
|
24
25
|
],
|
25
26
|
"unets": [
|
@@ -28,6 +29,7 @@ _import_structure = {
|
|
28
29
|
"controlnet": ["RBLNControlNetModel"],
|
29
30
|
"transformers": [
|
30
31
|
"RBLNPriorTransformer",
|
32
|
+
"RBLNCosmosTransformer3DModel",
|
31
33
|
"RBLNSD3Transformer2DModel",
|
32
34
|
],
|
33
35
|
}
|
@@ -35,10 +37,12 @@ _import_structure = {
|
|
35
37
|
if TYPE_CHECKING:
|
36
38
|
from .autoencoders import (
|
37
39
|
RBLNAutoencoderKL,
|
40
|
+
RBLNAutoencoderKLCosmos,
|
38
41
|
RBLNVQModel,
|
39
42
|
)
|
40
43
|
from .controlnet import RBLNControlNetModel
|
41
44
|
from .transformers import (
|
45
|
+
RBLNCosmosTransformer3DModel,
|
42
46
|
RBLNPriorTransformer,
|
43
47
|
RBLNSD3Transformer2DModel,
|
44
48
|
)
|