optimum-rbln 0.8.1a5__py3-none-any.whl → 0.8.1a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. optimum/rbln/__init__.py +18 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +21 -1
  4. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  5. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +82 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_cosmos_transformer.py +68 -0
  8. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  9. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +0 -4
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +110 -0
  11. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +0 -2
  12. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +0 -4
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +1 -4
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +0 -4
  15. optimum/rbln/diffusers/modeling_diffusers.py +57 -40
  16. optimum/rbln/diffusers/models/__init__.py +4 -0
  17. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  18. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +6 -1
  19. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
  20. optimum/rbln/diffusers/models/autoencoders/vae.py +49 -5
  21. optimum/rbln/diffusers/models/autoencoders/vq_model.py +6 -1
  22. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  23. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  25. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  26. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  27. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +451 -0
  28. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  29. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  30. optimum/rbln/modeling.py +38 -2
  31. optimum/rbln/modeling_base.py +18 -2
  32. optimum/rbln/transformers/modeling_generic.py +3 -3
  33. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  34. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  35. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  36. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  37. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +13 -1
  38. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +15 -0
  39. optimum/rbln/transformers/models/clip/configuration_clip.py +12 -2
  40. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -1
  41. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +22 -20
  42. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +6 -1
  43. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +8 -0
  44. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  45. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  46. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -3
  47. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +8 -0
  48. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +16 -0
  49. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -0
  50. optimum/rbln/transformers/models/resnet/configuration_resnet.py +6 -1
  51. optimum/rbln/transformers/models/resnet/modeling_resnet.py +5 -1
  52. optimum/rbln/transformers/models/roberta/configuration_roberta.py +12 -2
  53. optimum/rbln/transformers/models/roberta/modeling_roberta.py +16 -0
  54. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +6 -2
  55. optimum/rbln/transformers/models/siglip/configuration_siglip.py +7 -0
  56. optimum/rbln/transformers/models/siglip/modeling_siglip.py +7 -0
  57. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  58. optimum/rbln/transformers/models/t5/modeling_t5.py +10 -4
  59. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +7 -0
  60. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +6 -2
  61. optimum/rbln/transformers/models/vit/configuration_vit.py +6 -1
  62. optimum/rbln/transformers/models/vit/modeling_vit.py +7 -1
  63. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +7 -0
  64. optimum/rbln/transformers/models/whisper/configuration_whisper.py +7 -0
  65. optimum/rbln/transformers/models/whisper/modeling_whisper.py +6 -2
  66. optimum/rbln/utils/runtime_utils.py +49 -1
  67. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/METADATA +1 -1
  68. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/RECORD +70 -60
  69. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/WHEEL +0 -0
  70. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/licenses/LICENSE +0 -0
@@ -13,4 +13,5 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from .autoencoder_kl import RBLNAutoencoderKL
16
+ from .autoencoder_kl_cosmos import RBLNAutoencoderKLCosmos
16
17
  from .vq_model import RBLNVQModel
@@ -80,7 +80,12 @@ class RBLNAutoencoderKL(RBLNModel):
80
80
 
81
81
  wrapped_model.eval()
82
82
 
83
- compiled_models[model_name] = cls.compile(wrapped_model, rbln_compile_config=rbln_config.compile_cfgs[i])
83
+ compiled_models[model_name] = cls.compile(
84
+ wrapped_model,
85
+ rbln_compile_config=rbln_config.compile_cfgs[i],
86
+ create_runtimes=rbln_config.create_runtimes,
87
+ device=rbln_config.device_map[model_name],
88
+ )
84
89
 
85
90
  return compiled_models
86
91
 
