optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +164 -36
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +772 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +63 -122
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +55 -70
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- optimum/rbln/modeling.py +58 -39
- optimum/rbln/modeling_base.py +107 -78
- optimum/rbln/transformers/__init__.py +87 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +108 -34
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +115 -84
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +282 -216
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/runtime_utils.py +33 -2
- optimum/rbln/utils/submodule.py +26 -43
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/METADATA +1 -1
- optimum_rbln-0.7.4a6.dist-info/RECORD +166 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/licenses/LICENSE +0 -0
@@ -13,20 +13,22 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import importlib
|
16
|
-
from typing import TYPE_CHECKING,
|
16
|
+
from typing import TYPE_CHECKING, Dict, Optional, Union
|
17
17
|
|
18
18
|
import torch
|
19
19
|
from diffusers import ControlNetModel
|
20
|
+
from diffusers.models.controlnet import ControlNetOutput
|
20
21
|
from transformers import PretrainedConfig
|
21
22
|
|
23
|
+
from ...configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
22
24
|
from ...modeling import RBLNModel
|
23
|
-
from ...modeling_config import RBLNCompileConfig, RBLNConfig
|
24
25
|
from ...utils.logging import get_logger
|
25
|
-
from ..
|
26
|
+
from ..configurations import RBLNControlNetModelConfig
|
27
|
+
from ..modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
26
28
|
|
27
29
|
|
28
30
|
if TYPE_CHECKING:
|
29
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
31
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
30
32
|
|
31
33
|
|
32
34
|
logger = get_logger(__name__)
|
@@ -98,6 +100,7 @@ class _ControlNetModel_Cross_Attention(torch.nn.Module):
|
|
98
100
|
class RBLNControlNetModel(RBLNModel):
|
99
101
|
hf_library_name = "diffusers"
|
100
102
|
auto_model_class = ControlNetModel
|
103
|
+
output_class = ControlNetOutput
|
101
104
|
|
102
105
|
def __post_init__(self, **kwargs):
|
103
106
|
super().__post_init__(**kwargs)
|
@@ -106,7 +109,7 @@ class RBLNControlNetModel(RBLNModel):
|
|
106
109
|
)
|
107
110
|
|
108
111
|
@classmethod
|
109
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config:
|
112
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
110
113
|
use_encoder_hidden_states = False
|
111
114
|
for down_block in model.down_blocks:
|
112
115
|
if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
|
@@ -118,73 +121,50 @@ class RBLNControlNetModel(RBLNModel):
|
|
118
121
|
return _ControlNetModel(model).eval()
|
119
122
|
|
120
123
|
@classmethod
|
121
|
-
def update_rbln_config_using_pipe(
|
124
|
+
def update_rbln_config_using_pipe(
|
125
|
+
cls,
|
126
|
+
pipe: RBLNDiffusionMixin,
|
127
|
+
rbln_config: "RBLNDiffusionMixinConfig",
|
128
|
+
submodule_name: str,
|
129
|
+
) -> "RBLNDiffusionMixinConfig":
|
122
130
|
rbln_vae_cls = getattr(importlib.import_module("optimum.rbln"), f"RBLN{pipe.vae.__class__.__name__}")
|
123
131
|
rbln_unet_cls = getattr(importlib.import_module("optimum.rbln"), f"RBLN{pipe.unet.__class__.__name__}")
|
124
|
-
text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
|
125
|
-
|
126
|
-
batch_size = rbln_config.get("batch_size")
|
127
|
-
if not batch_size:
|
128
|
-
do_classifier_free_guidance = (
|
129
|
-
rbln_config.get("guidance_scale", 5.0) > 1.0 and pipe.unet.config.time_cond_proj_dim is None
|
130
|
-
)
|
131
|
-
batch_size = 2 if do_classifier_free_guidance else 1
|
132
|
-
else:
|
133
|
-
if rbln_config.get("guidance_scale"):
|
134
|
-
logger.warning(
|
135
|
-
"guidance_scale is ignored because batch size is explicitly specified. "
|
136
|
-
"To ensure consistent behavior, consider removing the guidance scale or "
|
137
|
-
"adjusting the batch size configuration as needed."
|
138
|
-
)
|
139
132
|
|
140
|
-
rbln_config.
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
"batch_size": batch_size,
|
147
|
-
}
|
133
|
+
rbln_config.controlnet.max_seq_len = pipe.text_encoder.config.max_position_embeddings
|
134
|
+
text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
|
135
|
+
rbln_config.controlnet.text_model_hidden_size = text_model_hidden_size
|
136
|
+
rbln_config.controlnet.vae_sample_size = rbln_vae_cls.get_vae_sample_size(pipe, rbln_config.vae)
|
137
|
+
rbln_config.controlnet.unet_sample_size = rbln_unet_cls.get_unet_sample_size(
|
138
|
+
pipe, rbln_config.unet, image_size=rbln_config.image_size
|
148
139
|
)
|
149
140
|
|
150
141
|
return rbln_config
|
151
142
|
|
152
143
|
@classmethod
|
153
|
-
def
|
144
|
+
def _update_rbln_config(
|
154
145
|
cls,
|
155
146
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
147
|
+
model: "PreTrainedModel",
|
156
148
|
model_config: "PretrainedConfig",
|
157
|
-
|
158
|
-
) ->
|
159
|
-
|
160
|
-
|
161
|
-
unet_sample_size = rbln_kwargs.get("unet_sample_size")
|
162
|
-
vae_sample_size = rbln_kwargs.get("vae_sample_size")
|
163
|
-
|
164
|
-
if batch_size is None:
|
165
|
-
batch_size = 1
|
166
|
-
|
167
|
-
if unet_sample_size is None:
|
168
|
-
raise ValueError(
|
169
|
-
"`rbln_unet_sample_size` (latent height, widht) must be specified (ex. unet's sample_size)"
|
170
|
-
)
|
149
|
+
rbln_config: RBLNControlNetModelConfig,
|
150
|
+
) -> RBLNModelConfig:
|
151
|
+
if rbln_config.unet_sample_size is None:
|
152
|
+
raise ValueError("`unet_sample_size` (latent height, width) must be specified (ex. unet's sample_size)")
|
171
153
|
|
172
|
-
if vae_sample_size is None:
|
173
|
-
raise ValueError(
|
174
|
-
"`rbln_vae_sample_size` (input image height, width) must be specified (ex. vae's sample_size)"
|
175
|
-
)
|
154
|
+
if rbln_config.vae_sample_size is None:
|
155
|
+
raise ValueError("`vae_sample_size` (input image height, width) must be specified (ex. vae's sample_size)")
|
176
156
|
|
177
|
-
if max_seq_len is None:
|
178
|
-
raise ValueError("`
|
157
|
+
if rbln_config.max_seq_len is None:
|
158
|
+
raise ValueError("`max_seq_len` (ex. text_encoder's max_position_embeddings) must be specified")
|
179
159
|
|
180
160
|
input_info = [
|
181
161
|
(
|
182
162
|
"sample",
|
183
163
|
[
|
184
|
-
batch_size,
|
164
|
+
rbln_config.batch_size,
|
185
165
|
model_config.in_channels,
|
186
|
-
unet_sample_size[0],
|
187
|
-
unet_sample_size[1],
|
166
|
+
rbln_config.unet_sample_size[0],
|
167
|
+
rbln_config.unet_sample_size[1],
|
188
168
|
],
|
189
169
|
"float32",
|
190
170
|
),
|
@@ -196,7 +176,7 @@ class RBLNControlNetModel(RBLNModel):
|
|
196
176
|
input_info.append(
|
197
177
|
(
|
198
178
|
"encoder_hidden_states",
|
199
|
-
[batch_size, max_seq_len, model_config.cross_attention_dim],
|
179
|
+
[rbln_config.batch_size, rbln_config.max_seq_len, model_config.cross_attention_dim],
|
200
180
|
"float32",
|
201
181
|
)
|
202
182
|
)
|
@@ -204,25 +184,18 @@ class RBLNControlNetModel(RBLNModel):
|
|
204
184
|
input_info.append(
|
205
185
|
(
|
206
186
|
"controlnet_cond",
|
207
|
-
[batch_size, 3, vae_sample_size[0], vae_sample_size[1]],
|
187
|
+
[rbln_config.batch_size, 3, rbln_config.vae_sample_size[0], rbln_config.vae_sample_size[1]],
|
208
188
|
"float32",
|
209
189
|
)
|
210
190
|
)
|
211
191
|
input_info.append(("conditioning_scale", [], "float32"))
|
212
192
|
|
213
193
|
if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
|
214
|
-
|
215
|
-
input_info.append(("
|
216
|
-
input_info.append(("time_ids", [batch_size, 6], "float32"))
|
194
|
+
input_info.append(("text_embeds", [rbln_config.batch_size, rbln_config.text_model_hidden_size], "float32"))
|
195
|
+
input_info.append(("time_ids", [rbln_config.batch_size, 6], "float32"))
|
217
196
|
|
218
197
|
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
219
|
-
|
220
|
-
rbln_config = RBLNConfig(
|
221
|
-
rbln_cls=cls.__name__,
|
222
|
-
compile_cfgs=[rbln_compile_config],
|
223
|
-
rbln_kwargs=rbln_kwargs,
|
224
|
-
)
|
225
|
-
|
198
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
226
199
|
return rbln_config
|
227
200
|
|
228
201
|
@property
|
@@ -237,6 +210,7 @@ class RBLNControlNetModel(RBLNModel):
|
|
237
210
|
controlnet_cond: torch.FloatTensor,
|
238
211
|
conditioning_scale: torch.Tensor = 1.0,
|
239
212
|
added_cond_kwargs: Dict[str, torch.Tensor] = {},
|
213
|
+
return_dict: bool = True,
|
240
214
|
**kwargs,
|
241
215
|
):
|
242
216
|
sample_batch_size = sample.size()[0]
|
@@ -246,14 +220,14 @@ class RBLNControlNetModel(RBLNModel):
|
|
246
220
|
):
|
247
221
|
raise ValueError(
|
248
222
|
f"Mismatch between ControlNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
|
249
|
-
"This may be caused by the '
|
250
|
-
"Adjust the batch size during compilation
|
223
|
+
"This may be caused by the 'guidance_scale' parameter, which doubles the runtime batch size of ControlNet in Stable Diffusion. "
|
224
|
+
"Adjust the batch size of ControlNet during compilation to match the runtime batch size.\n\n"
|
251
225
|
"For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
|
252
226
|
)
|
253
227
|
|
254
228
|
added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
|
255
229
|
if self.use_encoder_hidden_states:
|
256
|
-
output =
|
230
|
+
output = self.model[0](
|
257
231
|
sample.contiguous(),
|
258
232
|
timestep.float(),
|
259
233
|
encoder_hidden_states,
|
@@ -262,14 +236,25 @@ class RBLNControlNetModel(RBLNModel):
|
|
262
236
|
**added_cond_kwargs,
|
263
237
|
)
|
264
238
|
else:
|
265
|
-
output =
|
239
|
+
output = self.model[0](
|
266
240
|
sample.contiguous(),
|
267
241
|
timestep.float(),
|
268
242
|
controlnet_cond,
|
269
243
|
torch.tensor(conditioning_scale),
|
270
244
|
**added_cond_kwargs,
|
271
245
|
)
|
246
|
+
|
272
247
|
down_block_res_samples = output[:-1]
|
273
248
|
mid_block_res_sample = output[-1]
|
249
|
+
output = (down_block_res_samples, mid_block_res_sample)
|
250
|
+
output = self._prepare_output(output, return_dict)
|
251
|
+
return output
|
274
252
|
|
275
|
-
|
253
|
+
def _prepare_output(self, output, return_dict):
|
254
|
+
if not return_dict:
|
255
|
+
return (output,) if not isinstance(output, (tuple, list)) else output
|
256
|
+
else:
|
257
|
+
return ControlNetOutput(
|
258
|
+
down_block_res_samples=output[:-1],
|
259
|
+
mid_block_res_sample=output[-1],
|
260
|
+
)
|
@@ -13,37 +13,22 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from pathlib import Path
|
16
|
-
from typing import
|
16
|
+
from typing import TYPE_CHECKING, Optional, Union
|
17
17
|
|
18
18
|
import torch
|
19
19
|
from diffusers.models.transformers.prior_transformer import PriorTransformer, PriorTransformerOutput
|
20
|
-
from transformers import PretrainedConfig, PreTrainedModel
|
21
20
|
|
21
|
+
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
22
22
|
from ....modeling import RBLNModel
|
23
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
24
23
|
from ....utils.logging import get_logger
|
25
|
-
from
|
26
|
-
from ...modeling_diffusers import RBLNDiffusionMixin
|
24
|
+
from ...configurations.models import RBLNPriorTransformerConfig
|
25
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
27
26
|
|
28
27
|
|
29
|
-
|
30
|
-
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
31
30
|
|
32
|
-
|
33
|
-
def forward(
|
34
|
-
self, hidden_states, timestep, proj_embedding, encoder_hidden_states, attention_mask, return_dict: bool = True
|
35
|
-
):
|
36
|
-
predicted_image_embedding = super().forward(
|
37
|
-
hidden_states,
|
38
|
-
timestep,
|
39
|
-
proj_embedding,
|
40
|
-
encoder_hidden_states,
|
41
|
-
attention_mask,
|
42
|
-
)
|
43
|
-
if return_dict:
|
44
|
-
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
|
45
|
-
else:
|
46
|
-
return (predicted_image_embedding,)
|
31
|
+
logger = get_logger(__name__)
|
47
32
|
|
48
33
|
|
49
34
|
class _PriorTransformer(torch.nn.Module):
|
@@ -73,51 +58,28 @@ class _PriorTransformer(torch.nn.Module):
|
|
73
58
|
class RBLNPriorTransformer(RBLNModel):
|
74
59
|
hf_library_name = "diffusers"
|
75
60
|
auto_model_class = PriorTransformer
|
61
|
+
output_class = PriorTransformerOutput
|
62
|
+
output_key = "predicted_image_embedding"
|
76
63
|
|
77
64
|
def __post_init__(self, **kwargs):
|
78
65
|
super().__post_init__(**kwargs)
|
79
|
-
self.runtime = RBLNRuntimePriorTransformer(runtime=self.model[0])
|
80
66
|
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
81
67
|
self.clip_mean = artifacts["clip_mean"]
|
82
68
|
self.clip_std = artifacts["clip_std"]
|
83
69
|
|
84
70
|
@classmethod
|
85
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config:
|
71
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
86
72
|
return _PriorTransformer(model).eval()
|
87
73
|
|
88
74
|
@classmethod
|
89
|
-
def update_rbln_config_using_pipe(
|
90
|
-
|
91
|
-
|
92
|
-
do_classifier_free_guidance = rbln_config.get("guidance_scale", 5.0) > 1.0
|
93
|
-
batch_size = 2 if do_classifier_free_guidance else 1
|
94
|
-
else:
|
95
|
-
if rbln_config.get("guidance_scale"):
|
96
|
-
logger.warning(
|
97
|
-
"guidance_scale is ignored because batch size is explicitly specified. "
|
98
|
-
"To ensure consistent behavior, consider removing the guidance scale or "
|
99
|
-
"adjusting the batch size configuration as needed."
|
100
|
-
)
|
101
|
-
embedding_dim = rbln_config.get("embedding_dim", pipe.prior.config.embedding_dim)
|
102
|
-
num_embeddings = rbln_config.get("num_embeddings", pipe.prior.config.num_embeddings)
|
103
|
-
|
104
|
-
rbln_config.update(
|
105
|
-
{
|
106
|
-
"batch_size": batch_size,
|
107
|
-
"embedding_dim": embedding_dim,
|
108
|
-
"num_embeddings": num_embeddings,
|
109
|
-
}
|
110
|
-
)
|
111
|
-
|
75
|
+
def update_rbln_config_using_pipe(
|
76
|
+
cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
77
|
+
) -> "RBLNDiffusionMixinConfig":
|
112
78
|
return rbln_config
|
113
79
|
|
114
80
|
@classmethod
|
115
81
|
def save_torch_artifacts(
|
116
|
-
cls,
|
117
|
-
model: "PreTrainedModel",
|
118
|
-
save_dir_path: Path,
|
119
|
-
subfolder: str,
|
120
|
-
rbln_config: RBLNConfig,
|
82
|
+
cls, model: "PreTrainedModel", save_dir_path: Path, subfolder: str, rbln_config: RBLNModelConfig
|
121
83
|
):
|
122
84
|
save_dict = {}
|
123
85
|
save_dict["clip_mean"] = model.clip_mean
|
@@ -125,50 +87,51 @@ class RBLNPriorTransformer(RBLNModel):
|
|
125
87
|
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
126
88
|
|
127
89
|
@classmethod
|
128
|
-
def
|
90
|
+
def _update_rbln_config(
|
129
91
|
cls,
|
130
|
-
preprocessors,
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
embedding_dim =
|
136
|
-
num_embeddings =
|
92
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
93
|
+
model: "PreTrainedModel",
|
94
|
+
model_config: "PretrainedConfig",
|
95
|
+
rbln_config: RBLNPriorTransformerConfig,
|
96
|
+
) -> RBLNPriorTransformerConfig:
|
97
|
+
rbln_config.embedding_dim = rbln_config.embedding_dim or model_config.embedding_dim
|
98
|
+
rbln_config.num_embeddings = rbln_config.num_embeddings or model_config.num_embeddings
|
137
99
|
|
138
100
|
input_info = [
|
139
|
-
("hidden_states", [batch_size, embedding_dim], "float32"),
|
101
|
+
("hidden_states", [rbln_config.batch_size, rbln_config.embedding_dim], "float32"),
|
140
102
|
("timestep", [], "float32"),
|
141
|
-
("proj_embedding", [batch_size, embedding_dim], "float32"),
|
142
|
-
(
|
143
|
-
|
103
|
+
("proj_embedding", [rbln_config.batch_size, rbln_config.embedding_dim], "float32"),
|
104
|
+
(
|
105
|
+
"encoder_hidden_states",
|
106
|
+
[rbln_config.batch_size, rbln_config.num_embeddings, rbln_config.embedding_dim],
|
107
|
+
"float32",
|
108
|
+
),
|
109
|
+
("attention_mask", [rbln_config.batch_size, rbln_config.num_embeddings], "float32"),
|
144
110
|
]
|
145
111
|
|
146
112
|
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
147
|
-
rbln_config
|
148
|
-
rbln_cls=cls.__name__,
|
149
|
-
compile_cfgs=[rbln_compile_config],
|
150
|
-
rbln_kwargs=rbln_kwargs,
|
151
|
-
)
|
113
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
152
114
|
return rbln_config
|
153
115
|
|
116
|
+
def post_process_latents(self, prior_latents):
|
117
|
+
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
118
|
+
return prior_latents
|
119
|
+
|
154
120
|
def forward(
|
155
121
|
self,
|
156
122
|
hidden_states,
|
157
123
|
timestep: Union[torch.Tensor, float, int],
|
158
124
|
proj_embedding: torch.Tensor,
|
159
125
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
160
|
-
attention_mask: Optional[torch.
|
126
|
+
attention_mask: Optional[torch.Tensor] = None,
|
161
127
|
return_dict: bool = True,
|
162
128
|
):
|
163
|
-
|
164
|
-
|
129
|
+
# Convert timestep(long) and attention_mask(bool) to float
|
130
|
+
return super().forward(
|
131
|
+
hidden_states,
|
165
132
|
timestep.float(),
|
166
133
|
proj_embedding,
|
167
134
|
encoder_hidden_states,
|
168
135
|
attention_mask.float(),
|
169
|
-
return_dict,
|
136
|
+
return_dict=return_dict,
|
170
137
|
)
|
171
|
-
|
172
|
-
def post_process_latents(self, prior_latents):
|
173
|
-
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
174
|
-
return prior_latents
|
@@ -19,14 +19,16 @@ from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
|
19
19
|
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
20
20
|
from transformers import PretrainedConfig
|
21
21
|
|
22
|
+
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
22
23
|
from ....modeling import RBLNModel
|
23
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
24
24
|
from ....utils.logging import get_logger
|
25
|
-
from ...
|
25
|
+
from ...configurations import RBLNSD3Transformer2DModelConfig
|
26
26
|
|
27
27
|
|
28
28
|
if TYPE_CHECKING:
|
29
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
29
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
30
|
+
|
31
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
30
32
|
|
31
33
|
logger = get_logger(__name__)
|
32
34
|
|
@@ -58,84 +60,64 @@ class SD3Transformer2DModelWrapper(torch.nn.Module):
|
|
58
60
|
|
59
61
|
class RBLNSD3Transformer2DModel(RBLNModel):
|
60
62
|
hf_library_name = "diffusers"
|
63
|
+
auto_model_class = SD3Transformer2DModel
|
64
|
+
output_class = Transformer2DModelOutput
|
65
|
+
output_key = "sample"
|
61
66
|
|
62
67
|
def __post_init__(self, **kwargs):
|
63
68
|
super().__post_init__(**kwargs)
|
64
69
|
|
65
70
|
@classmethod
|
66
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config:
|
71
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
67
72
|
return SD3Transformer2DModelWrapper(model).eval()
|
68
73
|
|
69
74
|
@classmethod
|
70
|
-
def update_rbln_config_using_pipe(
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
elif img_width and img_height:
|
79
|
-
sample_size = img_height // pipe.vae_scale_factor, img_width // pipe.vae_scale_factor
|
80
|
-
|
81
|
-
prompt_max_length = rbln_config.get("max_sequence_length", 256)
|
82
|
-
prompt_embed_length = pipe.tokenizer_max_length + prompt_max_length
|
83
|
-
|
84
|
-
batch_size = rbln_config.get("batch_size")
|
85
|
-
if not batch_size:
|
86
|
-
do_classifier_free_guidance = rbln_config.get("guidance_scale", 5.0) > 1.0
|
87
|
-
batch_size = 2 if do_classifier_free_guidance else 1
|
88
|
-
else:
|
89
|
-
if rbln_config.get("guidance_scale"):
|
90
|
-
logger.warning(
|
91
|
-
"guidance_scale is ignored because batch size is explicitly specified. "
|
92
|
-
"To ensure consistent behavior, consider removing the guidance scale or "
|
93
|
-
"adjusting the batch size configuration as needed."
|
75
|
+
def update_rbln_config_using_pipe(
|
76
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
77
|
+
) -> "RBLNDiffusionMixinConfig":
|
78
|
+
if rbln_config.sample_size is None:
|
79
|
+
if rbln_config.image_size is not None:
|
80
|
+
rbln_config.transformer.sample_size = (
|
81
|
+
rbln_config.image_size[0] // pipe.vae_scale_factor,
|
82
|
+
rbln_config.image_size[1] // pipe.vae_scale_factor,
|
94
83
|
)
|
84
|
+
else:
|
85
|
+
rbln_config.transformer.sample_size = pipe.default_sample_size
|
95
86
|
|
96
|
-
rbln_config.
|
97
|
-
|
98
|
-
"batch_size": batch_size,
|
99
|
-
"prompt_embed_length": prompt_embed_length,
|
100
|
-
"sample_size": sample_size,
|
101
|
-
}
|
102
|
-
)
|
103
|
-
|
87
|
+
prompt_embed_length = pipe.tokenizer_max_length + rbln_config.max_seq_len
|
88
|
+
rbln_config.transformer.prompt_embed_length = prompt_embed_length
|
104
89
|
return rbln_config
|
105
90
|
|
106
91
|
@classmethod
|
107
|
-
def
|
92
|
+
def _update_rbln_config(
|
108
93
|
cls,
|
109
94
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
95
|
+
model: "PreTrainedModel",
|
110
96
|
model_config: "PretrainedConfig",
|
111
|
-
|
112
|
-
) ->
|
113
|
-
|
97
|
+
rbln_config: RBLNSD3Transformer2DModelConfig,
|
98
|
+
) -> RBLNSD3Transformer2DModelConfig:
|
99
|
+
if rbln_config.sample_size is None:
|
100
|
+
rbln_config.sample_size = model_config.sample_size
|
114
101
|
|
115
|
-
|
116
|
-
|
117
|
-
sample_size = (sample_size, sample_size)
|
118
|
-
|
119
|
-
rbln_prompt_embed_length = rbln_kwargs.get("prompt_embed_length")
|
120
|
-
if rbln_prompt_embed_length is None:
|
121
|
-
raise ValueError("rbln_prompt_embed_length should be specified.")
|
102
|
+
if isinstance(rbln_config.sample_size, int):
|
103
|
+
rbln_config.sample_size = (rbln_config.sample_size, rbln_config.sample_size)
|
122
104
|
|
123
105
|
input_info = [
|
124
106
|
(
|
125
107
|
"hidden_states",
|
126
108
|
[
|
127
|
-
|
109
|
+
rbln_config.batch_size,
|
128
110
|
model_config.in_channels,
|
129
|
-
sample_size[0],
|
130
|
-
sample_size[1],
|
111
|
+
rbln_config.sample_size[0],
|
112
|
+
rbln_config.sample_size[1],
|
131
113
|
],
|
132
114
|
"float32",
|
133
115
|
),
|
134
116
|
(
|
135
117
|
"encoder_hidden_states",
|
136
118
|
[
|
137
|
-
|
138
|
-
|
119
|
+
rbln_config.batch_size,
|
120
|
+
rbln_config.prompt_embed_length,
|
139
121
|
model_config.joint_attention_dim,
|
140
122
|
],
|
141
123
|
"float32",
|
@@ -143,24 +125,16 @@ class RBLNSD3Transformer2DModel(RBLNModel):
|
|
143
125
|
(
|
144
126
|
"pooled_projections",
|
145
127
|
[
|
146
|
-
|
128
|
+
rbln_config.batch_size,
|
147
129
|
model_config.pooled_projection_dim,
|
148
130
|
],
|
149
131
|
"float32",
|
150
132
|
),
|
151
|
-
("timestep", [
|
133
|
+
("timestep", [rbln_config.batch_size], "float32"),
|
152
134
|
]
|
153
135
|
|
154
|
-
|
155
|
-
|
156
|
-
rbln_config = RBLNConfig(
|
157
|
-
rbln_cls=cls.__name__,
|
158
|
-
compile_cfgs=[rbln_compile_config],
|
159
|
-
rbln_kwargs=rbln_kwargs,
|
160
|
-
)
|
161
|
-
|
162
|
-
rbln_config.model_cfg.update({"batch_size": rbln_batch_size})
|
163
|
-
|
136
|
+
compile_config = RBLNCompileConfig(input_info=input_info)
|
137
|
+
rbln_config.set_compile_cfgs([compile_config])
|
164
138
|
return rbln_config
|
165
139
|
|
166
140
|
@property
|
@@ -184,11 +158,12 @@ class RBLNSD3Transformer2DModel(RBLNModel):
|
|
184
158
|
sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
|
185
159
|
):
|
186
160
|
raise ValueError(
|
187
|
-
f"Mismatch between
|
161
|
+
f"Mismatch between transformer's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
|
188
162
|
"This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
|
189
|
-
"Adjust the batch size during compilation
|
163
|
+
"Adjust the batch size of transformer during compilation.\n\n"
|
190
164
|
"For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
|
191
165
|
)
|
192
166
|
|
193
|
-
|
194
|
-
|
167
|
+
return super().forward(
|
168
|
+
hidden_states, encoder_hidden_states, pooled_projections, timestep, return_dict=return_dict
|
169
|
+
)
|