optimum-rbln 0.7.4a3__py3-none-any.whl → 0.7.4a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +156 -36
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/configuration_utils.py +772 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +63 -122
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +55 -70
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- optimum/rbln/modeling.py +58 -39
- optimum/rbln/modeling_base.py +85 -80
- optimum/rbln/transformers/__init__.py +79 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +96 -34
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +50 -43
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +114 -141
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/submodule.py +26 -43
- {optimum_rbln-0.7.4a3.dist-info → optimum_rbln-0.7.4a5.dist-info}/METADATA +1 -1
- optimum_rbln-0.7.4a5.dist-info/RECORD +162 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.4a3.dist-info/RECORD +0 -126
- {optimum_rbln-0.7.4a3.dist-info → optimum_rbln-0.7.4a5.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.4a3.dist-info → optimum_rbln-0.7.4a5.dist-info}/licenses/LICENSE +0 -0
@@ -16,17 +16,18 @@ from dataclasses import dataclass
|
|
16
16
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
17
17
|
|
18
18
|
import torch
|
19
|
-
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
19
|
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput
|
20
20
|
from transformers import PretrainedConfig
|
21
21
|
|
22
|
+
from ....configuration_utils import RBLNCompileConfig
|
22
23
|
from ....modeling import RBLNModel
|
23
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
24
24
|
from ....utils.logging import get_logger
|
25
|
-
from ...
|
25
|
+
from ...configurations import RBLNUNet2DConditionModelConfig
|
26
|
+
from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
26
27
|
|
27
28
|
|
28
29
|
if TYPE_CHECKING:
|
29
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
30
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
30
31
|
|
31
32
|
logger = get_logger(__name__)
|
32
33
|
|
@@ -141,10 +142,13 @@ class _UNet_Kandinsky(torch.nn.Module):
|
|
141
142
|
class RBLNUNet2DConditionModel(RBLNModel):
|
142
143
|
hf_library_name = "diffusers"
|
143
144
|
auto_model_class = UNet2DConditionModel
|
145
|
+
_rbln_config_class = RBLNUNet2DConditionModelConfig
|
146
|
+
output_class = UNet2DConditionOutput
|
147
|
+
output_key = "sample"
|
144
148
|
|
145
149
|
def __post_init__(self, **kwargs):
|
146
150
|
super().__post_init__(**kwargs)
|
147
|
-
self.in_features = self.rbln_config.
|
151
|
+
self.in_features = self.rbln_config.in_features
|
148
152
|
if self.in_features is not None:
|
149
153
|
|
150
154
|
@dataclass
|
@@ -158,7 +162,9 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
158
162
|
self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
|
159
163
|
|
160
164
|
@classmethod
|
161
|
-
def wrap_model_if_needed(
|
165
|
+
def wrap_model_if_needed(
|
166
|
+
cls, model: torch.nn.Module, rbln_config: RBLNUNet2DConditionModelConfig
|
167
|
+
) -> torch.nn.Module:
|
162
168
|
if model.config.addition_embed_type == "text_time":
|
163
169
|
return _UNet_SDXL(model).eval()
|
164
170
|
elif model.config.addition_embed_type == "image":
|
@@ -168,117 +174,117 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
168
174
|
|
169
175
|
@classmethod
|
170
176
|
def get_unet_sample_size(
|
171
|
-
cls,
|
172
|
-
|
173
|
-
|
177
|
+
cls,
|
178
|
+
pipe: RBLNDiffusionMixin,
|
179
|
+
rbln_config: RBLNUNet2DConditionModelConfig,
|
180
|
+
image_size: Optional[Tuple[int, int]] = None,
|
181
|
+
) -> Tuple[int, int]:
|
174
182
|
scale_factor = pipe.movq_scale_factor if hasattr(pipe, "movq_scale_factor") else pipe.vae_scale_factor
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
if rbln_config["img2img_pipeline"]:
|
183
|
+
|
184
|
+
if image_size is None:
|
185
|
+
if "Img2Img" in pipe.__class__.__name__:
|
179
186
|
if hasattr(pipe, "vae"):
|
180
187
|
# In case of img2img, sample size of unet is determined by vae encoder.
|
181
188
|
vae_sample_size = pipe.vae.config.sample_size
|
182
189
|
if isinstance(vae_sample_size, int):
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
190
|
+
vae_sample_size = (vae_sample_size, vae_sample_size)
|
191
|
+
|
192
|
+
sample_size = (
|
193
|
+
vae_sample_size[0] // scale_factor,
|
194
|
+
vae_sample_size[1] // scale_factor,
|
195
|
+
)
|
189
196
|
elif hasattr(pipe, "movq"):
|
190
197
|
logger.warning(
|
191
|
-
"RBLN config '
|
198
|
+
"RBLN config 'image_size' should have been provided for this pipeline. "
|
192
199
|
"Both variable will be set 512 by default."
|
193
200
|
)
|
194
201
|
sample_size = (512 // scale_factor, 512 // scale_factor)
|
195
202
|
else:
|
196
203
|
sample_size = pipe.unet.config.sample_size
|
204
|
+
if isinstance(sample_size, int):
|
205
|
+
sample_size = (sample_size, sample_size)
|
197
206
|
else:
|
198
207
|
sample_size = (image_size[0] // scale_factor, image_size[1] // scale_factor)
|
199
208
|
|
200
209
|
return sample_size
|
201
210
|
|
202
211
|
@classmethod
|
203
|
-
def update_rbln_config_using_pipe(
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
rbln_config.get("guidance_scale", 5.0) > 1.0 and pipe.unet.config.time_cond_proj_dim is None
|
211
|
-
)
|
212
|
-
batch_size = 2 if do_classifier_free_guidance else 1
|
213
|
-
else:
|
214
|
-
if rbln_config.get("guidance_scale"):
|
215
|
-
logger.warning(
|
216
|
-
"guidance_scale is ignored because batch size is explicitly specified. "
|
217
|
-
"To ensure consistent behavior, consider removing the guidance scale or "
|
218
|
-
"adjusting the batch size configuration as needed."
|
219
|
-
)
|
212
|
+
def update_rbln_config_using_pipe(
|
213
|
+
cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
214
|
+
) -> "RBLNDiffusionMixinConfig":
|
215
|
+
rbln_config.unet.text_model_hidden_size = (
|
216
|
+
pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
|
217
|
+
)
|
218
|
+
rbln_config.unet.image_model_hidden_size = pipe.unet.config.encoder_hid_dim if hasattr(pipe, "unet") else None
|
220
219
|
|
221
|
-
max_seq_len =
|
222
|
-
|
223
|
-
{
|
224
|
-
"max_seq_len": max_seq_len,
|
225
|
-
"text_model_hidden_size": text_model_hidden_size,
|
226
|
-
"image_model_hidden_size": image_model_hidden_size,
|
227
|
-
"sample_size": cls.get_unet_sample_size(pipe, rbln_config),
|
228
|
-
"batch_size": batch_size,
|
229
|
-
"is_controlnet": "controlnet" in pipe.config.keys(),
|
230
|
-
}
|
220
|
+
rbln_config.unet.max_seq_len = (
|
221
|
+
pipe.text_encoder.config.max_position_embeddings if hasattr(pipe, "text_encoder") else None
|
231
222
|
)
|
232
223
|
|
224
|
+
rbln_config.unet.sample_size = cls.get_unet_sample_size(
|
225
|
+
pipe, rbln_config.unet, image_size=rbln_config.image_size
|
226
|
+
)
|
227
|
+
rbln_config.unet.use_additional_residuals = "controlnet" in pipe.config.keys()
|
228
|
+
|
233
229
|
return rbln_config
|
234
230
|
|
235
231
|
@classmethod
|
236
|
-
def
|
232
|
+
def _update_rbln_config(
|
237
233
|
cls,
|
238
234
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
235
|
+
model: "PreTrainedModel",
|
239
236
|
model_config: "PretrainedConfig",
|
240
|
-
|
241
|
-
) ->
|
242
|
-
|
243
|
-
|
244
|
-
sample_size = rbln_kwargs.get("sample_size")
|
245
|
-
is_controlnet = rbln_kwargs.get("is_controlnet")
|
246
|
-
rbln_in_features = None
|
247
|
-
|
248
|
-
if batch_size is None:
|
249
|
-
batch_size = 1
|
250
|
-
|
251
|
-
if sample_size is None:
|
252
|
-
sample_size = model_config.sample_size
|
237
|
+
rbln_config: RBLNUNet2DConditionModelConfig,
|
238
|
+
) -> RBLNUNet2DConditionModelConfig:
|
239
|
+
if rbln_config.sample_size is None:
|
240
|
+
rbln_config.sample_size = model_config.sample_size
|
253
241
|
|
254
|
-
if isinstance(sample_size, int):
|
255
|
-
sample_size = (sample_size, sample_size)
|
242
|
+
if isinstance(rbln_config.sample_size, int):
|
243
|
+
rbln_config.sample_size = (rbln_config.sample_size, rbln_config.sample_size)
|
256
244
|
|
257
245
|
input_info = [
|
258
|
-
(
|
246
|
+
(
|
247
|
+
"sample",
|
248
|
+
[
|
249
|
+
rbln_config.batch_size,
|
250
|
+
model_config.in_channels,
|
251
|
+
rbln_config.sample_size[0],
|
252
|
+
rbln_config.sample_size[1],
|
253
|
+
],
|
254
|
+
"float32",
|
255
|
+
),
|
259
256
|
("timestep", [], "float32"),
|
260
257
|
]
|
261
258
|
|
262
|
-
if max_seq_len is not None:
|
259
|
+
if rbln_config.max_seq_len is not None:
|
263
260
|
input_info.append(
|
264
|
-
(
|
261
|
+
(
|
262
|
+
"encoder_hidden_states",
|
263
|
+
[rbln_config.batch_size, rbln_config.max_seq_len, model_config.cross_attention_dim],
|
264
|
+
"float32",
|
265
|
+
),
|
265
266
|
)
|
266
267
|
|
267
|
-
if
|
268
|
+
if rbln_config.use_additional_residuals:
|
268
269
|
# down block addtional residuals
|
269
|
-
first_shape = [
|
270
|
-
|
270
|
+
first_shape = [
|
271
|
+
rbln_config.batch_size,
|
272
|
+
model_config.block_out_channels[0],
|
273
|
+
rbln_config.sample_size[0],
|
274
|
+
rbln_config.sample_size[1],
|
275
|
+
]
|
276
|
+
height, width = rbln_config.sample_size[0], rbln_config.sample_size[1]
|
271
277
|
input_info.append(("down_block_additional_residuals_0", first_shape, "float32"))
|
272
278
|
name_idx = 1
|
273
279
|
for idx, _ in enumerate(model_config.down_block_types):
|
274
|
-
shape = [batch_size, model_config.block_out_channels[idx], height, width]
|
280
|
+
shape = [rbln_config.batch_size, model_config.block_out_channels[idx], height, width]
|
275
281
|
for _ in range(model_config.layers_per_block):
|
276
282
|
input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
|
277
283
|
name_idx += 1
|
278
284
|
if idx != len(model_config.down_block_types) - 1:
|
279
285
|
height = height // 2
|
280
286
|
width = width // 2
|
281
|
-
shape = [batch_size, model_config.block_out_channels[idx], height, width]
|
287
|
+
shape = [rbln_config.batch_size, model_config.block_out_channels[idx], height, width]
|
282
288
|
input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
|
283
289
|
name_idx += 1
|
284
290
|
|
@@ -286,33 +292,27 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
286
292
|
num_cross_attn_blocks = model_config.down_block_types.count("CrossAttnDownBlock2D")
|
287
293
|
out_channels = model_config.block_out_channels[-1]
|
288
294
|
shape = [
|
289
|
-
batch_size,
|
295
|
+
rbln_config.batch_size,
|
290
296
|
out_channels,
|
291
|
-
sample_size[0] // 2**num_cross_attn_blocks,
|
292
|
-
sample_size[1] // 2**num_cross_attn_blocks,
|
297
|
+
rbln_config.sample_size[0] // 2**num_cross_attn_blocks,
|
298
|
+
rbln_config.sample_size[1] // 2**num_cross_attn_blocks,
|
293
299
|
]
|
294
300
|
input_info.append(("mid_block_additional_residual", shape, "float32"))
|
295
301
|
|
296
302
|
if hasattr(model_config, "addition_embed_type"):
|
297
303
|
if model_config.addition_embed_type == "text_time":
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
304
|
+
rbln_config.in_features = model_config.projection_class_embeddings_input_dim
|
305
|
+
input_info.append(
|
306
|
+
("text_embeds", [rbln_config.batch_size, rbln_config.text_model_hidden_size], "float32")
|
307
|
+
)
|
308
|
+
input_info.append(("time_ids", [rbln_config.batch_size, 6], "float32"))
|
302
309
|
elif model_config.addition_embed_type == "image":
|
303
|
-
|
304
|
-
|
310
|
+
input_info.append(
|
311
|
+
("image_embeds", [rbln_config.batch_size, rbln_config.image_model_hidden_size], "float32")
|
312
|
+
)
|
305
313
|
|
306
314
|
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
307
|
-
|
308
|
-
rbln_config = RBLNConfig(
|
309
|
-
rbln_cls=cls.__name__,
|
310
|
-
compile_cfgs=[rbln_compile_config],
|
311
|
-
rbln_kwargs=rbln_kwargs,
|
312
|
-
)
|
313
|
-
|
314
|
-
if rbln_in_features is not None:
|
315
|
-
rbln_config.model_cfg["in_features"] = rbln_in_features
|
315
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
316
316
|
|
317
317
|
return rbln_config
|
318
318
|
|
@@ -336,7 +336,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
336
336
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
337
337
|
return_dict: bool = True,
|
338
338
|
**kwargs,
|
339
|
-
):
|
339
|
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
340
340
|
sample_batch_size = sample.size()[0]
|
341
341
|
compiled_batch_size = self.compiled_batch_size
|
342
342
|
if sample_batch_size != compiled_batch_size and (
|
@@ -344,8 +344,8 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
344
344
|
):
|
345
345
|
raise ValueError(
|
346
346
|
f"Mismatch between UNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
|
347
|
-
"This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
|
348
|
-
"Adjust the batch size during compilation
|
347
|
+
"This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size of UNet in Stable Diffusion. "
|
348
|
+
"Adjust the batch size of UNet during compilation to match the runtime batch size.\n\n"
|
349
349
|
"For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
|
350
350
|
)
|
351
351
|
|
@@ -353,31 +353,28 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
353
353
|
|
354
354
|
if down_block_additional_residuals is not None:
|
355
355
|
down_block_additional_residuals = [t.contiguous() for t in down_block_additional_residuals]
|
356
|
-
return (
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
),
|
356
|
+
return super().forward(
|
357
|
+
sample.contiguous(),
|
358
|
+
timestep.float(),
|
359
|
+
encoder_hidden_states,
|
360
|
+
*down_block_additional_residuals,
|
361
|
+
mid_block_additional_residual,
|
362
|
+
**added_cond_kwargs,
|
363
|
+
return_dict=return_dict,
|
365
364
|
)
|
366
365
|
|
367
366
|
if "image_embeds" in added_cond_kwargs:
|
368
|
-
return (
|
369
|
-
super().forward(
|
370
|
-
sample.contiguous(),
|
371
|
-
timestep.float(),
|
372
|
-
**added_cond_kwargs,
|
373
|
-
),
|
374
|
-
)
|
375
|
-
|
376
|
-
return (
|
377
|
-
super().forward(
|
367
|
+
return super().forward(
|
378
368
|
sample.contiguous(),
|
379
369
|
timestep.float(),
|
380
|
-
encoder_hidden_states,
|
381
370
|
**added_cond_kwargs,
|
382
|
-
|
371
|
+
return_dict=return_dict,
|
372
|
+
)
|
373
|
+
|
374
|
+
return super().forward(
|
375
|
+
sample.contiguous(),
|
376
|
+
timestep.float(),
|
377
|
+
encoder_hidden_states,
|
378
|
+
**added_cond_kwargs,
|
379
|
+
return_dict=return_dict,
|
383
380
|
)
|
@@ -81,7 +81,7 @@ class RBLNMultiControlNetModel(RBLNModel):
|
|
81
81
|
model.save_pretrained(real_save_path)
|
82
82
|
|
83
83
|
@classmethod
|
84
|
-
def
|
84
|
+
def _update_rbln_config(cls, **rbln_config_kwargs):
|
85
85
|
pass
|
86
86
|
|
87
87
|
def forward(
|
@@ -100,16 +100,15 @@ class RBLNMultiControlNetModel(RBLNModel):
|
|
100
100
|
return_dict: bool = True,
|
101
101
|
):
|
102
102
|
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
103
|
-
|
103
|
+
down_samples, mid_sample = controlnet(
|
104
104
|
sample=sample.contiguous(),
|
105
105
|
timestep=timestep.float(),
|
106
106
|
encoder_hidden_states=encoder_hidden_states,
|
107
107
|
controlnet_cond=image,
|
108
108
|
conditioning_scale=torch.tensor(scale),
|
109
|
+
return_dict=return_dict,
|
109
110
|
)
|
110
111
|
|
111
|
-
down_samples, mid_sample = output[:-1], output[-1]
|
112
|
-
|
113
112
|
# merge samples
|
114
113
|
if i == 0:
|
115
114
|
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
|
15
16
|
from diffusers import StableDiffusionPipeline
|
16
17
|
|
17
18
|
from ...modeling_diffusers import RBLNDiffusionMixin
|
@@ -19,4 +20,4 @@ from ...modeling_diffusers import RBLNDiffusionMixin
|
|
19
20
|
|
20
21
|
class RBLNStableDiffusionPipeline(RBLNDiffusionMixin, StableDiffusionPipeline):
|
21
22
|
original_class = StableDiffusionPipeline
|
22
|
-
_submodules = ["
|
23
|
+
_submodules = ["vae", "text_encoder", "unet"]
|
optimum/rbln/modeling.py
CHANGED
@@ -14,15 +14,16 @@
|
|
14
14
|
|
15
15
|
from pathlib import Path
|
16
16
|
from tempfile import TemporaryDirectory
|
17
|
-
from typing import TYPE_CHECKING,
|
17
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
18
18
|
|
19
19
|
import rebel
|
20
20
|
import torch
|
21
21
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
22
22
|
from transformers import AutoConfig, PretrainedConfig
|
23
|
+
from transformers.modeling_outputs import BaseModelOutput
|
23
24
|
|
25
|
+
from .configuration_utils import DEFAULT_COMPILED_MODEL_NAME, RBLNModelConfig
|
24
26
|
from .modeling_base import RBLNBaseModel
|
25
|
-
from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, use_rbln_config
|
26
27
|
from .utils.logging import get_logger
|
27
28
|
|
28
29
|
|
@@ -48,6 +49,9 @@ class RBLNModel(RBLNBaseModel):
|
|
48
49
|
```
|
49
50
|
"""
|
50
51
|
|
52
|
+
output_class = None
|
53
|
+
output_key = "last_hidden_state"
|
54
|
+
|
51
55
|
@classmethod
|
52
56
|
def update_kwargs(cls, kwargs):
|
53
57
|
"""
|
@@ -56,12 +60,7 @@ class RBLNModel(RBLNBaseModel):
|
|
56
60
|
For example, `torchscript`=True should be set because torch.jit
|
57
61
|
does not support `transformers` output instances as module output;
|
58
62
|
"""
|
59
|
-
kwargs.update(
|
60
|
-
{
|
61
|
-
"torchscript": True,
|
62
|
-
"return_dict": False,
|
63
|
-
}
|
64
|
-
)
|
63
|
+
kwargs.update({"torchscript": True})
|
65
64
|
return kwargs
|
66
65
|
|
67
66
|
@classmethod
|
@@ -70,7 +69,7 @@ class RBLNModel(RBLNBaseModel):
|
|
70
69
|
model: "PreTrainedModel",
|
71
70
|
save_dir_path: Path,
|
72
71
|
subfolder: str,
|
73
|
-
rbln_config:
|
72
|
+
rbln_config: RBLNModelConfig,
|
74
73
|
):
|
75
74
|
"""
|
76
75
|
If you are unavoidably running on a CPU rather than an RBLN device,
|
@@ -78,30 +77,29 @@ class RBLNModel(RBLNBaseModel):
|
|
78
77
|
"""
|
79
78
|
|
80
79
|
@classmethod
|
81
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config:
|
80
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
82
81
|
# Wrap the model if needed.
|
83
82
|
return model
|
84
83
|
|
85
84
|
@classmethod
|
86
|
-
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config:
|
85
|
+
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
|
87
86
|
model = cls.wrap_model_if_needed(model, rbln_config)
|
88
87
|
rbln_compile_config = rbln_config.compile_cfgs[0]
|
89
88
|
compiled_model = cls.compile(model, rbln_compile_config=rbln_compile_config)
|
90
89
|
return compiled_model
|
91
90
|
|
92
91
|
@classmethod
|
93
|
-
@use_rbln_config
|
94
92
|
def from_model(
|
95
93
|
cls,
|
96
94
|
model: "PreTrainedModel",
|
97
95
|
config: Optional[PretrainedConfig] = None,
|
98
|
-
rbln_config:
|
96
|
+
rbln_config: Optional[RBLNModelConfig] = None,
|
99
97
|
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
100
98
|
subfolder: str = "",
|
101
99
|
**kwargs,
|
102
100
|
):
|
103
101
|
preprocessors = kwargs.pop("preprocessors", [])
|
104
|
-
|
102
|
+
rbln_config, kwargs = cls.prepare_rbln_config(rbln_config=rbln_config, **kwargs)
|
105
103
|
|
106
104
|
# Directory to save compile artifacts(.rbln) and original configs
|
107
105
|
if model_save_dir is None:
|
@@ -141,14 +139,21 @@ class RBLNModel(RBLNBaseModel):
|
|
141
139
|
for preprocessor in preprocessors:
|
142
140
|
preprocessor.save_pretrained(save_dir_path / subfolder)
|
143
141
|
|
144
|
-
#
|
145
|
-
|
142
|
+
# Load submodules
|
143
|
+
if len(cls._rbln_submodules) > 0:
|
144
|
+
rbln_submodules = cls._load_submodules(
|
145
|
+
model=model,
|
146
|
+
model_save_dir=save_dir,
|
147
|
+
rbln_config=rbln_config,
|
148
|
+
**kwargs,
|
149
|
+
)
|
150
|
+
else:
|
151
|
+
rbln_submodules = []
|
146
152
|
|
147
153
|
# Get compilation arguments (e.g. input_info)
|
148
|
-
rbln_config:
|
149
|
-
preprocessors=preprocessors, model_config=config,
|
154
|
+
rbln_config: RBLNModelConfig = cls.update_rbln_config(
|
155
|
+
preprocessors=preprocessors, model=model, model_config=config, rbln_config=rbln_config
|
150
156
|
)
|
151
|
-
# rbln_config.update_runtime_cfg(rbln_kwargs) # This is done in get_rbln_config
|
152
157
|
|
153
158
|
compiled_model: Union[rebel.RBLNCompiledModel, Dict[str, rebel.RBLNCompiledModel]] = cls.get_compiled_model(
|
154
159
|
model, rbln_config=rbln_config
|
@@ -167,17 +172,6 @@ class RBLNModel(RBLNBaseModel):
|
|
167
172
|
# Save torch artifacts (e.g. embedding matrix if needed.)
|
168
173
|
cls.save_torch_artifacts(model, save_dir_path=save_dir_path, subfolder=subfolder, rbln_config=rbln_config)
|
169
174
|
|
170
|
-
# Load submodules
|
171
|
-
if len(cls._rbln_submodules) > 0:
|
172
|
-
rbln_submodules = cls._load_submodules(
|
173
|
-
model=model,
|
174
|
-
model_save_dir=save_dir,
|
175
|
-
rbln_kwargs=rbln_kwargs,
|
176
|
-
**kwargs,
|
177
|
-
)
|
178
|
-
else:
|
179
|
-
rbln_submodules = []
|
180
|
-
|
181
175
|
# Instantiate
|
182
176
|
return cls._from_pretrained(
|
183
177
|
model_id=save_dir_path,
|
@@ -201,8 +195,8 @@ class RBLNModel(RBLNBaseModel):
|
|
201
195
|
subfolder: str = "",
|
202
196
|
local_files_only: bool = False,
|
203
197
|
trust_remote_code: bool = False,
|
204
|
-
# Some rbln-
|
205
|
-
|
198
|
+
# Some rbln-config should be applied before loading torch module (i.e. quantized llm)
|
199
|
+
rbln_config: Optional[RBLNModelConfig] = None,
|
206
200
|
**kwargs,
|
207
201
|
) -> "PreTrainedModel":
|
208
202
|
kwargs = cls.update_kwargs(kwargs)
|
@@ -222,18 +216,43 @@ class RBLNModel(RBLNBaseModel):
|
|
222
216
|
def _create_runtimes(
|
223
217
|
cls,
|
224
218
|
compiled_models: List[rebel.RBLNCompiledModel],
|
225
|
-
|
226
|
-
activate_profiler: Optional[bool] = None,
|
219
|
+
rbln_config: RBLNModelConfig,
|
227
220
|
) -> List[rebel.Runtime]:
|
228
|
-
if DEFAULT_COMPILED_MODEL_NAME not in
|
221
|
+
if DEFAULT_COMPILED_MODEL_NAME not in rbln_config.device_map:
|
229
222
|
cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
|
230
223
|
|
231
|
-
device = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
232
224
|
return [
|
233
|
-
|
225
|
+
rebel.Runtime(
|
226
|
+
compiled_model,
|
227
|
+
tensor_type="pt",
|
228
|
+
device=rbln_config.device_map[DEFAULT_COMPILED_MODEL_NAME],
|
229
|
+
activate_profiler=rbln_config.activate_profiler,
|
230
|
+
)
|
234
231
|
for compiled_model in compiled_models
|
235
232
|
]
|
236
233
|
|
237
|
-
def forward(self, *args:
|
234
|
+
def forward(self, *args, return_dict: Optional[bool] = None, **kwargs):
|
235
|
+
if self.hf_library_name == "transformers":
|
236
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
237
|
+
else:
|
238
|
+
return_dict = True if return_dict is None else return_dict
|
239
|
+
|
240
|
+
# Get output from the model
|
238
241
|
output = self.model[0](*args, **kwargs)
|
239
|
-
|
242
|
+
|
243
|
+
# Format output according to task requirements
|
244
|
+
return self._prepare_output(output, return_dict)
|
245
|
+
|
246
|
+
def _prepare_output(self, output, return_dict):
|
247
|
+
"""
|
248
|
+
Prepare model output based on return_dict flag.
|
249
|
+
This method can be overridden by subclasses to provide task-specific output handling.
|
250
|
+
"""
|
251
|
+
if not return_dict:
|
252
|
+
return (output,) if not isinstance(output, (tuple, list)) else output
|
253
|
+
else:
|
254
|
+
if self.output_class is None:
|
255
|
+
return BaseModelOutput(last_hidden_state=output)
|
256
|
+
|
257
|
+
# Create output with the appropriate class and key
|
258
|
+
return self.output_class(**{self.output_key: output})
|