optimum-rbln 0.8.1a5__py3-none-any.whl → 0.8.1a7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +18 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/__init__.py +21 -1
- optimum/rbln/diffusers/configurations/__init__.py +4 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +82 -0
- optimum/rbln/diffusers/configurations/models/configuration_cosmos_transformer.py +68 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +0 -4
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +110 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +0 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +0 -4
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +1 -4
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +0 -4
- optimum/rbln/diffusers/modeling_diffusers.py +57 -40
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +6 -1
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +49 -5
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +6 -1
- optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
- optimum/rbln/diffusers/pipelines/__init__.py +10 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +451 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
- optimum/rbln/modeling.py +38 -2
- optimum/rbln/modeling_base.py +18 -2
- optimum/rbln/transformers/modeling_generic.py +3 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
- optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
- optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
- optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +13 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +15 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +12 -2
- optimum/rbln/transformers/models/clip/modeling_clip.py +27 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +22 -20
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +6 -1
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +8 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -3
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +8 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +16 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +6 -1
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +5 -1
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +12 -2
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +16 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +6 -2
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +7 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +7 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
- optimum/rbln/transformers/models/t5/modeling_t5.py +10 -4
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +7 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +6 -2
- optimum/rbln/transformers/models/vit/configuration_vit.py +6 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +7 -1
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +7 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +7 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +6 -2
- optimum/rbln/utils/runtime_utils.py +49 -1
- {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/RECORD +70 -60
- {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/licenses/LICENSE +0 -0
@@ -80,7 +80,12 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
80
80
|
|
81
81
|
wrapped_model.eval()
|
82
82
|
|
83
|
-
compiled_models[model_name] = cls.compile(
|
83
|
+
compiled_models[model_name] = cls.compile(
|
84
|
+
wrapped_model,
|
85
|
+
rbln_compile_config=rbln_config.compile_cfgs[i],
|
86
|
+
create_runtimes=rbln_config.create_runtimes,
|
87
|
+
device=rbln_config.device_map[model_name],
|
88
|
+
)
|
84
89
|
|
85
90
|
return compiled_models
|
86
91
|
|
@@ -0,0 +1,219 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import TYPE_CHECKING, Dict, List, Union
|
16
|
+
|
17
|
+
import rebel
|
18
|
+
import torch
|
19
|
+
from diffusers.models.autoencoders.autoencoder_kl_cosmos import AutoencoderKLCosmos, CosmosCausalConv3d
|
20
|
+
from diffusers.models.autoencoders.vae import DecoderOutput
|
21
|
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
22
|
+
from torch.nn import functional as F
|
23
|
+
from transformers import PretrainedConfig
|
24
|
+
|
25
|
+
from ....configuration_utils import RBLNCompileConfig
|
26
|
+
from ....modeling import RBLNModel
|
27
|
+
from ....utils.logging import get_logger
|
28
|
+
from ...configurations import RBLNAutoencoderKLCosmosConfig
|
29
|
+
from .vae import RBLNRuntimeCosmosVAEDecoder, RBLNRuntimeCosmosVAEEncoder, _VAECosmosDecoder, _VAECosmosEncoder
|
30
|
+
|
31
|
+
|
32
|
+
if TYPE_CHECKING:
|
33
|
+
import torch
|
34
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
35
|
+
|
36
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
37
|
+
|
38
|
+
logger = get_logger(__name__)
|
39
|
+
|
40
|
+
|
41
|
+
class RBLNAutoencoderKLCosmos(RBLNModel):
|
42
|
+
"""
|
43
|
+
RBLN implementation of AutoencoderKLCosmos for diffusion models.
|
44
|
+
|
45
|
+
This model is used to accelerate AutoencoderKLCosmos models from diffusers library on RBLN NPUs.
|
46
|
+
It can be configured to include both encoder and decoder, or just the decoder part for latent-to-video
|
47
|
+
conversion.
|
48
|
+
|
49
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
50
|
+
the library implements for all its models.
|
51
|
+
"""
|
52
|
+
|
53
|
+
auto_model_class = AutoencoderKLCosmos
|
54
|
+
hf_library_name = "diffusers"
|
55
|
+
_rbln_config_class = RBLNAutoencoderKLCosmosConfig
|
56
|
+
|
57
|
+
def __post_init__(self, **kwargs):
|
58
|
+
super().__post_init__(**kwargs)
|
59
|
+
|
60
|
+
if self.rbln_config.uses_encoder:
|
61
|
+
self.encoder = RBLNRuntimeCosmosVAEEncoder(
|
62
|
+
runtime=self.model[0], main_input_name="x", use_slicing=self.rbln_config.use_slicing
|
63
|
+
)
|
64
|
+
|
65
|
+
self.decoder = RBLNRuntimeCosmosVAEDecoder(
|
66
|
+
runtime=self.model[-1], main_input_name="z", use_slicing=self.rbln_config.use_slicing
|
67
|
+
)
|
68
|
+
self.image_size = self.rbln_config.image_size
|
69
|
+
|
70
|
+
@classmethod
|
71
|
+
def wrap_model_if_needed(
|
72
|
+
cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLCosmosConfig
|
73
|
+
) -> torch.nn.Module:
|
74
|
+
decoder_model = _VAECosmosDecoder(model)
|
75
|
+
decoder_model.eval()
|
76
|
+
|
77
|
+
if rbln_config.uses_encoder:
|
78
|
+
encoder_model = _VAECosmosEncoder(model)
|
79
|
+
encoder_model.eval()
|
80
|
+
return encoder_model, decoder_model
|
81
|
+
else:
|
82
|
+
return decoder_model
|
83
|
+
|
84
|
+
@classmethod
|
85
|
+
def get_compiled_model(
|
86
|
+
cls, model, rbln_config: RBLNAutoencoderKLCosmosConfig
|
87
|
+
) -> Dict[str, rebel.RBLNCompiledModel]:
|
88
|
+
def replaced_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
89
|
+
if self.temporal_pad != 0:
|
90
|
+
hidden_states_prev = hidden_states[:, :, :1, ...].repeat(1, 1, self.temporal_pad, 1, 1)
|
91
|
+
hidden_states = torch.cat([hidden_states_prev, hidden_states], dim=2)
|
92
|
+
hidden_states = F.pad(hidden_states, (*self.spatial_pad, 0, 0), mode=self.pad_mode, value=0.0)
|
93
|
+
return super(CosmosCausalConv3d, self).forward(hidden_states)
|
94
|
+
|
95
|
+
try:
|
96
|
+
original_forward = CosmosCausalConv3d.forward
|
97
|
+
CosmosCausalConv3d.forward = replaced_forward
|
98
|
+
|
99
|
+
compiled_models = {}
|
100
|
+
if rbln_config.uses_encoder:
|
101
|
+
encoder_model, decoder_model = cls.wrap_model_if_needed(model, rbln_config)
|
102
|
+
enc_compiled_model = cls.compile(
|
103
|
+
encoder_model,
|
104
|
+
rbln_compile_config=rbln_config.compile_cfgs[0],
|
105
|
+
create_runtimes=rbln_config.create_runtimes,
|
106
|
+
device=rbln_config.device_map["encoder"],
|
107
|
+
)
|
108
|
+
compiled_models["encoder"] = enc_compiled_model
|
109
|
+
else:
|
110
|
+
decoder_model = cls.wrap_model_if_needed(model, rbln_config)
|
111
|
+
dec_compiled_model = cls.compile(
|
112
|
+
decoder_model,
|
113
|
+
rbln_compile_config=rbln_config.compile_cfgs[-1],
|
114
|
+
create_runtimes=rbln_config.create_runtimes,
|
115
|
+
device=rbln_config.device_map["decoder"],
|
116
|
+
)
|
117
|
+
compiled_models["decoder"] = dec_compiled_model
|
118
|
+
|
119
|
+
finally:
|
120
|
+
CosmosCausalConv3d.forward = original_forward
|
121
|
+
|
122
|
+
return compiled_models
|
123
|
+
|
124
|
+
@classmethod
|
125
|
+
def update_rbln_config_using_pipe(
|
126
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
127
|
+
) -> "RBLNDiffusionMixinConfig":
|
128
|
+
rbln_config.vae.num_channels_latents = pipe.transformer.config.out_channels
|
129
|
+
rbln_config.vae.vae_scale_factor_temporal = pipe.vae_scale_factor_temporal
|
130
|
+
rbln_config.vae.vae_scale_factor_spatial = pipe.vae_scale_factor_spatial
|
131
|
+
return rbln_config
|
132
|
+
|
133
|
+
@classmethod
|
134
|
+
def _update_rbln_config(
|
135
|
+
cls,
|
136
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
137
|
+
model: "PreTrainedModel",
|
138
|
+
model_config: "PretrainedConfig",
|
139
|
+
rbln_config: RBLNAutoencoderKLCosmosConfig,
|
140
|
+
) -> RBLNAutoencoderKLCosmosConfig:
|
141
|
+
batch_size = 1 if rbln_config.use_slicing else rbln_config.batch_size
|
142
|
+
compile_cfgs = []
|
143
|
+
if rbln_config.uses_encoder:
|
144
|
+
vae_enc_input_info = [
|
145
|
+
(
|
146
|
+
"x",
|
147
|
+
[
|
148
|
+
batch_size,
|
149
|
+
model_config.in_channels,
|
150
|
+
rbln_config.num_frames,
|
151
|
+
rbln_config.height,
|
152
|
+
rbln_config.width,
|
153
|
+
],
|
154
|
+
"float32",
|
155
|
+
),
|
156
|
+
]
|
157
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
|
158
|
+
|
159
|
+
num_latent_frames = (rbln_config.num_frames - 1) // rbln_config.vae_scale_factor_temporal + 1
|
160
|
+
latent_height = rbln_config.height // rbln_config.vae_scale_factor_spatial
|
161
|
+
latent_width = rbln_config.width // rbln_config.vae_scale_factor_spatial
|
162
|
+
|
163
|
+
vae_dec_input_info = [
|
164
|
+
(
|
165
|
+
"z",
|
166
|
+
[
|
167
|
+
batch_size,
|
168
|
+
rbln_config.num_channels_latents,
|
169
|
+
num_latent_frames,
|
170
|
+
latent_height,
|
171
|
+
latent_width,
|
172
|
+
],
|
173
|
+
"float32",
|
174
|
+
),
|
175
|
+
]
|
176
|
+
compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
|
177
|
+
|
178
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
179
|
+
return rbln_config
|
180
|
+
|
181
|
+
@classmethod
|
182
|
+
def _create_runtimes(
|
183
|
+
cls,
|
184
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
185
|
+
rbln_config: RBLNAutoencoderKLCosmosConfig,
|
186
|
+
) -> List[rebel.Runtime]:
|
187
|
+
if len(compiled_models) == 1:
|
188
|
+
# decoder
|
189
|
+
expected_models = ["decoder"]
|
190
|
+
else:
|
191
|
+
expected_models = ["encoder", "decoder"]
|
192
|
+
|
193
|
+
if any(model_name not in rbln_config.device_map for model_name in expected_models):
|
194
|
+
cls._raise_missing_compiled_file_error(expected_models)
|
195
|
+
|
196
|
+
device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
|
197
|
+
return [
|
198
|
+
rebel.Runtime(
|
199
|
+
compiled_model,
|
200
|
+
tensor_type="pt",
|
201
|
+
device=device_val,
|
202
|
+
activate_profiler=rbln_config.activate_profiler,
|
203
|
+
)
|
204
|
+
for compiled_model, device_val in zip(compiled_models, device_vals)
|
205
|
+
]
|
206
|
+
|
207
|
+
def encode(self, x: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
|
208
|
+
posterior = self.encoder.encode(x)
|
209
|
+
if not return_dict:
|
210
|
+
return (posterior,)
|
211
|
+
return AutoencoderKLOutput(latent_dist=posterior)
|
212
|
+
|
213
|
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
|
214
|
+
decoded = self.decoder.decode(z)
|
215
|
+
|
216
|
+
if not return_dict:
|
217
|
+
return (decoded,)
|
218
|
+
|
219
|
+
return DecoderOutput(sample=decoded)
|
@@ -12,15 +12,18 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import List
|
15
|
+
from typing import TYPE_CHECKING, List
|
16
16
|
|
17
17
|
import torch
|
18
|
-
from diffusers import
|
19
|
-
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
18
|
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution, IdentityDistribution
|
20
19
|
|
21
20
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
22
21
|
|
23
22
|
|
23
|
+
if TYPE_CHECKING:
|
24
|
+
from diffusers import AutoencoderKL, AutoencoderKLCosmos, VQModel
|
25
|
+
|
26
|
+
|
24
27
|
class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
|
25
28
|
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
26
29
|
moments = self.forward(x.contiguous())
|
@@ -33,6 +36,27 @@ class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
|
|
33
36
|
return self.forward(z)
|
34
37
|
|
35
38
|
|
39
|
+
class RBLNRuntimeCosmosVAEEncoder(RBLNPytorchRuntime):
|
40
|
+
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
41
|
+
if self.use_slicing and x.shape[0] > 1:
|
42
|
+
encoded_slices = [self.forward(x_slice) for x_slice in x.split(1)]
|
43
|
+
h = torch.cat(encoded_slices)
|
44
|
+
else:
|
45
|
+
h = self.forward(x)
|
46
|
+
posterior = IdentityDistribution(h)
|
47
|
+
return posterior
|
48
|
+
|
49
|
+
|
50
|
+
class RBLNRuntimeCosmosVAEDecoder(RBLNPytorchRuntime):
|
51
|
+
def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
52
|
+
if self.use_slicing and z.shape[0] > 1:
|
53
|
+
decoded_slices = [self.forward(z_slice) for z_slice in z.split(1)]
|
54
|
+
decoded = torch.cat(decoded_slices)
|
55
|
+
else:
|
56
|
+
decoded = self.forward(z)
|
57
|
+
return decoded
|
58
|
+
|
59
|
+
|
36
60
|
class _VAEDecoder(torch.nn.Module):
|
37
61
|
def __init__(self, vae: "AutoencoderKL"):
|
38
62
|
super().__init__()
|
@@ -66,6 +90,26 @@ class _VAEEncoder(torch.nn.Module):
|
|
66
90
|
return vae_out
|
67
91
|
|
68
92
|
|
93
|
+
class _VAECosmosEncoder(torch.nn.Module):
|
94
|
+
def __init__(self, vae: "AutoencoderKLCosmos"):
|
95
|
+
super().__init__()
|
96
|
+
self.vae = vae
|
97
|
+
|
98
|
+
def forward(self, x):
|
99
|
+
vae_out = self.vae._encode(x)
|
100
|
+
return vae_out
|
101
|
+
|
102
|
+
|
103
|
+
class _VAECosmosDecoder(torch.nn.Module):
|
104
|
+
def __init__(self, vae: "AutoencoderKLCosmos"):
|
105
|
+
super().__init__()
|
106
|
+
self.vae = vae
|
107
|
+
|
108
|
+
def forward(self, z):
|
109
|
+
vae_out = self.vae._decode(z, return_dict=False)
|
110
|
+
return vae_out
|
111
|
+
|
112
|
+
|
69
113
|
class RBLNRuntimeVQEncoder(RBLNPytorchRuntime):
|
70
114
|
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
71
115
|
h = self.forward(x.contiguous())
|
@@ -84,7 +128,7 @@ class RBLNRuntimeVQDecoder(RBLNPytorchRuntime):
|
|
84
128
|
|
85
129
|
|
86
130
|
class _VQEncoder(torch.nn.Module):
|
87
|
-
def __init__(self, vq_model: VQModel):
|
131
|
+
def __init__(self, vq_model: "VQModel"):
|
88
132
|
super().__init__()
|
89
133
|
self.vq_model = vq_model
|
90
134
|
|
@@ -99,7 +143,7 @@ class _VQEncoder(torch.nn.Module):
|
|
99
143
|
|
100
144
|
|
101
145
|
class _VQDecoder(torch.nn.Module):
|
102
|
-
def __init__(self, vq_model: VQModel):
|
146
|
+
def __init__(self, vq_model: "VQModel"):
|
103
147
|
super().__init__()
|
104
148
|
self.vq_model = vq_model
|
105
149
|
|
@@ -78,7 +78,12 @@ class RBLNVQModel(RBLNModel):
|
|
78
78
|
|
79
79
|
wrapped_model.eval()
|
80
80
|
|
81
|
-
compiled_models[model_name] = cls.compile(
|
81
|
+
compiled_models[model_name] = cls.compile(
|
82
|
+
wrapped_model,
|
83
|
+
rbln_compile_config=rbln_config.compile_cfgs[i],
|
84
|
+
create_runtimes=rbln_config.create_runtimes,
|
85
|
+
device=rbln_config.device_map[model_name],
|
86
|
+
)
|
82
87
|
|
83
88
|
return compiled_models
|
84
89
|
|
@@ -0,0 +1,321 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from pathlib import Path
|
16
|
+
from typing import TYPE_CHECKING, List, Optional, Union
|
17
|
+
|
18
|
+
import rebel
|
19
|
+
import torch
|
20
|
+
from diffusers import CosmosTransformer3DModel
|
21
|
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
22
|
+
from diffusers.models.transformers.transformer_cosmos import (
|
23
|
+
CosmosEmbedding,
|
24
|
+
CosmosLearnablePositionalEmbed,
|
25
|
+
CosmosPatchEmbed,
|
26
|
+
CosmosRotaryPosEmbed,
|
27
|
+
)
|
28
|
+
from torchvision import transforms
|
29
|
+
|
30
|
+
from ....configuration_utils import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNModelConfig
|
31
|
+
from ....modeling import RBLNModel
|
32
|
+
from ....utils.logging import get_logger
|
33
|
+
from ...configurations import RBLNCosmosTransformer3DModelConfig
|
34
|
+
|
35
|
+
|
36
|
+
if TYPE_CHECKING:
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
38
|
+
|
39
|
+
from ...modeling_diffusers import RBLNCosmosTransformer3DModelConfig, RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
40
|
+
|
41
|
+
|
42
|
+
logger = get_logger(__name__)
|
43
|
+
|
44
|
+
|
45
|
+
class CosmosTransformer3DModelWrapper(torch.nn.Module):
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
model: CosmosTransformer3DModel,
|
49
|
+
num_latent_frames: int = 16,
|
50
|
+
latent_height: int = 88,
|
51
|
+
latent_width: int = 160,
|
52
|
+
) -> None:
|
53
|
+
super().__init__()
|
54
|
+
self.model = model
|
55
|
+
self.num_latent_frames = num_latent_frames
|
56
|
+
self.latent_height = latent_height
|
57
|
+
self.latent_width = latent_width
|
58
|
+
self.p_t, self.p_h, self.p_w = model.config.patch_size
|
59
|
+
|
60
|
+
def forward(
|
61
|
+
self,
|
62
|
+
hidden_states: torch.Tensor,
|
63
|
+
encoder_hidden_states: torch.Tensor,
|
64
|
+
embedded_timestep: torch.Tensor,
|
65
|
+
temb: torch.Tensor,
|
66
|
+
image_rotary_emb_0: torch.Tensor,
|
67
|
+
image_rotary_emb_1: torch.Tensor,
|
68
|
+
extra_pos_emb: Optional[torch.Tensor] = None,
|
69
|
+
attention_mask: Optional[torch.Tensor] = None,
|
70
|
+
return_dict: bool = False,
|
71
|
+
):
|
72
|
+
image_rotary_emb = [image_rotary_emb_0, image_rotary_emb_1]
|
73
|
+
for block in self.model.transformer_blocks:
|
74
|
+
hidden_states = block(
|
75
|
+
hidden_states=hidden_states,
|
76
|
+
encoder_hidden_states=encoder_hidden_states,
|
77
|
+
embedded_timestep=embedded_timestep,
|
78
|
+
temb=temb,
|
79
|
+
image_rotary_emb=image_rotary_emb,
|
80
|
+
extra_pos_emb=extra_pos_emb,
|
81
|
+
attention_mask=attention_mask,
|
82
|
+
)
|
83
|
+
post_patch_num_frames = self.num_latent_frames // self.p_t
|
84
|
+
post_patch_height = self.latent_height // self.p_h
|
85
|
+
post_patch_width = self.latent_width // self.p_w
|
86
|
+
hidden_states = self.model.norm_out(hidden_states, embedded_timestep, temb)
|
87
|
+
hidden_states = self.model.proj_out(hidden_states)
|
88
|
+
hidden_states = hidden_states.unflatten(2, (self.p_h, self.p_w, self.p_t, -1))
|
89
|
+
hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
|
90
|
+
hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
|
91
|
+
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
92
|
+
|
93
|
+
return (hidden_states,)
|
94
|
+
|
95
|
+
|
96
|
+
class RBLNCosmosTransformer3DModel(RBLNModel):
|
97
|
+
"""RBLN wrapper for the Cosmos Transformer model."""
|
98
|
+
|
99
|
+
hf_library_name = "diffusers"
|
100
|
+
auto_model_class = CosmosTransformer3DModel
|
101
|
+
|
102
|
+
def __post_init__(self, **kwargs):
|
103
|
+
super().__post_init__(**kwargs)
|
104
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
105
|
+
|
106
|
+
hidden_size = self.config.num_attention_heads * self.config.attention_head_dim
|
107
|
+
patch_embed_in_channels = (
|
108
|
+
self.config.in_channels + 1 if self.config.concat_padding_mask else self.config.in_channels
|
109
|
+
)
|
110
|
+
self.rope = CosmosRotaryPosEmbed(
|
111
|
+
hidden_size=self.config.attention_head_dim,
|
112
|
+
max_size=self.config.max_size,
|
113
|
+
patch_size=self.config.patch_size,
|
114
|
+
rope_scale=self.config.rope_scale,
|
115
|
+
)
|
116
|
+
self.rope.load_state_dict(artifacts["rope"])
|
117
|
+
if artifacts["learnable_pos_embed"] is None:
|
118
|
+
self.learnable_pos_embed = None
|
119
|
+
else:
|
120
|
+
self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
|
121
|
+
hidden_size=hidden_size,
|
122
|
+
max_size=self.config.max_size,
|
123
|
+
patch_size=self.config.patch_size,
|
124
|
+
)
|
125
|
+
self.learnable_pos_embed.load_state_dict(artifacts["learnable_pos_embed"])
|
126
|
+
self.patch_embed = CosmosPatchEmbed(patch_embed_in_channels, hidden_size, self.config.patch_size, bias=False)
|
127
|
+
self.patch_embed.load_state_dict(artifacts["patch_embed"])
|
128
|
+
self.time_embed = CosmosEmbedding(hidden_size, hidden_size)
|
129
|
+
self.time_embed.load_state_dict(artifacts["time_embed"])
|
130
|
+
|
131
|
+
def compute_embedding(
|
132
|
+
self,
|
133
|
+
hidden_states: torch.Tensor,
|
134
|
+
timestep: torch.Tensor,
|
135
|
+
attention_mask: Optional[torch.Tensor] = None,
|
136
|
+
fps: Optional[int] = None,
|
137
|
+
condition_mask: Optional[torch.Tensor] = None,
|
138
|
+
padding_mask: Optional[torch.Tensor] = None,
|
139
|
+
):
|
140
|
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
141
|
+
|
142
|
+
# 1. Concatenate padding mask if needed & prepare attention mask
|
143
|
+
if condition_mask is not None:
|
144
|
+
hidden_states = torch.cat([hidden_states, condition_mask], dim=1)
|
145
|
+
|
146
|
+
if self.config.concat_padding_mask:
|
147
|
+
padding_mask = transforms.functional.resize(
|
148
|
+
padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
149
|
+
)
|
150
|
+
hidden_states = torch.cat(
|
151
|
+
[hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
|
152
|
+
)
|
153
|
+
|
154
|
+
if attention_mask is not None:
|
155
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
|
156
|
+
|
157
|
+
# 2. Generate positional embeddings
|
158
|
+
image_rotary_emb = self.rope(hidden_states, fps=fps)
|
159
|
+
extra_pos_emb = self.learnable_pos_embed(hidden_states) if self.config.extra_pos_embed_type else None
|
160
|
+
|
161
|
+
# 3. Patchify input
|
162
|
+
p_t, p_h, p_w = self.config.patch_size
|
163
|
+
hidden_states = self.patch_embed(hidden_states)
|
164
|
+
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
|
165
|
+
|
166
|
+
# 4. Timestep embeddings
|
167
|
+
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
|
168
|
+
|
169
|
+
return (
|
170
|
+
hidden_states,
|
171
|
+
temb,
|
172
|
+
embedded_timestep,
|
173
|
+
image_rotary_emb[0],
|
174
|
+
image_rotary_emb[1],
|
175
|
+
extra_pos_emb,
|
176
|
+
attention_mask,
|
177
|
+
)
|
178
|
+
|
179
|
+
@classmethod
|
180
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
181
|
+
num_latent_frames = rbln_config.num_latent_frames
|
182
|
+
latent_height = rbln_config.latent_height
|
183
|
+
latent_width = rbln_config.latent_width
|
184
|
+
return CosmosTransformer3DModelWrapper(
|
185
|
+
model=model,
|
186
|
+
num_latent_frames=num_latent_frames,
|
187
|
+
latent_height=latent_height,
|
188
|
+
latent_width=latent_width,
|
189
|
+
).eval()
|
190
|
+
|
191
|
+
@classmethod
|
192
|
+
def update_rbln_config_using_pipe(
|
193
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
194
|
+
) -> RBLNCosmosTransformer3DModelConfig:
|
195
|
+
rbln_config.transformer.num_latent_frames = (
|
196
|
+
rbln_config.transformer.num_frames - 1
|
197
|
+
) // pipe.vae_scale_factor_temporal + 1
|
198
|
+
rbln_config.transformer.latent_height = rbln_config.transformer.height // pipe.vae_scale_factor_spatial
|
199
|
+
rbln_config.transformer.latent_width = rbln_config.transformer.width // pipe.vae_scale_factor_spatial
|
200
|
+
rbln_config.transformer.max_seq_len = pipe.text_encoder.config.n_positions
|
201
|
+
rbln_config.transformer.embedding_dim = pipe.text_encoder.encoder.embed_tokens.embedding_dim
|
202
|
+
|
203
|
+
return rbln_config
|
204
|
+
|
205
|
+
@classmethod
|
206
|
+
def save_torch_artifacts(
|
207
|
+
cls,
|
208
|
+
model: "PreTrainedModel",
|
209
|
+
save_dir_path: Path,
|
210
|
+
subfolder: str,
|
211
|
+
rbln_config: RBLNModelConfig,
|
212
|
+
):
|
213
|
+
save_dict = {}
|
214
|
+
save_dict["rope"] = model.rope.state_dict()
|
215
|
+
if model.learnable_pos_embed is not None:
|
216
|
+
save_dict["learnable_pos_embed"] = model.learnable_pos_embed.state_dict()
|
217
|
+
save_dict["patch_embed"] = model.patch_embed.state_dict()
|
218
|
+
save_dict["time_embed"] = model.time_embed.state_dict()
|
219
|
+
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
220
|
+
|
221
|
+
@classmethod
|
222
|
+
def _update_rbln_config(
|
223
|
+
cls,
|
224
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
225
|
+
model: "PreTrainedModel",
|
226
|
+
model_config: "PretrainedConfig",
|
227
|
+
rbln_config: "RBLNCosmosTransformer3DModelConfig",
|
228
|
+
) -> RBLNCosmosTransformer3DModelConfig:
|
229
|
+
p_t, p_h, p_w = model_config.patch_size
|
230
|
+
hidden_dim = (
|
231
|
+
(rbln_config.num_latent_frames // p_t)
|
232
|
+
* (rbln_config.latent_height // p_h)
|
233
|
+
* (rbln_config.latent_width // p_w)
|
234
|
+
)
|
235
|
+
attention_head_dim = model_config.attention_head_dim
|
236
|
+
hidden_size = model.config.num_attention_heads * model.config.attention_head_dim
|
237
|
+
input_info = [
|
238
|
+
(
|
239
|
+
"hidden_states",
|
240
|
+
[
|
241
|
+
rbln_config.batch_size,
|
242
|
+
hidden_dim,
|
243
|
+
hidden_size,
|
244
|
+
],
|
245
|
+
"float32",
|
246
|
+
),
|
247
|
+
(
|
248
|
+
"encoder_hidden_states",
|
249
|
+
[
|
250
|
+
rbln_config.batch_size,
|
251
|
+
rbln_config.max_seq_len,
|
252
|
+
rbln_config.embedding_dim,
|
253
|
+
],
|
254
|
+
"float32",
|
255
|
+
),
|
256
|
+
("embedded_timestep", [rbln_config.batch_size, hidden_size], "float32"),
|
257
|
+
("temb", [1, hidden_size * 3], "float32"),
|
258
|
+
("image_rotary_emb_0", [hidden_dim, attention_head_dim], "float32"),
|
259
|
+
("image_rotary_emb_1", [hidden_dim, attention_head_dim], "float32"),
|
260
|
+
("extra_pos_emb", [rbln_config.batch_size, hidden_dim, hidden_size], "float32"),
|
261
|
+
]
|
262
|
+
|
263
|
+
compile_config = RBLNCompileConfig(input_info=input_info)
|
264
|
+
rbln_config.set_compile_cfgs([compile_config])
|
265
|
+
return rbln_config
|
266
|
+
|
267
|
+
@classmethod
|
268
|
+
def _create_runtimes(
|
269
|
+
cls,
|
270
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
271
|
+
rbln_config: RBLNModelConfig,
|
272
|
+
) -> List[rebel.Runtime]:
|
273
|
+
if DEFAULT_COMPILED_MODEL_NAME not in rbln_config.device_map:
|
274
|
+
cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
|
275
|
+
|
276
|
+
return [
|
277
|
+
rebel.Runtime(
|
278
|
+
compiled_model,
|
279
|
+
tensor_type="pt",
|
280
|
+
device=rbln_config.device_map[DEFAULT_COMPILED_MODEL_NAME],
|
281
|
+
activate_profiler=rbln_config.activate_profiler,
|
282
|
+
timeout=120,
|
283
|
+
)
|
284
|
+
for compiled_model in compiled_models
|
285
|
+
]
|
286
|
+
|
287
|
+
def forward(
|
288
|
+
self,
|
289
|
+
hidden_states: torch.Tensor,
|
290
|
+
timestep: torch.Tensor,
|
291
|
+
encoder_hidden_states: torch.Tensor,
|
292
|
+
attention_mask: Optional[torch.Tensor] = None,
|
293
|
+
fps: Optional[int] = None,
|
294
|
+
condition_mask: Optional[torch.Tensor] = None,
|
295
|
+
padding_mask: Optional[torch.Tensor] = None,
|
296
|
+
return_dict: bool = True,
|
297
|
+
):
|
298
|
+
(
|
299
|
+
hidden_states,
|
300
|
+
temb,
|
301
|
+
embedded_timestep,
|
302
|
+
image_rotary_emb_0,
|
303
|
+
image_rotary_emb_1,
|
304
|
+
extra_pos_emb,
|
305
|
+
attention_mask,
|
306
|
+
) = self.compute_embedding(hidden_states, timestep, attention_mask, fps, condition_mask, padding_mask)
|
307
|
+
|
308
|
+
hidden_states = self.model[0].forward(
|
309
|
+
hidden_states,
|
310
|
+
encoder_hidden_states,
|
311
|
+
embedded_timestep,
|
312
|
+
temb,
|
313
|
+
image_rotary_emb_0,
|
314
|
+
image_rotary_emb_1,
|
315
|
+
extra_pos_emb,
|
316
|
+
)
|
317
|
+
|
318
|
+
if not return_dict:
|
319
|
+
return (hidden_states,)
|
320
|
+
else:
|
321
|
+
return Transformer2DModelOutput(sample=hidden_states)
|