@@ -0,0 +1,219 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Dict, List, Union
16
+
17
+ import rebel
18
+ import torch
19
+ from diffusers.models.autoencoders.autoencoder_kl_cosmos import AutoencoderKLCosmos, CosmosCausalConv3d
20
+ from diffusers.models.autoencoders.vae import DecoderOutput
21
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
22
+ from torch.nn import functional as F
23
+ from transformers import PretrainedConfig
24
+
25
+ from ....configuration_utils import RBLNCompileConfig
26
+ from ....modeling import RBLNModel
27
+ from ....utils.logging import get_logger
28
+ from ...configurations import RBLNAutoencoderKLCosmosConfig
29
+ from .vae import RBLNRuntimeCosmosVAEDecoder, RBLNRuntimeCosmosVAEEncoder, _VAECosmosDecoder, _VAECosmosEncoder
30
+
31
+
32
+ if TYPE_CHECKING:
33
+ import torch
34
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
35
+
36
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
37
+
38
+ logger = get_logger(__name__)
39
+
40
+
41
+ class RBLNAutoencoderKLCosmos(RBLNModel):
42
+ """
43
+ RBLN implementation of AutoencoderKLCosmos for diffusion models.
44
+
45
+ This model is used to accelerate AutoencoderKLCosmos models from diffusers library on RBLN NPUs.
46
+ It can be configured to include both encoder and decoder, or just the decoder part for latent-to-video
47
+ conversion.
48
+
49
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
50
+ the library implements for all its models.
51
+ """
52
+
53
+ auto_model_class = AutoencoderKLCosmos
54
+ hf_library_name = "diffusers"
55
+ _rbln_config_class = RBLNAutoencoderKLCosmosConfig
56
+
57
+ def __post_init__(self, **kwargs):
58
+ super().__post_init__(**kwargs)
59
+
60
+ if self.rbln_config.uses_encoder:
61
+ self.encoder = RBLNRuntimeCosmosVAEEncoder(
62
+ runtime=self.model[0], main_input_name="x", use_slicing=self.rbln_config.use_slicing
63
+ )
64
+
65
+ self.decoder = RBLNRuntimeCosmosVAEDecoder(
66
+ runtime=self.model[-1], main_input_name="z", use_slicing=self.rbln_config.use_slicing
67
+ )
68
+ self.image_size = self.rbln_config.image_size
69
+
70
+ @classmethod
71
+ def wrap_model_if_needed(
72
+ cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLCosmosConfig
73
+ ) -> torch.nn.Module:
74
+ decoder_model = _VAECosmosDecoder(model)
75
+ decoder_model.eval()
76
+
77
+ if rbln_config.uses_encoder:
78
+ encoder_model = _VAECosmosEncoder(model)
79
+ encoder_model.eval()
80
+ return encoder_model, decoder_model
81
+ else:
82
+ return decoder_model
83
+
84
+ @classmethod
85
+ def get_compiled_model(
86
+ cls, model, rbln_config: RBLNAutoencoderKLCosmosConfig
87
+ ) -> Dict[str, rebel.RBLNCompiledModel]:
88
+ def replaced_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
89
+ if self.temporal_pad != 0:
90
+ hidden_states_prev = hidden_states[:, :, :1, ...].repeat(1, 1, self.temporal_pad, 1, 1)
91
+ hidden_states = torch.cat([hidden_states_prev, hidden_states], dim=2)
92
+ hidden_states = F.pad(hidden_states, (*self.spatial_pad, 0, 0), mode=self.pad_mode, value=0.0)
93
+ return super(CosmosCausalConv3d, self).forward(hidden_states)
94
+
95
+ try:
96
+ original_forward = CosmosCausalConv3d.forward
97
+ CosmosCausalConv3d.forward = replaced_forward
98
+
99
+ compiled_models = {}
100
+ if rbln_config.uses_encoder:
101
+ encoder_model, decoder_model = cls.wrap_model_if_needed(model, rbln_config)
102
+ enc_compiled_model = cls.compile(
103
+ encoder_model,
104
+ rbln_compile_config=rbln_config.compile_cfgs[0],
105
+ create_runtimes=rbln_config.create_runtimes,
106
+ device=rbln_config.device_map["encoder"],
107
+ )
108
+ compiled_models["encoder"] = enc_compiled_model
109
+ else:
110
+ decoder_model = cls.wrap_model_if_needed(model, rbln_config)
111
+ dec_compiled_model = cls.compile(
112
+ decoder_model,
113
+ rbln_compile_config=rbln_config.compile_cfgs[-1],
114
+ create_runtimes=rbln_config.create_runtimes,
115
+ device=rbln_config.device_map["decoder"],
116
+ )
117
+ compiled_models["decoder"] = dec_compiled_model
118
+
119
+ finally:
120
+ CosmosCausalConv3d.forward = original_forward
121
+
122
+ return compiled_models
123
+
124
+ @classmethod
125
+ def update_rbln_config_using_pipe(
126
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
127
+ ) -> "RBLNDiffusionMixinConfig":
128
+ rbln_config.vae.num_channels_latents = pipe.transformer.config.out_channels
129
+ rbln_config.vae.vae_scale_factor_temporal = pipe.vae_scale_factor_temporal
130
+ rbln_config.vae.vae_scale_factor_spatial = pipe.vae_scale_factor_spatial
131
+ return rbln_config
132
+
133
+ @classmethod
134
+ def _update_rbln_config(
135
+ cls,
136
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
137
+ model: "PreTrainedModel",
138
+ model_config: "PretrainedConfig",
139
+ rbln_config: RBLNAutoencoderKLCosmosConfig,
140
+ ) -> RBLNAutoencoderKLCosmosConfig:
141
+ batch_size = 1 if rbln_config.use_slicing else rbln_config.batch_size
142
+ compile_cfgs = []
143
+ if rbln_config.uses_encoder:
144
+ vae_enc_input_info = [
145
+ (
146
+ "x",
147
+ [
148
+ batch_size,
149
+ model_config.in_channels,
150
+ rbln_config.num_frames,
151
+ rbln_config.height,
152
+ rbln_config.width,
153
+ ],
154
+ "float32",
155
+ ),
156
+ ]
157
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
158
+
159
+ num_latent_frames = (rbln_config.num_frames - 1) // rbln_config.vae_scale_factor_temporal + 1
160
+ latent_height = rbln_config.height // rbln_config.vae_scale_factor_spatial
161
+ latent_width = rbln_config.width // rbln_config.vae_scale_factor_spatial
162
+
163
+ vae_dec_input_info = [
164
+ (
165
+ "z",
166
+ [
167
+ batch_size,
168
+ rbln_config.num_channels_latents,
169
+ num_latent_frames,
170
+ latent_height,
171
+ latent_width,
172
+ ],
173
+ "float32",
174
+ ),
175
+ ]
176
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
177
+
178
+ rbln_config.set_compile_cfgs(compile_cfgs)
179
+ return rbln_config
180
+
181
+ @classmethod
182
+ def _create_runtimes(
183
+ cls,
184
+ compiled_models: List[rebel.RBLNCompiledModel],
185
+ rbln_config: RBLNAutoencoderKLCosmosConfig,
186
+ ) -> List[rebel.Runtime]:
187
+ if len(compiled_models) == 1:
188
+ # decoder
189
+ expected_models = ["decoder"]
190
+ else:
191
+ expected_models = ["encoder", "decoder"]
192
+
193
+ if any(model_name not in rbln_config.device_map for model_name in expected_models):
194
+ cls._raise_missing_compiled_file_error(expected_models)
195
+
196
+ device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
197
+ return [
198
+ rebel.Runtime(
199
+ compiled_model,
200
+ tensor_type="pt",
201
+ device=device_val,
202
+ activate_profiler=rbln_config.activate_profiler,
203
+ )
204
+ for compiled_model, device_val in zip(compiled_models, device_vals)
205
+ ]
206
+
207
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
208
+ posterior = self.encoder.encode(x)
209
+ if not return_dict:
210
+ return (posterior,)
211
+ return AutoencoderKLOutput(latent_dist=posterior)
212
+
213
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
214
+ decoded = self.decoder.decode(z)
215
+
216
+ if not return_dict:
217
+ return (decoded,)
218
+
219
+ return DecoderOutput(sample=decoded)
@@ -12,15 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import List
15
+ from typing import TYPE_CHECKING, List
16
16
 
