optimum-rbln 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +14 -7
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +30 -63
- optimum/rbln/diffusers/models/controlnet.py +36 -62
- optimum/rbln/diffusers/models/unet_2d_condition.py +57 -156
- optimum/rbln/diffusers/pipelines/__init__.py +40 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -187
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -192
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -206
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -207
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -111
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -117
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -123
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -126
- optimum/rbln/modeling_alias.py +4 -9
- optimum/rbln/modeling_base.py +117 -144
- optimum/rbln/modeling_config.py +51 -0
- optimum/rbln/modeling_diffusers.py +400 -0
- optimum/rbln/transformers/__init__.py +10 -0
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -28
- optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
- optimum/rbln/transformers/models/clip/modeling_clip.py +13 -23
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +246 -116
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
- optimum/rbln/transformers/models/exaone/__init__.py +32 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +81 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +53 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +166 -151
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -28
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
- optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +43 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
- optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +108 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +38 -13
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
- optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
- optimum/rbln/utils/context.py +58 -0
- optimum/rbln/utils/decorator_utils.py +55 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/logging.py +1 -1
- optimum/rbln/utils/runtime_utils.py +4 -4
- optimum/rbln/utils/timer_utils.py +26 -2
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +11 -9
- optimum_rbln-0.1.13.dist-info/RECORD +107 -0
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.11.dist-info/RECORD +0 -93
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -23,16 +23,15 @@
|
|
23
23
|
|
24
24
|
import logging
|
25
25
|
from dataclasses import dataclass
|
26
|
-
from pathlib import Path
|
27
26
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
28
27
|
|
29
28
|
import torch
|
30
29
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
31
|
-
from
|
32
|
-
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
30
|
+
from transformers import PretrainedConfig
|
33
31
|
|
34
32
|
from ...modeling_base import RBLNModel
|
35
33
|
from ...modeling_config import RBLNCompileConfig, RBLNConfig
|
34
|
+
from ...utils.context import override_auto_classes
|
36
35
|
|
37
36
|
|
38
37
|
if TYPE_CHECKING:
|
@@ -126,9 +125,6 @@ class _UNet_SDXL(torch.nn.Module):
|
|
126
125
|
|
127
126
|
|
128
127
|
class RBLNUNet2DConditionModel(RBLNModel):
|
129
|
-
model_type = "rbln_model"
|
130
|
-
auto_model_class = AutoModel # feature extraction
|
131
|
-
|
132
128
|
def __post_init__(self, **kwargs):
|
133
129
|
super().__post_init__(**kwargs)
|
134
130
|
self.in_features = self.rbln_config.model_cfg.get("in_features", None)
|
@@ -146,29 +142,11 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
146
142
|
|
147
143
|
@classmethod
|
148
144
|
def from_pretrained(cls, *args, **kwargs):
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
**kwargs,
|
145
|
+
with override_auto_classes(
|
146
|
+
config_func=UNet2DConditionModel.load_config,
|
147
|
+
model_func=UNet2DConditionModel.from_pretrained,
|
153
148
|
):
|
154
|
-
|
155
|
-
|
156
|
-
tasktmp = TasksManager.get_model_from_task
|
157
|
-
configtmp = AutoConfig.from_pretrained
|
158
|
-
modeltmp = AutoModel.from_pretrained
|
159
|
-
TasksManager.get_model_from_task = get_model_from_task
|
160
|
-
if kwargs.get("export", None):
|
161
|
-
# This is an ad-hoc to workaround save null values of the config.
|
162
|
-
# if export, pure optimum(not optimum-rbln) loads config using AutoConfig
|
163
|
-
# and diffusers model do not support loading by AutoConfig.
|
164
|
-
AutoConfig.from_pretrained = lambda *args, **kwargs: None
|
165
|
-
else:
|
166
|
-
AutoConfig.from_pretrained = UNet2DConditionModel.load_config
|
167
|
-
AutoModel.from_pretrained = UNet2DConditionModel.from_pretrained
|
168
|
-
rt = super().from_pretrained(*args, **kwargs)
|
169
|
-
AutoConfig.from_pretrained = configtmp
|
170
|
-
AutoModel.from_pretrained = modeltmp
|
171
|
-
TasksManager.get_model_from_task = tasktmp
|
149
|
+
rt = super().from_pretrained(*args, **kwargs)
|
172
150
|
return rt
|
173
151
|
|
174
152
|
@classmethod
|
@@ -185,137 +163,68 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
185
163
|
model_config: "PretrainedConfig",
|
186
164
|
rbln_kwargs: Dict[str, Any] = {},
|
187
165
|
) -> RBLNConfig:
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
if
|
201
|
-
|
202
|
-
|
203
|
-
if
|
204
|
-
|
205
|
-
raise ValueError(
|
206
|
-
"rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided when rbln_use_encode is True"
|
207
|
-
)
|
208
|
-
input_width = rbln_img_width // rbln_vae_scale_factor
|
209
|
-
input_height = rbln_img_height // rbln_vae_scale_factor
|
210
|
-
else:
|
211
|
-
input_width, input_height = model_config.sample_size, model_config.sample_size
|
166
|
+
batch_size = rbln_kwargs.get("batch_size")
|
167
|
+
max_seq_len = rbln_kwargs.get("max_seq_len")
|
168
|
+
sample_size = rbln_kwargs.get("sample_size")
|
169
|
+
is_controlnet = rbln_kwargs.get("is_controlnet")
|
170
|
+
rbln_in_features = None
|
171
|
+
|
172
|
+
if batch_size is None:
|
173
|
+
batch_size = 1
|
174
|
+
|
175
|
+
if sample_size is None:
|
176
|
+
sample_size = model_config.sample_size
|
177
|
+
|
178
|
+
if isinstance(sample_size, int):
|
179
|
+
sample_size = (sample_size, sample_size)
|
180
|
+
|
181
|
+
if max_seq_len is None:
|
182
|
+
raise ValueError("`rbln_max_seq_len` (ex. text_encoder's max_position_embeddings )must be specified")
|
212
183
|
|
213
184
|
input_info = [
|
214
|
-
(
|
215
|
-
"sample",
|
216
|
-
[
|
217
|
-
rbln_batch_size,
|
218
|
-
model_config.in_channels,
|
219
|
-
input_height,
|
220
|
-
input_width,
|
221
|
-
],
|
222
|
-
"float32",
|
223
|
-
),
|
185
|
+
("sample", [batch_size, model_config.in_channels, sample_size[0], sample_size[1]], "float32"),
|
224
186
|
("timestep", [], "float32"),
|
225
|
-
(
|
226
|
-
"encoder_hidden_states",
|
227
|
-
[
|
228
|
-
rbln_batch_size,
|
229
|
-
rbln_max_seq_len,
|
230
|
-
model_config.cross_attention_dim,
|
231
|
-
],
|
232
|
-
"float32",
|
233
|
-
),
|
187
|
+
("encoder_hidden_states", [batch_size, max_seq_len, model_config.cross_attention_dim], "float32"),
|
234
188
|
]
|
235
189
|
|
236
|
-
if
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
)
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
]
|
265
|
-
)
|
266
|
-
if len(model_config.block_out_channels) > 2:
|
267
|
-
input_info.append(
|
268
|
-
(
|
269
|
-
f"down_block_additional_residuals_{6}",
|
270
|
-
[rbln_batch_size, model_config.block_out_channels[1], input_height // 4, input_width // 4],
|
271
|
-
"float32",
|
272
|
-
)
|
273
|
-
)
|
274
|
-
input_info.extend(
|
275
|
-
[
|
276
|
-
(
|
277
|
-
f"down_block_additional_residuals_{i}",
|
278
|
-
[rbln_batch_size, model_config.block_out_channels[2], input_height // 4, input_width // 4],
|
279
|
-
"float32",
|
280
|
-
)
|
281
|
-
for i in range(7, 9)
|
282
|
-
]
|
283
|
-
)
|
284
|
-
if len(model_config.block_out_channels) > 3:
|
285
|
-
input_info.extend(
|
286
|
-
[
|
287
|
-
(
|
288
|
-
f"down_block_additional_residuals_{i}",
|
289
|
-
[rbln_batch_size, model_config.block_out_channels[3], input_height // 8, input_width // 8],
|
290
|
-
"float32",
|
291
|
-
)
|
292
|
-
for i in range(9, 12)
|
293
|
-
]
|
294
|
-
)
|
295
|
-
input_info.append(
|
296
|
-
(
|
297
|
-
"mid_block_additional_residual",
|
298
|
-
[
|
299
|
-
rbln_batch_size,
|
300
|
-
model_config.block_out_channels[-1],
|
301
|
-
input_height // 2 ** (len(model_config.block_out_channels) - 1),
|
302
|
-
input_width // 2 ** (len(model_config.block_out_channels) - 1),
|
303
|
-
],
|
304
|
-
"float32",
|
305
|
-
)
|
306
|
-
)
|
190
|
+
if is_controlnet:
|
191
|
+
# down block addtional residuals
|
192
|
+
first_shape = [batch_size, model_config.block_out_channels[0], sample_size[0], sample_size[1]]
|
193
|
+
height, width = sample_size[0], sample_size[1]
|
194
|
+
input_info.append(("down_block_additional_residuals_0", first_shape, "float32"))
|
195
|
+
name_idx = 1
|
196
|
+
for idx, _ in enumerate(model_config.down_block_types):
|
197
|
+
shape = [batch_size, model_config.block_out_channels[idx], height, width]
|
198
|
+
for _ in range(model_config.layers_per_block):
|
199
|
+
input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
|
200
|
+
name_idx += 1
|
201
|
+
if idx != len(model_config.down_block_types) - 1:
|
202
|
+
height = height // 2
|
203
|
+
width = width // 2
|
204
|
+
shape = [batch_size, model_config.block_out_channels[idx], height, width]
|
205
|
+
input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
|
206
|
+
name_idx += 1
|
207
|
+
|
208
|
+
# mid block addtional residual
|
209
|
+
num_cross_attn_blocks = model_config.down_block_types.count("CrossAttnDownBlock2D")
|
210
|
+
out_channels = model_config.block_out_channels[-1]
|
211
|
+
shape = [
|
212
|
+
batch_size,
|
213
|
+
out_channels,
|
214
|
+
sample_size[0] // 2**num_cross_attn_blocks,
|
215
|
+
sample_size[1] // 2**num_cross_attn_blocks,
|
216
|
+
]
|
217
|
+
input_info.append(("mid_block_additional_residual", shape, "float32"))
|
307
218
|
|
308
219
|
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
309
220
|
|
310
221
|
if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
|
311
|
-
|
312
|
-
|
313
|
-
if rbln_in_features is None:
|
314
|
-
rbln_in_features = model_config.projection_class_embeddings_input_dim
|
222
|
+
rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
|
223
|
+
rbln_in_features = model_config.projection_class_embeddings_input_dim
|
315
224
|
rbln_compile_config.input_info.append(
|
316
|
-
("text_embeds", [
|
225
|
+
("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32")
|
317
226
|
)
|
318
|
-
rbln_compile_config.input_info.append(("time_ids", [
|
227
|
+
rbln_compile_config.input_info.append(("time_ids", [batch_size, 6], "float32"))
|
319
228
|
|
320
229
|
rbln_config = RBLNConfig(
|
321
230
|
rbln_cls=cls.__name__,
|
@@ -323,14 +232,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
323
232
|
rbln_kwargs=rbln_kwargs,
|
324
233
|
)
|
325
234
|
|
326
|
-
rbln_config.model_cfg.update(
|
327
|
-
{
|
328
|
-
"max_seq_len": rbln_max_seq_len,
|
329
|
-
"batch_size": rbln_batch_size,
|
330
|
-
"use_encode": rbln_use_encode,
|
331
|
-
}
|
332
|
-
)
|
333
|
-
|
334
235
|
if rbln_in_features is not None:
|
335
236
|
rbln_config.model_cfg["in_features"] = rbln_in_features
|
336
237
|
|
@@ -20,16 +20,44 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
+
from typing import TYPE_CHECKING
|
23
24
|
|
24
|
-
from .
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
25
|
+
from transformers.utils import _LazyModule
|
26
|
+
|
27
|
+
|
28
|
+
_import_structure = {
|
29
|
+
"controlnet": [
|
30
|
+
"RBLNMultiControlNetModel",
|
31
|
+
"RBLNStableDiffusionControlNetImg2ImgPipeline",
|
32
|
+
"RBLNStableDiffusionControlNetPipeline",
|
33
|
+
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
34
|
+
"RBLNStableDiffusionXLControlNetPipeline",
|
35
|
+
],
|
36
|
+
"stable_diffusion": [
|
37
|
+
"RBLNStableDiffusionImg2ImgPipeline",
|
38
|
+
"RBLNStableDiffusionPipeline",
|
39
|
+
],
|
40
|
+
"stable_diffusion_xl": ["RBLNStableDiffusionXLImg2ImgPipeline", "RBLNStableDiffusionXLPipeline"],
|
41
|
+
}
|
42
|
+
if TYPE_CHECKING:
|
43
|
+
from .controlnet import (
|
44
|
+
RBLNMultiControlNetModel,
|
45
|
+
RBLNStableDiffusionControlNetImg2ImgPipeline,
|
46
|
+
RBLNStableDiffusionControlNetPipeline,
|
47
|
+
RBLNStableDiffusionXLControlNetImg2ImgPipeline,
|
48
|
+
RBLNStableDiffusionXLControlNetPipeline,
|
49
|
+
)
|
50
|
+
from .stable_diffusion import (
|
51
|
+
RBLNStableDiffusionImg2ImgPipeline,
|
52
|
+
RBLNStableDiffusionPipeline,
|
53
|
+
)
|
54
|
+
from .stable_diffusion_xl import RBLNStableDiffusionXLImg2ImgPipeline, RBLNStableDiffusionXLPipeline
|
55
|
+
else:
|
56
|
+
import sys
|
57
|
+
|
58
|
+
sys.modules[__name__] = _LazyModule(
|
59
|
+
__name__,
|
60
|
+
globals()["__file__"],
|
61
|
+
_import_structure,
|
62
|
+
module_spec=__spec__,
|
63
|
+
)
|
@@ -52,6 +52,13 @@ class RBLNMultiControlNetModel(RBLNModel):
|
|
52
52
|
self.nets = models
|
53
53
|
self.dtype = torch.float32
|
54
54
|
|
55
|
+
@property
|
56
|
+
def compiled_models(self):
|
57
|
+
cm = []
|
58
|
+
for net in self.nets:
|
59
|
+
cm.extend(net.compiled_models)
|
60
|
+
return cm
|
61
|
+
|
55
62
|
@classmethod
|
56
63
|
def from_pretrained(cls, *args, **kwargs):
|
57
64
|
def get_model_from_task(
|
@@ -102,6 +109,10 @@ class RBLNMultiControlNetModel(RBLNModel):
|
|
102
109
|
real_save_path = save_directory + suffix
|
103
110
|
model.save_pretrained(real_save_path)
|
104
111
|
|
112
|
+
@classmethod
|
113
|
+
def _get_rbln_config(cls, **rbln_config_kwargs):
|
114
|
+
pass
|
115
|
+
|
105
116
|
def forward(
|
106
117
|
self,
|
107
118
|
sample: torch.FloatTensor,
|
@@ -26,205 +26,25 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|
26
26
|
|
27
27
|
import torch
|
28
28
|
import torch.nn.functional as F
|
29
|
-
from diffusers import
|
29
|
+
from diffusers import StableDiffusionControlNetPipeline
|
30
30
|
from diffusers.image_processor import PipelineImageInput
|
31
|
-
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
32
31
|
from diffusers.pipelines.controlnet.pipeline_controlnet import retrieve_timesteps
|
33
32
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
34
33
|
from diffusers.utils import deprecate, logging
|
35
34
|
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
|
36
|
-
from transformers import CLIPTextModel
|
37
35
|
|
38
|
-
from ....
|
39
|
-
from ....
|
40
|
-
from
|
41
|
-
from ...models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
|
36
|
+
from ....modeling_diffusers import RBLNDiffusionMixin
|
37
|
+
from ....utils.decorator_utils import remove_compile_time_kwargs
|
38
|
+
from ...models import RBLNControlNetModel
|
42
39
|
from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
|
43
40
|
|
44
41
|
|
45
42
|
logger = logging.get_logger(__name__)
|
46
43
|
|
47
44
|
|
48
|
-
class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
49
|
-
|
50
|
-
|
51
|
-
"""
|
52
|
-
Pipeline for text-to-image generation using Stable Diffusion with ControlNet.
|
53
|
-
|
54
|
-
This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods
|
55
|
-
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
56
|
-
|
57
|
-
It implements the methods to convert a pre-trained Stable Diffusion Controlnet pipeline into a RBLNStableDiffusionControlNet pipeline by:
|
58
|
-
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
59
|
-
- compiling the resulting graph using the RBLN compiler.
|
60
|
-
|
61
|
-
Args:
|
62
|
-
model_id (`Union[str, Path]`):
|
63
|
-
Can be either:
|
64
|
-
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
65
|
-
- A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
|
66
|
-
"""
|
67
|
-
export = kwargs.pop("export", None)
|
68
|
-
vae = kwargs.pop("vae", None)
|
69
|
-
unet = kwargs.pop("unet", None)
|
70
|
-
text_encoder = kwargs.pop("text_encoder", None)
|
71
|
-
controlnet = kwargs.pop("controlnet", None)
|
72
|
-
model_save_dir = kwargs.pop("model_save_dir", None)
|
73
|
-
rbln_config = kwargs.pop("rbln_config", None)
|
74
|
-
rbln_kwargs, _ = RBLNBaseModel.resolve_rbln_config(rbln_config, kwargs)
|
75
|
-
|
76
|
-
device = rbln_kwargs.get("device", None)
|
77
|
-
device_map = rbln_kwargs.get("device_map", None)
|
78
|
-
create_runtimes = rbln_kwargs.get("create_runtimes", None)
|
79
|
-
optimize_host_memory = rbln_kwargs.get("optimize_host_memory", None)
|
80
|
-
|
81
|
-
kwargs_dict = {
|
82
|
-
"pretrained_model_name_or_path": model_id,
|
83
|
-
**kwargs,
|
84
|
-
}
|
85
|
-
|
86
|
-
kwargs_dict.update(
|
87
|
-
{
|
88
|
-
**({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
|
89
|
-
**({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
|
90
|
-
**(
|
91
|
-
{"text_encoder": text_encoder}
|
92
|
-
if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
|
93
|
-
else {}
|
94
|
-
),
|
95
|
-
**(
|
96
|
-
{"controlnet": controlnet}
|
97
|
-
if controlnet is not None
|
98
|
-
and (
|
99
|
-
isinstance(controlnet, ControlNetModel)
|
100
|
-
or all(isinstance(c, ControlNetModel) for c in controlnet)
|
101
|
-
)
|
102
|
-
else {}
|
103
|
-
),
|
104
|
-
}
|
105
|
-
)
|
106
|
-
|
107
|
-
with ContextRblnConfig(
|
108
|
-
device=device,
|
109
|
-
device_map=device_map,
|
110
|
-
create_runtimes=create_runtimes,
|
111
|
-
optimze_host_mem=optimize_host_memory,
|
112
|
-
):
|
113
|
-
model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
|
114
|
-
|
115
|
-
if export is None or export is False:
|
116
|
-
return model
|
117
|
-
|
118
|
-
do_classifier_free_guidance = (
|
119
|
-
rbln_kwargs.pop("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
|
120
|
-
)
|
121
|
-
|
122
|
-
# compile model, create runtime
|
123
|
-
if not isinstance(vae, RBLNAutoencoderKL):
|
124
|
-
vae = RBLNAutoencoderKL.from_pretrained(
|
125
|
-
model_id=model_id,
|
126
|
-
subfolder="vae",
|
127
|
-
export=True,
|
128
|
-
model_save_dir=model_save_dir,
|
129
|
-
rbln_unet_sample_size=model.unet.config.sample_size,
|
130
|
-
rbln_use_encode=False,
|
131
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
132
|
-
rbln_config={**rbln_kwargs},
|
133
|
-
)
|
134
|
-
|
135
|
-
if not isinstance(text_encoder, RBLNCLIPTextModel):
|
136
|
-
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
137
|
-
model_id=model_id,
|
138
|
-
subfolder="text_encoder",
|
139
|
-
export=True,
|
140
|
-
model_save_dir=model_save_dir,
|
141
|
-
rbln_config={**rbln_kwargs},
|
142
|
-
)
|
143
|
-
|
144
|
-
batch_size = rbln_kwargs.pop("batch_size", 1)
|
145
|
-
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
146
|
-
|
147
|
-
if not isinstance(unet, RBLNUNet2DConditionModel):
|
148
|
-
unet = RBLNUNet2DConditionModel.from_pretrained(
|
149
|
-
model_id=model_id,
|
150
|
-
subfolder="unet",
|
151
|
-
export=True,
|
152
|
-
model_save_dir=model_save_dir,
|
153
|
-
rbln_max_seq_len=text_encoder.config.max_position_embeddings,
|
154
|
-
rbln_batch_size=unet_batch_size,
|
155
|
-
rbln_use_encode=False,
|
156
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
157
|
-
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
158
|
-
rbln_config={**rbln_kwargs},
|
159
|
-
)
|
160
|
-
|
161
|
-
if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
|
162
|
-
if isinstance(controlnet, (list, tuple)):
|
163
|
-
multicontrolnet = []
|
164
|
-
for i, cid in enumerate(controlnet):
|
165
|
-
subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
|
166
|
-
multicontrolnet.append(
|
167
|
-
RBLNControlNetModel.from_pretrained(
|
168
|
-
model_id=cid.config._name_or_path,
|
169
|
-
subfolder=subfolder_name,
|
170
|
-
export=True,
|
171
|
-
model_save_dir=model_save_dir,
|
172
|
-
rbln_batch_size=unet_batch_size,
|
173
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
174
|
-
rbln_config={**rbln_kwargs},
|
175
|
-
)
|
176
|
-
)
|
177
|
-
controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
|
178
|
-
controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
|
179
|
-
else:
|
180
|
-
controlnet = RBLNControlNetModel.from_pretrained(
|
181
|
-
model_id=controlnet.config._name_or_path,
|
182
|
-
subfolder="controlnet",
|
183
|
-
export=True,
|
184
|
-
model_save_dir=model_save_dir,
|
185
|
-
rbln_batch_size=unet_batch_size,
|
186
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
187
|
-
rbln_config={**rbln_kwargs},
|
188
|
-
)
|
189
|
-
controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
|
190
|
-
|
191
|
-
if model_save_dir is not None:
|
192
|
-
# To skip saving original pytorch modules
|
193
|
-
del (model.vae, model.text_encoder, model.unet, model.controlnet)
|
194
|
-
|
195
|
-
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
196
|
-
# So config must be saved again, later.
|
197
|
-
model.save_pretrained(model_save_dir)
|
198
|
-
|
199
|
-
# replace modules
|
200
|
-
model.vae = vae
|
201
|
-
model.text_encoder = text_encoder
|
202
|
-
model.unet = unet
|
203
|
-
model.controlnet = controlnet
|
204
|
-
|
205
|
-
# update config to be able to load from file.
|
206
|
-
update_dict = {
|
207
|
-
"vae": ("optimum.rbln", "RBLNAutoencoderKL"),
|
208
|
-
"text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
|
209
|
-
"unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
|
210
|
-
"controlnet": controlnet_dict,
|
211
|
-
}
|
212
|
-
model.register_to_config(**update_dict)
|
213
|
-
|
214
|
-
if model_save_dir is not None:
|
215
|
-
# overwrite to replace incorrect config
|
216
|
-
model.save_config(model_save_dir)
|
217
|
-
|
218
|
-
# use for CI to access each compiled model
|
219
|
-
if optimize_host_memory is False:
|
220
|
-
model.compiled_models = [vae.compiled_models[0], text_encoder.compiled_models[0], unet.compiled_models[0]]
|
221
|
-
if isinstance(controlnet, RBLNMultiControlNetModel):
|
222
|
-
for c_model in controlnet.nets:
|
223
|
-
model.compiled_models.append(c_model.compiled_models[0])
|
224
|
-
else:
|
225
|
-
model.compiled_models.append(controlnet.compiled_models[0])
|
226
|
-
|
227
|
-
return model
|
45
|
+
class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionControlNetPipeline):
|
46
|
+
original_class = StableDiffusionControlNetPipeline
|
47
|
+
_submodules = ["text_encoder", "unet", "vae", "controlnet"]
|
228
48
|
|
229
49
|
def check_inputs(
|
230
50
|
self,
|
@@ -390,6 +210,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
390
210
|
)
|
391
211
|
|
392
212
|
@torch.no_grad()
|
213
|
+
@remove_compile_time_kwargs
|
393
214
|
def __call__(
|
394
215
|
self,
|
395
216
|
prompt: Union[str, List[str]] = None,
|
@@ -599,6 +420,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
599
420
|
text_encoder_lora_scale = (
|
600
421
|
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
601
422
|
)
|
423
|
+
|
602
424
|
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
603
425
|
prompt,
|
604
426
|
device,
|