optimum-rbln 0.8.1a5__py3-none-any.whl → 0.8.1a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. optimum/rbln/__init__.py +18 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +21 -1
  4. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  5. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +82 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_cosmos_transformer.py +68 -0
  8. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  9. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +0 -4
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +110 -0
  11. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +0 -2
  12. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +0 -4
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +1 -4
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +0 -4
  15. optimum/rbln/diffusers/modeling_diffusers.py +57 -40
  16. optimum/rbln/diffusers/models/__init__.py +4 -0
  17. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  18. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +6 -1
  19. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
  20. optimum/rbln/diffusers/models/autoencoders/vae.py +49 -5
  21. optimum/rbln/diffusers/models/autoencoders/vq_model.py +6 -1
  22. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  23. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  25. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  26. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  27. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +451 -0
  28. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  29. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  30. optimum/rbln/modeling.py +38 -2
  31. optimum/rbln/modeling_base.py +18 -2
  32. optimum/rbln/transformers/modeling_generic.py +3 -3
  33. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  34. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  35. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  36. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  37. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +13 -1
  38. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +15 -0
  39. optimum/rbln/transformers/models/clip/configuration_clip.py +12 -2
  40. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -1
  41. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +22 -20
  42. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +6 -1
  43. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +8 -0
  44. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  45. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  46. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -3
  47. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +8 -0
  48. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +16 -0
  49. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -0
  50. optimum/rbln/transformers/models/resnet/configuration_resnet.py +6 -1
  51. optimum/rbln/transformers/models/resnet/modeling_resnet.py +5 -1
  52. optimum/rbln/transformers/models/roberta/configuration_roberta.py +12 -2
  53. optimum/rbln/transformers/models/roberta/modeling_roberta.py +16 -0
  54. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +6 -2
  55. optimum/rbln/transformers/models/siglip/configuration_siglip.py +7 -0
  56. optimum/rbln/transformers/models/siglip/modeling_siglip.py +7 -0
  57. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  58. optimum/rbln/transformers/models/t5/modeling_t5.py +10 -4
  59. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +7 -0
  60. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +6 -2
  61. optimum/rbln/transformers/models/vit/configuration_vit.py +6 -1
  62. optimum/rbln/transformers/models/vit/modeling_vit.py +7 -1
  63. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +7 -0
  64. optimum/rbln/transformers/models/whisper/configuration_whisper.py +7 -0
  65. optimum/rbln/transformers/models/whisper/modeling_whisper.py +6 -2
  66. optimum/rbln/utils/runtime_utils.py +49 -1
  67. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/METADATA +1 -1
  68. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/RECORD +70 -60
  69. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/WHEEL +0 -0
  70. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -138,8 +138,17 @@ _import_structure = {
138
138
  "diffusers": [
139
139
  "RBLNAutoencoderKL",
140
140
  "RBLNAutoencoderKLConfig",
141
+ "RBLNAutoencoderKLCosmos",
142
+ "RBLNAutoencoderKLCosmosConfig",
141
143
  "RBLNControlNetModel",
142
144
  "RBLNControlNetModelConfig",
145
+ "RBLNCosmosTextToWorldPipeline",
146
+ "RBLNCosmosVideoToWorldPipeline",
147
+ "RBLNCosmosTextToWorldPipelineConfig",
148
+ "RBLNCosmosVideoToWorldPipelineConfig",
149
+ "RBLNCosmosSafetyChecker",
150
+ "RBLNCosmosTransformer3DModel",
151
+ "RBLNCosmosTransformer3DModelConfig",
143
152
  "RBLNDiffusionMixin",
144
153
  "RBLNKandinskyV22CombinedPipeline",
145
154
  "RBLNKandinskyV22CombinedPipelineConfig",
@@ -202,8 +211,17 @@ if TYPE_CHECKING:
202
211
  from .diffusers import (
203
212
  RBLNAutoencoderKL,
204
213
  RBLNAutoencoderKLConfig,
214
+ RBLNAutoencoderKLCosmos,
215
+ RBLNAutoencoderKLCosmosConfig,
205
216
  RBLNControlNetModel,
206
217
  RBLNControlNetModelConfig,
218
+ RBLNCosmosSafetyChecker,
219
+ RBLNCosmosTextToWorldPipeline,
220
+ RBLNCosmosTextToWorldPipelineConfig,
221
+ RBLNCosmosTransformer3DModel,
222
+ RBLNCosmosTransformer3DModelConfig,
223
+ RBLNCosmosVideoToWorldPipeline,
224
+ RBLNCosmosVideoToWorldPipelineConfig,
207
225
  RBLNDiffusionMixin,
208
226
  RBLNKandinskyV22CombinedPipeline,
209
227
  RBLNKandinskyV22CombinedPipelineConfig,
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.8.1a5'
21
- __version_tuple__ = version_tuple = (0, 8, 1, 'a5')
20
+ __version__ = version = '0.8.1a7'
21
+ __version_tuple__ = version_tuple = (0, 8, 1, 'a7')
@@ -18,14 +18,21 @@ from diffusers.pipelines.pipeline_utils import ALL_IMPORTABLE_CLASSES, LOADABLE_
18
18
  from transformers.utils import _LazyModule
19
19
 
20
20
 
21
- LOADABLE_CLASSES["optimum.rbln"] = {"RBLNBaseModel": ["save_pretrained", "from_pretrained"]}
21
+ LOADABLE_CLASSES["optimum.rbln"] = {
22
+ "RBLNBaseModel": ["save_pretrained", "from_pretrained"],
23
+ "RBLNCosmosSafetyChecker": ["save_pretrained", "from_pretrained"],
24
+ }
22
25
  ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES["optimum.rbln"])
23
26
 
24
27
 
25
28
  _import_structure = {
26
29
  "configurations": [
27
30
  "RBLNAutoencoderKLConfig",
31
+ "RBLNAutoencoderKLCosmosConfig",
28
32
  "RBLNControlNetModelConfig",
33
+ "RBLNCosmosTextToWorldPipelineConfig",
34
+ "RBLNCosmosVideoToWorldPipelineConfig",
35
+ "RBLNCosmosTransformer3DModelConfig",
29
36
  "RBLNKandinskyV22CombinedPipelineConfig",
30
37
  "RBLNKandinskyV22Img2ImgCombinedPipelineConfig",
31
38
  "RBLNKandinskyV22Img2ImgPipelineConfig",
@@ -52,6 +59,9 @@ _import_structure = {
52
59
  "RBLNVQModelConfig",
53
60
  ],
54
61
  "pipelines": [
62
+ "RBLNCosmosTextToWorldPipeline",
63
+ "RBLNCosmosVideoToWorldPipeline",
64
+ "RBLNCosmosSafetyChecker",
55
65
  "RBLNKandinskyV22CombinedPipeline",
56
66
  "RBLNKandinskyV22Img2ImgCombinedPipeline",
57
67
  "RBLNKandinskyV22InpaintCombinedPipeline",
@@ -76,8 +86,10 @@ _import_structure = {
76
86
  ],
77
87
  "models": [
78
88
  "RBLNAutoencoderKL",
89
+ "RBLNAutoencoderKLCosmos",
79
90
  "RBLNUNet2DConditionModel",
80
91
  "RBLNControlNetModel",
92
+ "RBLNCosmosTransformer3DModel",
81
93
  "RBLNSD3Transformer2DModel",
82
94
  "RBLNPriorTransformer",
83
95
  "RBLNVQModel",
@@ -90,7 +102,11 @@ _import_structure = {
90
102
  if TYPE_CHECKING:
91
103
  from .configurations import (
92
104
  RBLNAutoencoderKLConfig,
105
+ RBLNAutoencoderKLCosmosConfig,
93
106
  RBLNControlNetModelConfig,
107
+ RBLNCosmosTextToWorldPipelineConfig,
108
+ RBLNCosmosTransformer3DModelConfig,
109
+ RBLNCosmosVideoToWorldPipelineConfig,
94
110
  RBLNKandinskyV22CombinedPipelineConfig,
95
111
  RBLNKandinskyV22Img2ImgCombinedPipelineConfig,
96
112
  RBLNKandinskyV22Img2ImgPipelineConfig,
@@ -120,12 +136,16 @@ if TYPE_CHECKING:
120
136
  from .models import (
121
137
  RBLNAutoencoderKL,
122
138
  RBLNControlNetModel,
139
+ RBLNCosmosTransformer3DModel,
123
140
  RBLNPriorTransformer,
124
141
  RBLNSD3Transformer2DModel,
125
142
  RBLNUNet2DConditionModel,
126
143
  RBLNVQModel,
127
144
  )
128
145
  from .pipelines import (
146
+ RBLNCosmosSafetyChecker,
147
+ RBLNCosmosTextToWorldPipeline,
148
+ RBLNCosmosVideoToWorldPipeline,
129
149
  RBLNKandinskyV22CombinedPipeline,
130
150
  RBLNKandinskyV22Img2ImgCombinedPipeline,
131
151
  RBLNKandinskyV22Img2ImgPipeline,
@@ -1,12 +1,16 @@
1
1
  from .models import (
2
2
  RBLNAutoencoderKLConfig,
3
+ RBLNAutoencoderKLCosmosConfig,
3
4
  RBLNControlNetModelConfig,
5
+ RBLNCosmosTransformer3DModelConfig,
4
6
  RBLNPriorTransformerConfig,
5
7
  RBLNSD3Transformer2DModelConfig,
6
8
  RBLNUNet2DConditionModelConfig,
7
9
  RBLNVQModelConfig,
8
10
  )
9
11
  from .pipelines import (
12
+ RBLNCosmosTextToWorldPipelineConfig,
13
+ RBLNCosmosVideoToWorldPipelineConfig,
10
14
  RBLNKandinskyV22CombinedPipelineConfig,
11
15
  RBLNKandinskyV22Img2ImgCombinedPipelineConfig,
12
16
  RBLNKandinskyV22Img2ImgPipelineConfig,
@@ -1,5 +1,7 @@
1
1
  from .configuration_autoencoder_kl import RBLNAutoencoderKLConfig
2
+ from .configuration_autoencoder_kl_cosmos import RBLNAutoencoderKLCosmosConfig
2
3
  from .configuration_controlnet import RBLNControlNetModelConfig
4
+ from .configuration_cosmos_transformer import RBLNCosmosTransformer3DModelConfig
3
5
  from .configuration_prior_transformer import RBLNPriorTransformerConfig
4
6
  from .configuration_transformer_sd3 import RBLNSD3Transformer2DModelConfig
5
7
  from .configuration_unet_2d_condition import RBLNUNet2DConditionModelConfig
@@ -0,0 +1,82 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.logging import get_logger
19
+
20
+
21
+ logger = get_logger(__name__)
22
+
23
+
24
+ class RBLNAutoencoderKLCosmosConfig(RBLNModelConfig):
25
+ def __init__(
26
+ self,
27
+ batch_size: Optional[int] = None,
28
+ uses_encoder: Optional[bool] = None,
29
+ num_frames: Optional[int] = None,
30
+ height: Optional[int] = None,
31
+ width: Optional[int] = None,
32
+ num_channels_latents: Optional[int] = None,
33
+ vae_scale_factor_temporal: Optional[int] = None,
34
+ vae_scale_factor_spatial: Optional[int] = None,
35
+ use_slicing: Optional[bool] = None,
36
+ **kwargs,
37
+ ):
38
+ """
39
+ Args:
40
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
41
+ uses_encoder (Optional[bool]): Whether to include the encoder part of the VAE in the model.
42
+ When False, only the decoder is used (for latent-to-video conversion).
43
+ num_frames (Optional[int]): The number of frames in the generated video. Defaults to 121.
44
+ height (Optional[int]): The height in pixels of the generated video. Defaults to 704.
45
+ width (Optional[int]): The width in pixels of the generated video. Defaults to 1280.
46
+ num_channels_latents (Optional[int]): The number of channels in latent space.
47
+ vae_scale_factor_temporal (Optional[int]): The scaling factor between time space and latent space.
48
+ Determines how much shorter the latent representations are compared to the original videos.
49
+ vae_scale_factor_spatial (Optional[int]): The scaling factor between pixel space and latent space.
50
+ Determines how much smaller the latent representations are compared to the original videos.
51
+ use_slicing (Optional[Bool]): Enable sliced VAE encoding and decoding.
52
+ If True, the VAE will split the input tensor in slices to compute encoding or decoding in several steps.
53
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
54
+
55
+ Raises:
56
+ ValueError: If batch_size is not a positive integer.
57
+ """
58
+ super().__init__(**kwargs)
59
+ # Since the Cosmos VAE Decoder already requires approximately 7.9 GiB of memory,
60
+ # Optimum-rbln cannot execute this model on RBLN-CA12 when the batch size > 1.
61
+ # However, the Cosmos VAE Decoder propose batch slicing when the batch size is greater than 1,
62
+ # Optimum-rbln utilize this method by compiling with batch_size=1 to enable batch slicing.
63
+ self.batch_size = batch_size or 1
64
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
65
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
66
+ elif self.batch_size > 1:
67
+ logger.warning("The batch size of Cosmos VAE Decoder will be explicitly 1 for memory efficiency.")
68
+ self.batch_size = 1
69
+
70
+ self.uses_encoder = uses_encoder
71
+ self.num_frames = num_frames or 121
72
+ self.height = height or 704
73
+ self.width = width or 1280
74
+
75
+ self.num_channels_latents = num_channels_latents
76
+ self.vae_scale_factor_temporal = vae_scale_factor_temporal
77
+ self.vae_scale_factor_spatial = vae_scale_factor_spatial
78
+ self.use_slicing = use_slicing or False
79
+
80
+ @property
81
+ def image_size(self):
82
+ return (self.height, self.width)
@@ -0,0 +1,68 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNCosmosTransformer3DModelConfig(RBLNModelConfig):
21
+ def __init__(
22
+ self,
23
+ batch_size: Optional[int] = None,
24
+ num_frames: Optional[int] = None,
25
+ height: Optional[int] = None,
26
+ width: Optional[int] = None,
27
+ fps: Optional[int] = None,
28
+ max_seq_len: Optional[int] = None,
29
+ embedding_dim: Optional[int] = None,
30
+ num_channels_latents: Optional[int] = None,
31
+ num_latent_frames: Optional[int] = None,
32
+ latent_height: Optional[int] = None,
33
+ latent_width: Optional[int] = None,
34
+ **kwargs,
35
+ ):
36
+ """
37
+ Args:
38
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
39
+ num_frames (Optional[int]): The number of frames in the generated video. Defaults to 121.
40
+ height (Optional[int]): The height in pixels of the generated video. Defaults to 704.
41
+ width (Optional[int]): The width in pixels of the generated video. Defaults to 1280.
42
+ fps (Optional[int]): The frames per second of the generated video. Defaults to 30.
43
+ max_seq_len (Optional[int]): Maximum sequence length of prompt embeds.
44
+ embedding_dim (Optional[int]): Embedding vector dimension of prompt embeds.
45
+ num_channels_latents (Optional[int]): The number of channels in latent space.
46
+ latent_height (Optional[int]): The height in pixels in latent space.
47
+ latent_width (Optional[int]): The width in pixels in latent space.
48
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
49
+
50
+ Raises:
51
+ ValueError: If batch_size is not a positive integer.
52
+ """
53
+ super().__init__(**kwargs)
54
+ self.batch_size = batch_size or 1
55
+ self.num_frames = num_frames or 121
56
+ self.height = height or 704
57
+ self.width = width or 1280
58
+ self.fps = fps or 30
59
+
60
+ self.max_seq_len = max_seq_len
61
+ self.num_channels_latents = num_channels_latents
62
+ self.num_latent_frames = num_latent_frames
63
+ self.latent_height = latent_height
64
+ self.latent_width = latent_width
65
+ self.embedding_dim = embedding_dim
66
+
67
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
68
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
@@ -4,6 +4,7 @@ from .configuration_controlnet import (
4
4
  RBLNStableDiffusionXLControlNetImg2ImgPipelineConfig,
5
5
  RBLNStableDiffusionXLControlNetPipelineConfig,
6
6
  )
7
+ from .configuration_cosmos import RBLNCosmosTextToWorldPipelineConfig, RBLNCosmosVideoToWorldPipelineConfig
7
8
  from .configuration_kandinsky2_2 import (
8
9
  RBLNKandinskyV22CombinedPipelineConfig,
9
10
  RBLNKandinskyV22Img2ImgCombinedPipelineConfig,
@@ -20,10 +20,6 @@ from ..models import RBLNAutoencoderKLConfig, RBLNControlNetModelConfig, RBLNUNe
20
20
 
21
21
 
22
22
  class RBLNStableDiffusionControlNetPipelineBaseConfig(RBLNModelConfig):
23
- """
24
- Base configuration for Stable Diffusion ControlNet pipelines.
25
- """
26
-
27
23
  submodules = ["text_encoder", "unet", "vae", "controlnet"]
28
24
  _vae_uses_encoder = False
29
25
 
@@ -0,0 +1,110 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....transformers import RBLNT5EncoderModelConfig
19
+ from ....utils.logging import get_logger
20
+ from ...pipelines.cosmos.cosmos_guardrail import RBLNCosmosSafetyCheckerConfig
21
+ from ..models import RBLNAutoencoderKLCosmosConfig, RBLNCosmosTransformer3DModelConfig
22
+
23
+
24
+ logger = get_logger(__name__)
25
+
26
+
27
+ class _RBLNCosmosPipelineBaseConfig(RBLNModelConfig):
28
+ submodules = ["text_encoder", "transformer", "vae", "safety_checker"]
29
+ _vae_uses_encoder = False
30
+
31
+ def __init__(
32
+ self,
33
+ text_encoder: Optional[RBLNT5EncoderModelConfig] = None,
34
+ transformer: Optional[RBLNCosmosTransformer3DModelConfig] = None,
35
+ vae: Optional[RBLNAutoencoderKLCosmosConfig] = None,
36
+ safety_checker: Optional[RBLNCosmosSafetyCheckerConfig] = None,
37
+ *,
38
+ batch_size: Optional[int] = None,
39
+ height: Optional[int] = None,
40
+ width: Optional[int] = None,
41
+ num_frames: Optional[int] = None,
42
+ fps: Optional[int] = None,
43
+ max_seq_len: Optional[int] = None,
44
+ **kwargs,
45
+ ):
46
+ """
47
+ Args:
48
+ text_encoder (Optional[RBLNT5EncoderModelConfig]): Configuration for the text encoder component.
49
+ Initialized as RBLNT5EncoderModelConfig if not provided.
50
+ transformer (Optional[RBLNCosmosTransformer3DModelConfig]): Configuration for the UNet model component.
51
+ Initialized as RBLNCosmosTransformer3DModelConfig if not provided.
52
+ vae (Optional[RBLNAutoencoderKLCosmosConfig]): Configuration for the VAE model component.
53
+ Initialized as RBLNAutoencoderKLCosmosConfig if not provided.
54
+ safety_checker (Optional[RBLNCosmosSafetyCheckerConfig]): Configuration for the safety checker component.
55
+ Initialized as RBLNCosmosSafetyCheckerConfig if not provided.
56
+ batch_size (Optional[int]): Batch size for inference, applied to all submodules.
57
+ height (Optional[int]): Height of the generated videos.
58
+ width (Optional[int]): Width of the generated videos.
59
+ num_frames (Optional[int]): The number of frames in the generated video.
60
+ fps (Optional[int]): The frames per second of the generated video.
61
+ max_seq_len (Optional[int]): Maximum sequence length supported by the model.
62
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
63
+ """
64
+ super().__init__(**kwargs)
65
+
66
+ self.text_encoder = self.init_submodule_config(
67
+ RBLNT5EncoderModelConfig, text_encoder, batch_size=batch_size, max_seq_len=max_seq_len
68
+ )
69
+ self.transformer = self.init_submodule_config(
70
+ RBLNCosmosTransformer3DModelConfig,
71
+ transformer,
72
+ batch_size=batch_size,
73
+ max_seq_len=max_seq_len,
74
+ height=height,
75
+ width=width,
76
+ num_frames=num_frames,
77
+ fps=fps,
78
+ )
79
+ self.vae = self.init_submodule_config(
80
+ RBLNAutoencoderKLCosmosConfig,
81
+ vae,
82
+ batch_size=batch_size,
83
+ uses_encoder=self.__class__._vae_uses_encoder,
84
+ height=height,
85
+ width=width,
86
+ num_frames=num_frames,
87
+ )
88
+ self.safety_checker = self.init_submodule_config(
89
+ RBLNCosmosSafetyCheckerConfig,
90
+ safety_checker,
91
+ batch_size=batch_size,
92
+ height=height,
93
+ width=width,
94
+ )
95
+
96
+ @property
97
+ def batch_size(self):
98
+ return self.vae.batch_size
99
+
100
+ @property
101
+ def max_seq_len(self):
102
+ return self.text_encoder.max_seq_len
103
+
104
+
105
+ class RBLNCosmosTextToWorldPipelineConfig(_RBLNCosmosPipelineBaseConfig):
106
+ _vae_uses_encoder = False
107
+
108
+
109
+ class RBLNCosmosVideoToWorldPipelineConfig(_RBLNCosmosPipelineBaseConfig):
110
+ _vae_uses_encoder = True
@@ -21,8 +21,6 @@ from ..models.configuration_prior_transformer import RBLNPriorTransformerConfig
21
21
 
22
22
 
23
23
  class RBLNKandinskyV22PipelineBaseConfig(RBLNModelConfig):
24
- """Base configuration class for Kandinsky V2.2 decoder pipelines."""
25
-
26
24
  submodules = ["unet", "movq"]
27
25
  _movq_uses_encoder = False
28
26
 
@@ -20,10 +20,6 @@ from ..models import RBLNAutoencoderKLConfig, RBLNUNet2DConditionModelConfig
20
20
 
21
21
 
22
22
  class RBLNStableDiffusionPipelineBaseConfig(RBLNModelConfig):
23
- """
24
- Base configuration for Stable Diffusion pipelines.
25
- """
26
-
27
23
  submodules = ["text_encoder", "unet", "vae"]
28
24
  _vae_uses_encoder = False
29
25
 
@@ -20,10 +20,6 @@ from ..models import RBLNAutoencoderKLConfig, RBLNSD3Transformer2DModelConfig
20
20
 
21
21
 
22
22
  class RBLNStableDiffusion3PipelineBaseConfig(RBLNModelConfig):
23
- """
24
- Base configuration for Stable Diffusion 3 pipelines.
25
- """
26
-
27
23
  submodules = ["transformer", "text_encoder", "text_encoder_2", "text_encoder_3", "vae"]
28
24
  _vae_uses_encoder = False
29
25
 
@@ -115,6 +111,7 @@ class RBLNStableDiffusion3PipelineBaseConfig(RBLNModelConfig):
115
111
  text_encoder_3,
116
112
  batch_size=batch_size,
117
113
  max_seq_len=max_seq_len,
114
+ model_input_names=["input_ids"],
118
115
  )
119
116
  self.transformer = self.init_submodule_config(
120
117
  RBLNSD3Transformer2DModelConfig,
@@ -20,10 +20,6 @@ from ..models import RBLNAutoencoderKLConfig, RBLNUNet2DConditionModelConfig
20
20
 
21
21
 
22
22
  class RBLNStableDiffusionXLPipelineBaseConfig(RBLNModelConfig):
23
- """
24
- Base configuration for Stable Diffusion XL pipelines.
25
- """
26
-
27
23
  submodules = ["text_encoder", "text_encoder_2", "unet", "vae"]
28
24
  _vae_uses_encoder = False
29
25
 
@@ -45,7 +45,7 @@ class RBLNDiffusionMixin:
45
45
  To use this mixin:
46
46
 
47
47
  1. Create a new pipeline class that inherits from both this mixin and the original StableDiffusionPipeline.
48
- 2. Define the required _submodules class variable listing the components to be compiled.
48
+ 2. Define the required _submodules and _optional_submodules class variable listing the components to be compiled.
49
49
 
50
50
  Example:
51
51
  ```python
@@ -55,6 +55,7 @@ class RBLNDiffusionMixin:
55
55
 
56
56
  Class Variables:
57
57
  _submodules: List of submodule names that should be compiled (typically ["text_encoder", "unet", "vae"])
58
+ _optional_submodules: List of submodule names compiled without inheriting RBLNModel (typically ["safety_checker"])
58
59
 
59
60
  Methods:
60
61
  from_pretrained: Creates and optionally compiles a model from a pretrained checkpoint
@@ -67,6 +68,7 @@ class RBLNDiffusionMixin:
67
68
 
68
69
  _connected_classes = {}
69
70
  _submodules = []
71
+ _optional_submodules = []
70
72
  _prefix = {}
71
73
  _rbln_config_class = None
72
74
  _hf_class = None
@@ -184,31 +186,42 @@ class RBLNDiffusionMixin:
184
186
  if export:
185
187
  # keep submodules if user passed any of them.
186
188
  passed_submodules = {
187
- name: kwargs.pop(name) for name in cls._submodules if isinstance(kwargs.get(name), RBLNModel)
189
+ name: kwargs.pop(name)
190
+ for name in cls._submodules + cls._optional_submodules
191
+ if isinstance(kwargs.get(name), RBLNModel)
188
192
  }
189
193
 
190
194
  else:
191
195
  # raise error if any of submodules are torch module.
192
196
  model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
193
- for submodule_name in cls._submodules:
194
- if isinstance(kwargs.get(submodule_name), torch.nn.Module):
195
- raise AssertionError(
196
- f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
197
+ for submodule_name in cls._submodules + cls._optional_submodules:
198
+ passed_submodule = kwargs.get(submodule_name, None)
199
+
200
+ if passed_submodule is None:
201
+ module_name, class_name = model_index_config[submodule_name]
202
+ if module_name != "optimum.rbln":
203
+ raise ValueError(
204
+ f"Invalid module_name '{module_name}' found in model_index.json for "
205
+ f"submodule '{submodule_name}'. "
206
+ "Expected 'optimum.rbln'. Please check the model_index.json configuration."
207
+ "If you want to compile, set `export=True`."
208
+ )
209
+
210
+ submodule_cls = get_rbln_model_cls(class_name)
211
+ submodule_config = getattr(rbln_config, submodule_name)
212
+ submodule = submodule_cls.from_pretrained(
213
+ model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
197
214
  )
198
215
 
199
- module_name, class_name = model_index_config[submodule_name]
200
- if module_name != "optimum.rbln":
201
- raise ValueError(
202
- f"Invalid module_name '{module_name}' found in model_index.json for "
203
- f"submodule '{submodule_name}'. "
204
- "Expected 'optimum.rbln'. Please check the model_index.json configuration."
205
- )
216
+ else:
217
+ if passed_submodule.__class__.__name__.startswith("RBLN"):
218
+ submodule = passed_submodule
219
+
220
+ elif isinstance(passed_submodule, torch.nn.Module):
221
+ raise AssertionError(
222
+ f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
223
+ )
206
224
 
207
- submodule_cls = get_rbln_model_cls(class_name)
208
- submodule_config = getattr(rbln_config, submodule_name)
209
- submodule = submodule_cls.from_pretrained(
210
- model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
211
- )
212
225
  kwargs[submodule_name] = submodule
213
226
 
214
227
  with ContextRblnConfig(
@@ -352,10 +365,16 @@ class RBLNDiffusionMixin:
352
365
  # Causing warning messeages.
353
366
 
354
367
  update_dict = {}
355
- for submodule_name in cls._submodules:
368
+ for submodule_name in cls._submodules + cls._optional_submodules:
356
369
  # replace submodule
357
- setattr(model, submodule_name, submodules[submodule_name])
358
- update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
370
+ if submodule_name in submodules:
371
+ setattr(model, submodule_name, submodules[submodule_name])
372
+ update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
373
+ else:
374
+ # It assumes that the modules in _optional_components is compiled
375
+ # and already registered as an attribute of the model.
376
+ update_dict[submodule_name] = ("optimum.rbln", getattr(model, submodule_name).__class__.__name__)
377
+
359
378
  if cls._load_connected_pipes:
360
379
  for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
361
380
  prefix = cls._prefix.get(connected_pipe_name, "")
@@ -386,31 +405,29 @@ class RBLNDiffusionMixin:
386
405
  return model
387
406
 
388
407
  def get_compiled_image_size(self):
389
- if hasattr(self, "vae"):
408
+ if hasattr(self, "vae") and hasattr(self.vae, "image_size"):
390
409
  compiled_image_size = self.vae.image_size
391
410
  else:
392
411
  compiled_image_size = None
393
412
  return compiled_image_size
394
413
 
395
414
  def handle_additional_kwargs(self, **kwargs):
396
- """
397
- Function to handle additional compile-time parameters during inference.
398
-
399
- If the additional variable is determined by another module, this method should be overrided.
400
-
401
- Example:
402
- ```python
403
- if hasattr(self, "movq"):
404
- compiled_image_size = self.movq.image_size
405
- kwargs["height"] = compiled_image_size[0]
406
- kwargs["width"] = compiled_image_size[1]
407
-
408
- compiled_num_frames = self.unet.rbln_config.num_frames
409
- if compiled_num_frames is not None:
410
- kwargs["num_frames"] = compiled_num_frames
411
- return kwargs
412
- ```
413
- """
415
+ # Function to handle additional compile-time parameters during inference.
416
+
417
+ # If the additional variable is determined by another module, this method should be overrided.
418
+
419
+ # Example:
420
+ # ```python
421
+ # if hasattr(self, "movq"):
422
+ # compiled_image_size = self.movq.image_size
423
+ # kwargs["height"] = compiled_image_size[0]
424
+ # kwargs["width"] = compiled_image_size[1]
425
+
426
+ # compiled_num_frames = self.unet.rbln_config.num_frames
427
+ # if compiled_num_frames is not None:
428
+ # kwargs["num_frames"] = compiled_num_frames
429
+ # return kwargs
430
+ # ```
414
431
  return kwargs
415
432
 
416
433
  @remove_compile_time_kwargs
@@ -20,6 +20,7 @@ from transformers.utils import _LazyModule
20
20
  _import_structure = {
21
21
  "autoencoders": [
22
22
  "RBLNAutoencoderKL",
23
+ "RBLNAutoencoderKLCosmos",
23
24
  "RBLNVQModel",
24
25
  ],
25
26
  "unets": [
@@ -28,6 +29,7 @@ _import_structure = {
28
29
  "controlnet": ["RBLNControlNetModel"],
29
30
  "transformers": [
30
31
  "RBLNPriorTransformer",
32
+ "RBLNCosmosTransformer3DModel",
31
33
  "RBLNSD3Transformer2DModel",
32
34
  ],
33
35
  }
@@ -35,10 +37,12 @@ _import_structure = {
35
37
  if TYPE_CHECKING:
36
38
  from .autoencoders import (
37
39
  RBLNAutoencoderKL,
40
+ RBLNAutoencoderKLCosmos,
38
41
  RBLNVQModel,
39
42
  )
40
43
  from .controlnet import RBLNControlNetModel
41
44
  from .transformers import (
45
+ RBLNCosmosTransformer3DModel,
42
46
  RBLNPriorTransformer,
43
47
  RBLNSD3Transformer2DModel,
44
48
  )