17
17
  import torch
18
- from diffusers import AutoencoderKL, VQModel
19
- from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
18
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution, IdentityDistribution
20
19
 
21
20
  from ....utils.runtime_utils import RBLNPytorchRuntime
22
21
 
23
22
 
23
+ if TYPE_CHECKING:
24
+ from diffusers import AutoencoderKL, AutoencoderKLCosmos, VQModel
25
+
26
+
24
27
  class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
25
28
  def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
26
29
  moments = self.forward(x.contiguous())
@@ -33,6 +36,27 @@ class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
33
36
  return self.forward(z)
34
37
 
35
38
 
39
+ class RBLNRuntimeCosmosVAEEncoder(RBLNPytorchRuntime):
40
+ def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
41
+ if self.use_slicing and x.shape[0] > 1:
42
+ encoded_slices = [self.forward(x_slice) for x_slice in x.split(1)]
43
+ h = torch.cat(encoded_slices)
44
+ else:
45
+ h = self.forward(x)
46
+ posterior = IdentityDistribution(h)
47
+ return posterior
48
+
49
+
50
+ class RBLNRuntimeCosmosVAEDecoder(RBLNPytorchRuntime):
51
+ def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
52
+ if self.use_slicing and z.shape[0] > 1:
53
+ decoded_slices = [self.forward(z_slice) for z_slice in z.split(1)]
54
+ decoded = torch.cat(decoded_slices)
55
+ else:
56
+ decoded = self.forward(z)
57
+ return decoded
58
+
59
+
36
60
  class _VAEDecoder(torch.nn.Module):
