optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +27 -13
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +85 -65
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
- optimum/rbln/diffusers/pipelines/__init__.py +60 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +176 -763
- optimum/rbln/modeling_diffusers.py +329 -0
- optimum/rbln/transformers/__init__.py +2 -2
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
- optimum/rbln/utils/decorator_utils.py +59 -0
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +5 -5
- optimum/rbln/utils/submodule.py +114 -0
- optimum/rbln/utils/timer_utils.py +2 -2
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- optimum_rbln-0.1.15.dist-info/RECORD +110 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum_rbln-0.1.12.dist-info/METADATA +0 -119
- optimum_rbln-0.1.12.dist-info/RECORD +0 -103
- optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -23,16 +23,15 @@
|
|
23
23
|
|
24
24
|
import logging
|
25
25
|
from dataclasses import dataclass
|
26
|
-
from pathlib import Path
|
27
26
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
28
27
|
|
29
28
|
import torch
|
30
29
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
31
|
-
from
|
32
|
-
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
30
|
+
from transformers import PretrainedConfig
|
33
31
|
|
34
|
-
from
|
35
|
-
from
|
32
|
+
from ....modeling import RBLNModel
|
33
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
34
|
+
from ....modeling_diffusers import RBLNDiffusionMixin
|
36
35
|
|
37
36
|
|
38
37
|
if TYPE_CHECKING:
|
@@ -126,6 +125,9 @@ class _UNet_SDXL(torch.nn.Module):
|
|
126
125
|
|
127
126
|
|
128
127
|
class RBLNUNet2DConditionModel(RBLNModel):
|
128
|
+
hf_library_name = "diffusers"
|
129
|
+
auto_model_class = UNet2DConditionModel
|
130
|
+
|
129
131
|
def __post_init__(self, **kwargs):
|
130
132
|
super().__post_init__(**kwargs)
|
131
133
|
self.in_features = self.rbln_config.model_cfg.get("in_features", None)
|
@@ -141,33 +143,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
141
143
|
|
142
144
|
self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
|
143
145
|
|
144
|
-
@classmethod
|
145
|
-
def from_pretrained(cls, *args, **kwargs):
|
146
|
-
def get_model_from_task(
|
147
|
-
task: str,
|
148
|
-
model_name_or_path: Union[str, Path],
|
149
|
-
**kwargs,
|
150
|
-
):
|
151
|
-
return UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
|
152
|
-
|
153
|
-
tasktmp = TasksManager.get_model_from_task
|
154
|
-
configtmp = AutoConfig.from_pretrained
|
155
|
-
modeltmp = AutoModel.from_pretrained
|
156
|
-
TasksManager.get_model_from_task = get_model_from_task
|
157
|
-
if kwargs.get("export", None):
|
158
|
-
# This is an ad-hoc to workaround save null values of the config.
|
159
|
-
# if export, pure optimum(not optimum-rbln) loads config using AutoConfig
|
160
|
-
# and diffusers model do not support loading by AutoConfig.
|
161
|
-
AutoConfig.from_pretrained = lambda *args, **kwargs: None
|
162
|
-
else:
|
163
|
-
AutoConfig.from_pretrained = UNet2DConditionModel.load_config
|
164
|
-
AutoModel.from_pretrained = UNet2DConditionModel.from_pretrained
|
165
|
-
rt = super().from_pretrained(*args, **kwargs)
|
166
|
-
AutoConfig.from_pretrained = configtmp
|
167
|
-
AutoModel.from_pretrained = modeltmp
|
168
|
-
TasksManager.get_model_from_task = tasktmp
|
169
|
-
return rt
|
170
|
-
|
171
146
|
@classmethod
|
172
147
|
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
173
148
|
if model.config.addition_embed_type == "text_time":
|
@@ -175,6 +150,61 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
175
150
|
else:
|
176
151
|
return _UNet_SD(model).eval()
|
177
152
|
|
153
|
+
@classmethod
|
154
|
+
def get_unet_sample_size(
|
155
|
+
cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]
|
156
|
+
) -> Union[int, Tuple[int, int]]:
|
157
|
+
image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
|
158
|
+
if (image_size[0] is None) != (image_size[1] is None):
|
159
|
+
raise ValueError("Both image height and image width must be given or not given")
|
160
|
+
elif image_size[0] is None and image_size[1] is None:
|
161
|
+
if rbln_config["img2img_pipeline"]:
|
162
|
+
# In case of img2img, sample size of unet is determined by vae encoder.
|
163
|
+
vae_sample_size = pipe.vae.config.sample_size
|
164
|
+
if isinstance(vae_sample_size, int):
|
165
|
+
sample_size = vae_sample_size // pipe.vae_scale_factor
|
166
|
+
else:
|
167
|
+
sample_size = (
|
168
|
+
vae_sample_size[0] // pipe.vae_scale_factor,
|
169
|
+
vae_sample_size[1] // pipe.vae_scale_factor,
|
170
|
+
)
|
171
|
+
else:
|
172
|
+
sample_size = pipe.unet.config.sample_size
|
173
|
+
else:
|
174
|
+
sample_size = (image_size[0] // pipe.vae_scale_factor, image_size[1] // pipe.vae_scale_factor)
|
175
|
+
|
176
|
+
return sample_size
|
177
|
+
|
178
|
+
@classmethod
|
179
|
+
def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
180
|
+
text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
|
181
|
+
|
182
|
+
batch_size = rbln_config.get("batch_size")
|
183
|
+
if not batch_size:
|
184
|
+
do_classifier_free_guidance = (
|
185
|
+
rbln_config.get("guidance_scale", 5.0) > 1.0 and pipe.unet.config.time_cond_proj_dim is None
|
186
|
+
)
|
187
|
+
batch_size = 2 if do_classifier_free_guidance else 1
|
188
|
+
else:
|
189
|
+
if rbln_config.get("guidance_scale"):
|
190
|
+
logger.warning(
|
191
|
+
"guidance_scale is ignored because batch size is explicitly specified. "
|
192
|
+
"To ensure consistent behavior, consider removing the guidance scale or "
|
193
|
+
"adjusting the batch size configuration as needed."
|
194
|
+
)
|
195
|
+
|
196
|
+
rbln_config.update(
|
197
|
+
{
|
198
|
+
"max_seq_len": pipe.text_encoder.config.max_position_embeddings,
|
199
|
+
"text_model_hidden_size": text_model_hidden_size,
|
200
|
+
"sample_size": cls.get_unet_sample_size(pipe, rbln_config),
|
201
|
+
"batch_size": batch_size,
|
202
|
+
"is_controlnet": "controlnet" in pipe.config.keys(),
|
203
|
+
}
|
204
|
+
)
|
205
|
+
|
206
|
+
return rbln_config
|
207
|
+
|
178
208
|
@classmethod
|
179
209
|
def _get_rbln_config(
|
180
210
|
cls,
|
@@ -182,137 +212,68 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
182
212
|
model_config: "PretrainedConfig",
|
183
213
|
rbln_kwargs: Dict[str, Any] = {},
|
184
214
|
) -> RBLNConfig:
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
if
|
198
|
-
|
199
|
-
|
200
|
-
if
|
201
|
-
|
202
|
-
raise ValueError(
|
203
|
-
"rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided when rbln_use_encode is True"
|
204
|
-
)
|
205
|
-
input_width = rbln_img_width // rbln_vae_scale_factor
|
206
|
-
input_height = rbln_img_height // rbln_vae_scale_factor
|
207
|
-
else:
|
208
|
-
input_width, input_height = model_config.sample_size, model_config.sample_size
|
215
|
+
batch_size = rbln_kwargs.get("batch_size")
|
216
|
+
max_seq_len = rbln_kwargs.get("max_seq_len")
|
217
|
+
sample_size = rbln_kwargs.get("sample_size")
|
218
|
+
is_controlnet = rbln_kwargs.get("is_controlnet")
|
219
|
+
rbln_in_features = None
|
220
|
+
|
221
|
+
if batch_size is None:
|
222
|
+
batch_size = 1
|
223
|
+
|
224
|
+
if sample_size is None:
|
225
|
+
sample_size = model_config.sample_size
|
226
|
+
|
227
|
+
if isinstance(sample_size, int):
|
228
|
+
sample_size = (sample_size, sample_size)
|
229
|
+
|
230
|
+
if max_seq_len is None:
|
231
|
+
raise ValueError("`rbln_max_seq_len` (ex. text_encoder's max_position_embeddings) must be specified.")
|
209
232
|
|
210
233
|
input_info = [
|
211
|
-
(
|
212
|
-
"sample",
|
213
|
-
[
|
214
|
-
rbln_batch_size,
|
215
|
-
model_config.in_channels,
|
216
|
-
input_height,
|
217
|
-
input_width,
|
218
|
-
],
|
219
|
-
"float32",
|
220
|
-
),
|
234
|
+
("sample", [batch_size, model_config.in_channels, sample_size[0], sample_size[1]], "float32"),
|
221
235
|
("timestep", [], "float32"),
|
222
|
-
(
|
223
|
-
"encoder_hidden_states",
|
224
|
-
[
|
225
|
-
rbln_batch_size,
|
226
|
-
rbln_max_seq_len,
|
227
|
-
model_config.cross_attention_dim,
|
228
|
-
],
|
229
|
-
"float32",
|
230
|
-
),
|
236
|
+
("encoder_hidden_states", [batch_size, max_seq_len, model_config.cross_attention_dim], "float32"),
|
231
237
|
]
|
232
238
|
|
233
|
-
if
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
)
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
]
|
262
|
-
)
|
263
|
-
if len(model_config.block_out_channels) > 2:
|
264
|
-
input_info.append(
|
265
|
-
(
|
266
|
-
f"down_block_additional_residuals_{6}",
|
267
|
-
[rbln_batch_size, model_config.block_out_channels[1], input_height // 4, input_width // 4],
|
268
|
-
"float32",
|
269
|
-
)
|
270
|
-
)
|
271
|
-
input_info.extend(
|
272
|
-
[
|
273
|
-
(
|
274
|
-
f"down_block_additional_residuals_{i}",
|
275
|
-
[rbln_batch_size, model_config.block_out_channels[2], input_height // 4, input_width // 4],
|
276
|
-
"float32",
|
277
|
-
)
|
278
|
-
for i in range(7, 9)
|
279
|
-
]
|
280
|
-
)
|
281
|
-
if len(model_config.block_out_channels) > 3:
|
282
|
-
input_info.extend(
|
283
|
-
[
|
284
|
-
(
|
285
|
-
f"down_block_additional_residuals_{i}",
|
286
|
-
[rbln_batch_size, model_config.block_out_channels[3], input_height // 8, input_width // 8],
|
287
|
-
"float32",
|
288
|
-
)
|
289
|
-
for i in range(9, 12)
|
290
|
-
]
|
291
|
-
)
|
292
|
-
input_info.append(
|
293
|
-
(
|
294
|
-
"mid_block_additional_residual",
|
295
|
-
[
|
296
|
-
rbln_batch_size,
|
297
|
-
model_config.block_out_channels[-1],
|
298
|
-
input_height // 2 ** (len(model_config.block_out_channels) - 1),
|
299
|
-
input_width // 2 ** (len(model_config.block_out_channels) - 1),
|
300
|
-
],
|
301
|
-
"float32",
|
302
|
-
)
|
303
|
-
)
|
239
|
+
if is_controlnet:
|
240
|
+
# down block addtional residuals
|
241
|
+
first_shape = [batch_size, model_config.block_out_channels[0], sample_size[0], sample_size[1]]
|
242
|
+
height, width = sample_size[0], sample_size[1]
|
243
|
+
input_info.append(("down_block_additional_residuals_0", first_shape, "float32"))
|
244
|
+
name_idx = 1
|
245
|
+
for idx, _ in enumerate(model_config.down_block_types):
|
246
|
+
shape = [batch_size, model_config.block_out_channels[idx], height, width]
|
247
|
+
for _ in range(model_config.layers_per_block):
|
248
|
+
input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
|
249
|
+
name_idx += 1
|
250
|
+
if idx != len(model_config.down_block_types) - 1:
|
251
|
+
height = height // 2
|
252
|
+
width = width // 2
|
253
|
+
shape = [batch_size, model_config.block_out_channels[idx], height, width]
|
254
|
+
input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
|
255
|
+
name_idx += 1
|
256
|
+
|
257
|
+
# mid block addtional residual
|
258
|
+
num_cross_attn_blocks = model_config.down_block_types.count("CrossAttnDownBlock2D")
|
259
|
+
out_channels = model_config.block_out_channels[-1]
|
260
|
+
shape = [
|
261
|
+
batch_size,
|
262
|
+
out_channels,
|
263
|
+
sample_size[0] // 2**num_cross_attn_blocks,
|
264
|
+
sample_size[1] // 2**num_cross_attn_blocks,
|
265
|
+
]
|
266
|
+
input_info.append(("mid_block_additional_residual", shape, "float32"))
|
304
267
|
|
305
268
|
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
306
269
|
|
307
270
|
if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
|
308
|
-
|
309
|
-
|
310
|
-
if rbln_in_features is None:
|
311
|
-
rbln_in_features = model_config.projection_class_embeddings_input_dim
|
271
|
+
rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
|
272
|
+
rbln_in_features = model_config.projection_class_embeddings_input_dim
|
312
273
|
rbln_compile_config.input_info.append(
|
313
|
-
("text_embeds", [
|
274
|
+
("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32")
|
314
275
|
)
|
315
|
-
rbln_compile_config.input_info.append(("time_ids", [
|
276
|
+
rbln_compile_config.input_info.append(("time_ids", [batch_size, 6], "float32"))
|
316
277
|
|
317
278
|
rbln_config = RBLNConfig(
|
318
279
|
rbln_cls=cls.__name__,
|
@@ -320,19 +281,15 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
320
281
|
rbln_kwargs=rbln_kwargs,
|
321
282
|
)
|
322
283
|
|
323
|
-
rbln_config.model_cfg.update(
|
324
|
-
{
|
325
|
-
"max_seq_len": rbln_max_seq_len,
|
326
|
-
"batch_size": rbln_batch_size,
|
327
|
-
"use_encode": rbln_use_encode,
|
328
|
-
}
|
329
|
-
)
|
330
|
-
|
331
284
|
if rbln_in_features is not None:
|
332
285
|
rbln_config.model_cfg["in_features"] = rbln_in_features
|
333
286
|
|
334
287
|
return rbln_config
|
335
288
|
|
289
|
+
@property
|
290
|
+
def compiled_batch_size(self):
|
291
|
+
return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
|
292
|
+
|
336
293
|
def forward(
|
337
294
|
self,
|
338
295
|
sample: torch.Tensor,
|
@@ -350,9 +307,18 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
350
307
|
return_dict: bool = True,
|
351
308
|
**kwargs,
|
352
309
|
):
|
353
|
-
|
354
|
-
|
355
|
-
|
310
|
+
sample_batch_size = sample.size()[0]
|
311
|
+
compiled_batch_size = self.compiled_batch_size
|
312
|
+
if sample_batch_size != compiled_batch_size and (
|
313
|
+
sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
|
314
|
+
):
|
315
|
+
raise ValueError(
|
316
|
+
f"Mismatch between UNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
|
317
|
+
"This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
|
318
|
+
"Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
|
319
|
+
"For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
|
320
|
+
)
|
321
|
+
|
356
322
|
added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
|
357
323
|
|
358
324
|
if down_block_additional_residuals is not None:
|
@@ -20,16 +20,64 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
+
from typing import TYPE_CHECKING
|
23
24
|
|
24
|
-
from .
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
25
|
+
from transformers.utils import _LazyModule
|
26
|
+
|
27
|
+
|
28
|
+
_import_structure = {
|
29
|
+
"controlnet": [
|
30
|
+
"RBLNMultiControlNetModel",
|
31
|
+
"RBLNStableDiffusionControlNetImg2ImgPipeline",
|
32
|
+
"RBLNStableDiffusionControlNetPipeline",
|
33
|
+
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
34
|
+
"RBLNStableDiffusionXLControlNetPipeline",
|
35
|
+
],
|
36
|
+
"stable_diffusion": [
|
37
|
+
"RBLNStableDiffusionImg2ImgPipeline",
|
38
|
+
"RBLNStableDiffusionPipeline",
|
39
|
+
"RBLNStableDiffusionInpaintPipeline",
|
40
|
+
],
|
41
|
+
"stable_diffusion_xl": [
|
42
|
+
"RBLNStableDiffusionXLImg2ImgPipeline",
|
43
|
+
"RBLNStableDiffusionXLPipeline",
|
44
|
+
"RBLNStableDiffusionXLInpaintPipeline",
|
45
|
+
],
|
46
|
+
"stable_diffusion_3": [
|
47
|
+
"RBLNStableDiffusion3Pipeline",
|
48
|
+
"RBLNStableDiffusion3Img2ImgPipeline",
|
49
|
+
"RBLNStableDiffusion3InpaintPipeline",
|
50
|
+
],
|
51
|
+
}
|
52
|
+
if TYPE_CHECKING:
|
53
|
+
from .controlnet import (
|
54
|
+
RBLNMultiControlNetModel,
|
55
|
+
RBLNStableDiffusionControlNetImg2ImgPipeline,
|
56
|
+
RBLNStableDiffusionControlNetPipeline,
|
57
|
+
RBLNStableDiffusionXLControlNetImg2ImgPipeline,
|
58
|
+
RBLNStableDiffusionXLControlNetPipeline,
|
59
|
+
)
|
60
|
+
from .stable_diffusion import (
|
61
|
+
RBLNStableDiffusionImg2ImgPipeline,
|
62
|
+
RBLNStableDiffusionInpaintPipeline,
|
63
|
+
RBLNStableDiffusionPipeline,
|
64
|
+
)
|
65
|
+
from .stable_diffusion_3 import (
|
66
|
+
RBLNStableDiffusion3Img2ImgPipeline,
|
67
|
+
RBLNStableDiffusion3InpaintPipeline,
|
68
|
+
RBLNStableDiffusion3Pipeline,
|
69
|
+
)
|
70
|
+
from .stable_diffusion_xl import (
|
71
|
+
RBLNStableDiffusionXLImg2ImgPipeline,
|
72
|
+
RBLNStableDiffusionXLInpaintPipeline,
|
73
|
+
RBLNStableDiffusionXLPipeline,
|
74
|
+
)
|
75
|
+
else:
|
76
|
+
import sys
|
77
|
+
|
78
|
+
sys.modules[__name__] = _LazyModule(
|
79
|
+
__name__,
|
80
|
+
globals()["__file__"],
|
81
|
+
_import_structure,
|
82
|
+
module_spec=__spec__,
|
83
|
+
)
|
@@ -27,12 +27,9 @@ from pathlib import Path
|
|
27
27
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
28
28
|
|
29
29
|
import torch
|
30
|
-
from diffusers import ControlNetModel
|
31
30
|
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
32
|
-
from optimum.exporters import TasksManager
|
33
|
-
from transformers import AutoConfig, AutoModel
|
34
31
|
|
35
|
-
from ....
|
32
|
+
from ....modeling import RBLNModel
|
36
33
|
from ....modeling_config import RBLNConfig
|
37
34
|
from ...models.controlnet import RBLNControlNetModel
|
38
35
|
|
@@ -44,6 +41,9 @@ logger = logging.getLogger(__name__)
|
|
44
41
|
|
45
42
|
|
46
43
|
class RBLNMultiControlNetModel(RBLNModel):
|
44
|
+
hf_library_name = "diffusers"
|
45
|
+
_hf_class = MultiControlNetModel
|
46
|
+
|
47
47
|
def __init__(
|
48
48
|
self,
|
49
49
|
models: List[RBLNControlNetModel],
|
@@ -52,26 +52,12 @@ class RBLNMultiControlNetModel(RBLNModel):
|
|
52
52
|
self.nets = models
|
53
53
|
self.dtype = torch.float32
|
54
54
|
|
55
|
-
@
|
56
|
-
def
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
):
|
62
|
-
return MultiControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
|
63
|
-
|
64
|
-
tasktmp = TasksManager.get_model_from_task
|
65
|
-
configtmp = AutoConfig.from_pretrained
|
66
|
-
modeltmp = AutoModel.from_pretrained
|
67
|
-
TasksManager.get_model_from_task = get_model_from_task
|
68
|
-
AutoConfig.from_pretrained = ControlNetModel.load_config
|
69
|
-
AutoModel.from_pretrained = MultiControlNetModel.from_pretrained
|
70
|
-
rt = super().from_pretrained(*args, **kwargs)
|
71
|
-
AutoConfig.from_pretrained = configtmp
|
72
|
-
AutoModel.from_pretrained = modeltmp
|
73
|
-
TasksManager.get_model_from_task = tasktmp
|
74
|
-
return rt
|
55
|
+
@property
|
56
|
+
def compiled_models(self):
|
57
|
+
cm = []
|
58
|
+
for net in self.nets:
|
59
|
+
cm.extend(net.compiled_models)
|
60
|
+
return cm
|
75
61
|
|
76
62
|
@classmethod
|
77
63
|
def _from_pretrained(
|
@@ -111,7 +97,7 @@ class RBLNMultiControlNetModel(RBLNModel):
|
|
111
97
|
sample: torch.FloatTensor,
|
112
98
|
timestep: Union[torch.Tensor, float, int],
|
113
99
|
encoder_hidden_states: torch.Tensor,
|
114
|
-
controlnet_cond: List[torch.
|
100
|
+
controlnet_cond: List[torch.Tensor],
|
115
101
|
conditioning_scale: List[float],
|
116
102
|
class_labels: Optional[torch.Tensor] = None,
|
117
103
|
timestep_cond: Optional[torch.Tensor] = None,
|