37
61
  def __init__(self, vae: "AutoencoderKL"):
38
62
  super().__init__()
@@ -66,6 +90,26 @@ class _VAEEncoder(torch.nn.Module):
66
90
  return vae_out
67
91
 
68
92
 
93
+ class _VAECosmosEncoder(torch.nn.Module):
94
+ def __init__(self, vae: "AutoencoderKLCosmos"):
95
+ super().__init__()
96
+ self.vae = vae
97
+
98
+ def forward(self, x):
99
+ vae_out = self.vae._encode(x)
100
+ return vae_out
101
+
102
+
103
+ class _VAECosmosDecoder(torch.nn.Module):
104
+ def __init__(self, vae: "AutoencoderKLCosmos"):
105
+ super().__init__()
106
+ self.vae = vae
107
+
108
+ def forward(self, z):
109
+ vae_out = self.vae._decode(z, return_dict=False)
110
+ return vae_out
111
+
112
+
69
113
  class RBLNRuntimeVQEncoder(RBLNPytorchRuntime):
70
114
  def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
71
115
  h = self.forward(x.contiguous())
@@ -84,7 +128,7 @@ class RBLNRuntimeVQDecoder(RBLNPytorchRuntime):
84
128
 
85
129
 
86
130
  class _VQEncoder(torch.nn.Module):
87
- def __init__(self, vq_model: VQModel):
131
+ def __init__(self, vq_model: "VQModel"):
88
132
  super().__init__()
89
133
  self.vq_model = vq_model
90
134
 
@@ -99,7 +143,7 @@ class _VQEncoder(torch.nn.Module):
99
143
 
100
144
 
101
145
  class _VQDecoder(torch.nn.Module):
102
- def __init__(self, vq_model: VQModel):
146
+ def __init__(self, vq_model: "VQModel"):
103
147
  super().__init__()
104
148
  self.vq_model = vq_model
105
149
 
@@ -78,7 +78,12 @@ class RBLNVQModel(RBLNModel):
78
78
 
79
79
  wrapped_model.eval()
80
80
 
81
- compiled_models[model_name] = cls.compile(wrapped_model, rbln_compile_config=rbln_config.compile_cfgs[i])
81
+ compiled_models[model_name] = cls.compile(
82
+ wrapped_model,
83
+ rbln_compile_config=rbln_config.compile_cfgs[i],
84
+ create_runtimes=rbln_config.create_runtimes,
85
+ device=rbln_config.device_map[model_name],
86
+ )
82
87
 
83
88
  return compiled_models
84
89
 
@@ -13,4 +13,5 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from .prior_transformer import RBLNPriorTransformer
16
+ from .transformer_cosmos import RBLNCosmosTransformer3DModel
16
17
  from .transformer_sd3 import RBLNSD3Transformer2DModel
@@ -0,0 +1,321 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+ from typing import TYPE_CHECKING, List, Optional, Union
17
+
18
+ import rebel
19
+ import torch
20
+ from diffusers import CosmosTransformer3DModel
21
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
22
+ from diffusers.models.transformers.transformer_cosmos import (
23
+ CosmosEmbedding,
24
+ CosmosLearnablePositionalEmbed,
25
+ CosmosPatchEmbed,
26
+ CosmosRotaryPosEmbed,
27
+ )
28
+ from torchvision import transforms
29
+
30
+ from ....configuration_utils import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNModelConfig
31
+ from ....modeling import RBLNModel
32
+ from ....utils.logging import get_logger
33
+ from ...configurations import RBLNCosmosTransformer3DModelConfig
34
+
35
+
36
+ if TYPE_CHECKING:
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
38
+
39
+ from ...modeling_diffusers import RBLNCosmosTransformer3DModelConfig, RBLNDiffusionMixin, RBLNDiffusionMixinConfig
40
+
41
+
42
+ logger = get_logger(__name__)
43
+
44
+
45
+ class CosmosTransformer3DModelWrapper(torch.nn.Module):
46
+ def __init__(
47
+ self,
48
+ model: CosmosTransformer3DModel,
49
+ num_latent_frames: int = 16,
50
+ latent_height: int = 88,
51
+ latent_width: int = 160,
52
+ ) -> None:
53
+ super().__init__()
54
+ self.model = model
55
+ self.num_latent_frames = num_latent_frames
56
+ self.latent_height = latent_height
57
+ self.latent_width = latent_width
58
+ self.p_t, self.p_h, self.p_w = model.config.patch_size
59
+
60
+ def forward(
61
+ self,
62
+ hidden_states: torch.Tensor,
63
+ encoder_hidden_states: torch.Tensor,
64
+ embedded_timestep: torch.Tensor,
65
+ temb: torch.Tensor,
66
+ image_rotary_emb_0: torch.Tensor,
67
+ image_rotary_emb_1: torch.Tensor,
68
+ extra_pos_emb: Optional[torch.Tensor] = None,
69
+ attention_mask: Optional[torch.Tensor] = None,
70
+ return_dict: bool = False,
71
+ ):
72
+ image_rotary_emb = [image_rotary_emb_0, image_rotary_emb_1]
73
+ for block in self.model.transformer_blocks:
74
+ hidden_states = block(
75
+ hidden_states=hidden_states,
76
+ encoder_hidden_states=encoder_hidden_states,
77
+ embedded_timestep=embedded_timestep,
78
+ temb=temb,
79
+ image_rotary_emb=image_rotary_emb,
80
+ extra_pos_emb=extra_pos_emb,
81
+ attention_mask=attention_mask,
82
+ )
83
+ post_patch_num_frames = self.num_latent_frames // self.p_t
84
+ post_patch_height = self.latent_height // self.p_h
85
+ post_patch_width = self.latent_width // self.p_w
86
+ hidden_states = self.model.norm_out(hidden_states, embedded_timestep, temb)
87
+ hidden_states = self.model.proj_out(hidden_states)
88
+ hidden_states = hidden_states.unflatten(2, (self.p_h, self.p_w, self.p_t, -1))
89
+ hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
90
+ hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
91
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
92
+
93
+ return (hidden_states,)
94
+
95
+
96
+ class RBLNCosmosTransformer3DModel(RBLNModel):
97
+ """RBLN wrapper for the Cosmos Transformer model."""
98
+
99
+ hf_library_name = "diffusers"
100
+ auto_model_class = CosmosTransformer3DModel
101
+
102
+ def __post_init__(self, **kwargs):
103
+ super().__post_init__(**kwargs)
104
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
105
+
106
+ hidden_size = self.config.num_attention_heads * self.config.attention_head_dim
107
+ patch_embed_in_channels = (
108
+ self.config.in_channels + 1 if self.config.concat_padding_mask else self.config.in_channels
109
+ )
110
+ self.rope = CosmosRotaryPosEmbed(
111
+ hidden_size=self.config.attention_head_dim,
112
+ max_size=self.config.max_size,
113
+ patch_size=self.config.patch_size,
114
+ rope_scale=self.config.rope_scale,
115
+ )
116
+ self.rope.load_state_dict(artifacts["rope"])
117
+ if artifacts["learnable_pos_embed"] is None:
118
+ self.learnable_pos_embed = None
119
+ else:
120
+ self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
121
+ hidden_size=hidden_size,
122
+ max_size=self.config.max_size,
123
+ patch_size=self.config.patch_size,
124
+ )
125
+ self.learnable_pos_embed.load_state_dict(artifacts["learnable_pos_embed"])
126
+ self.patch_embed = CosmosPatchEmbed(patch_embed_in_channels, hidden_size, self.config.patch_size, bias=False)
127
+ self.patch_embed.load_state_dict(artifacts["patch_embed"])
128
+ self.time_embed = CosmosEmbedding(hidden_size, hidden_size)
129
+ self.time_embed.load_state_dict(artifacts["time_embed"])
130
+
131
+ def compute_embedding(
132
+ self,
133
+ hidden_states: torch.Tensor,
134
+ timestep: torch.Tensor,
135
+ attention_mask: Optional[torch.Tensor] = None,
136
+ fps: Optional[int] = None,
137
+ condition_mask: Optional[torch.Tensor] = None,
138
+ padding_mask: Optional[torch.Tensor] = None,
139
+ ):
140
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
141
+
142
+ # 1. Concatenate padding mask if needed & prepare attention mask
143
+ if condition_mask is not None:
144
+ hidden_states = torch.cat([hidden_states, condition_mask], dim=1)
145
+
146
+ if self.config.concat_padding_mask:
147
+ padding_mask = transforms.functional.resize(
148
+ padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
149
+ )
150
+ hidden_states = torch.cat(
151
+ [hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
152
+ )
153
+
154
+ if attention_mask is not None:
155
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
156
+
157
+ # 2. Generate positional embeddings
158
+ image_rotary_emb = self.rope(hidden_states, fps=fps)
159
+ extra_pos_emb = self.learnable_pos_embed(hidden_states) if self.config.extra_pos_embed_type else None
160
+
161
+ # 3. Patchify input
162
+ p_t, p_h, p_w = self.config.patch_size
163
+ hidden_states = self.patch_embed(hidden_states)
164
+ hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
165
+
166
+ # 4. Timestep embeddings
167
+ temb, embedded_timestep = self.time_embed(hidden_states, timestep)
168
+
169
+ return (
170
+ hidden_states,
171
+ temb,
172
+ embedded_timestep,
173
+ image_rotary_emb[0],
174
+ image_rotary_emb[1],
175
+ extra_pos_emb,
176
+ attention_mask,
177
+ )
178
+
179
+ @classmethod
180
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
181
+ num_latent_frames = rbln_config.num_latent_frames
182
+ latent_height = rbln_config.latent_height
183
+ latent_width = rbln_config.latent_width
184
+ return CosmosTransformer3DModelWrapper(
185
+ model=model,
186
+ num_latent_frames=num_latent_frames,
187
+ latent_height=latent_height,
188
+ latent_width=latent_width,
189
+ ).eval()
190
+
191
+ @classmethod
192
+ def update_rbln_config_using_pipe(
193
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
194
+ ) -> RBLNCosmosTransformer3DModelConfig:
195
+ rbln_config.transformer.num_latent_frames = (
196
+ rbln_config.transformer.num_frames - 1
197
+ ) // pipe.vae_scale_factor_temporal + 1
198
+ rbln_config.transformer.latent_height = rbln_config.transformer.height // pipe.vae_scale_factor_spatial
199
+ rbln_config.transformer.latent_width = rbln_config.transformer.width // pipe.vae_scale_factor_spatial
200
+ rbln_config.transformer.max_seq_len = pipe.text_encoder.config.n_positions
201
+ rbln_config.transformer.embedding_dim = pipe.text_encoder.encoder.embed_tokens.embedding_dim
202
+
203
+ return rbln_config
204
+
205
+ @classmethod
206
+ def save_torch_artifacts(
207
+ cls,
208
+ model: "PreTrainedModel",
209
+ save_dir_path: Path,
210
+ subfolder: str,
211
+ rbln_config: RBLNModelConfig,
212
+ ):
213
+ save_dict = {}
214
+ save_dict["rope"] = model.rope.state_dict()
215
+ if model.learnable_pos_embed is not None:
216
+ save_dict["learnable_pos_embed"] = model.learnable_pos_embed.state_dict()
217
+ save_dict["patch_embed"] = model.patch_embed.state_dict()
218
+ save_dict["time_embed"] = model.time_embed.state_dict()
219
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
220
+
221
+ @classmethod
222
+ def _update_rbln_config(
223
+ cls,
224
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
225
+ model: "PreTrainedModel",
226
+ model_config: "PretrainedConfig",
227
+ rbln_config: "RBLNCosmosTransformer3DModelConfig",
228
+ ) -> RBLNCosmosTransformer3DModelConfig:
229
+ p_t, p_h, p_w = model_config.patch_size
230
+ hidden_dim = (
231
+ (rbln_config.num_latent_frames // p_t)
232
+ * (rbln_config.latent_height // p_h)
233
+ * (rbln_config.latent_width // p_w)
234
+ )
235
+ attention_head_dim = model_config.attention_head_dim
236
+ hidden_size = model.config.num_attention_heads * model.config.attention_head_dim
237
+ input_info = [
238
+ (
239
+ "hidden_states",
240
+ [
241
+ rbln_config.batch_size,
242
+ hidden_dim,
243
+ hidden_size,
244
+ ],
245
+ "float32",
246
+ ),
247
+ (
248
+ "encoder_hidden_states",
249
+ [
250
+ rbln_config.batch_size,
251
+ rbln_config.max_seq_len,
252
+ rbln_config.embedding_dim,
253
+ ],
254
+ "float32",
255
+ ),
256
+ ("embedded_timestep", [rbln_config.batch_size, hidden_size], "float32"),
257
+ ("temb", [1, hidden_size * 3], "float32"),
258
+ ("image_rotary_emb_0", [hidden_dim, attention_head_dim], "float32"),
259
+ ("image_rotary_emb_1", [hidden_dim, attention_head_dim], "float32"),
260
+ ("extra_pos_emb", [rbln_config.batch_size, hidden_dim, hidden_size], "float32"),
261
+ ]
262
+
263
+ compile_config = RBLNCompileConfig(input_info=input_info)
264
+ rbln_config.set_compile_cfgs([compile_config])
265
+ return rbln_config
266
+
267
+ @classmethod
268
+ def _create_runtimes(
269
+ cls,
270
+ compiled_models: List[rebel.RBLNCompiledModel],
271
+ rbln_config: RBLNModelConfig,
272
+ ) -> List[rebel.Runtime]:
273
+ if DEFAULT_COMPILED_MODEL_NAME not in rbln_config.device_map:
274
+ cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
275
+
276
+ return [
277
+ rebel.Runtime(
278
+ compiled_model,
279
+ tensor_type="pt",
280
+ device=rbln_config.device_map[DEFAULT_COMPILED_MODEL_NAME],
281
+ activate_profiler=rbln_config.activate_profiler,
282
+ timeout=120,
283
+ )
284
+ for compiled_model in compiled_models
285
+ ]
286
+
287
+ def forward(
288
+ self,
289
+ hidden_states: torch.Tensor,
290
+ timestep: torch.Tensor,
291
+ encoder_hidden_states: torch.Tensor,
292
+ attention_mask: Optional[torch.Tensor] = None,
293
+ fps: Optional[int] = None,
294
+ condition_mask: Optional[torch.Tensor] = None,
295
+ padding_mask: Optional[torch.Tensor] = None,
296
+ return_dict: bool = True,
297
+ ):
298
+ (
299
+ hidden_states,
300
+ temb,
301
+ embedded_timestep,
302
+ image_rotary_emb_0,
303
+ image_rotary_emb_1,
304
+ extra_pos_emb,
305
+ attention_mask,
306
+ ) = self.compute_embedding(hidden_states, timestep, attention_mask, fps, condition_mask, padding_mask)
307
+
308
+ hidden_states = self.model[0].forward(
309
+ hidden_states,
310
+ encoder_hidden_states,
311
+ embedded_timestep,
312
+ temb,
313
+ image_rotary_emb_0,
314
+ image_rotary_emb_1,
315
+ extra_pos_emb,
316
+ )
317
+
318
+ if not return_dict:
319
+ return (hidden_states,)
320
+ else:
321
+ return Transformer2DModelOutput(sample=hidden_states)