optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +27 -13
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +85 -65
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
- optimum/rbln/diffusers/pipelines/__init__.py +60 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +176 -763
- optimum/rbln/modeling_diffusers.py +329 -0
- optimum/rbln/transformers/__init__.py +2 -2
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
- optimum/rbln/utils/decorator_utils.py +59 -0
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +5 -5
- optimum/rbln/utils/submodule.py +114 -0
- optimum/rbln/utils/timer_utils.py +2 -2
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- optimum_rbln-0.1.15.dist-info/RECORD +110 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum_rbln-0.1.12.dist-info/METADATA +0 -119
- optimum_rbln-0.1.12.dist-info/RECORD +0 -103
- optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -26,203 +26,25 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|
26
26
|
|
27
27
|
import torch
|
28
28
|
import torch.nn.functional as F
|
29
|
-
from diffusers import
|
29
|
+
from diffusers import StableDiffusionControlNetPipeline
|
30
30
|
from diffusers.image_processor import PipelineImageInput
|
31
|
-
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
32
31
|
from diffusers.pipelines.controlnet.pipeline_controlnet import retrieve_timesteps
|
33
32
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
34
33
|
from diffusers.utils import deprecate, logging
|
35
34
|
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
|
36
|
-
from transformers import CLIPTextModel
|
37
35
|
|
38
|
-
from ....
|
39
|
-
from ....
|
40
|
-
from
|
41
|
-
from ...models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
|
36
|
+
from ....modeling_diffusers import RBLNDiffusionMixin
|
37
|
+
from ....utils.decorator_utils import remove_compile_time_kwargs
|
38
|
+
from ...models import RBLNControlNetModel
|
42
39
|
from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
|
43
40
|
|
44
41
|
|
45
42
|
logger = logging.get_logger(__name__)
|
46
43
|
|
47
44
|
|
48
|
-
class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
49
|
-
|
50
|
-
|
51
|
-
def from_pretrained(cls, model_id, **kwargs):
|
52
|
-
"""
|
53
|
-
Pipeline for text-to-image generation using Stable Diffusion with ControlNet.
|
54
|
-
|
55
|
-
This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods
|
56
|
-
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
57
|
-
|
58
|
-
It implements the methods to convert a pre-trained Stable Diffusion Controlnet pipeline into a RBLNStableDiffusionControlNet pipeline by:
|
59
|
-
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
60
|
-
- compiling the resulting graph using the RBLN compiler.
|
61
|
-
|
62
|
-
Args:
|
63
|
-
model_id (`Union[str, Path]`):
|
64
|
-
Can be either:
|
65
|
-
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
66
|
-
- A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
|
67
|
-
"""
|
68
|
-
export = kwargs.pop("export", None)
|
69
|
-
vae = kwargs.pop("vae", None)
|
70
|
-
unet = kwargs.pop("unet", None)
|
71
|
-
text_encoder = kwargs.pop("text_encoder", None)
|
72
|
-
controlnet = kwargs.pop("controlnet", None)
|
73
|
-
model_save_dir = kwargs.pop("model_save_dir", None)
|
74
|
-
rbln_config = kwargs.pop("rbln_config", None)
|
75
|
-
rbln_config = {} if rbln_config is None else rbln_config
|
76
|
-
|
77
|
-
device = rbln_config.get("device", None)
|
78
|
-
device_map = rbln_config.get("device_map", None)
|
79
|
-
create_runtimes = rbln_config.get("create_runtimes", None)
|
80
|
-
optimize_host_memory = rbln_config.get("optimize_host_memory", None)
|
81
|
-
|
82
|
-
kwargs_dict = {
|
83
|
-
"pretrained_model_name_or_path": model_id,
|
84
|
-
**kwargs,
|
85
|
-
}
|
86
|
-
|
87
|
-
kwargs_dict.update(
|
88
|
-
{
|
89
|
-
**({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
|
90
|
-
**({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
|
91
|
-
**(
|
92
|
-
{"text_encoder": text_encoder}
|
93
|
-
if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
|
94
|
-
else {}
|
95
|
-
),
|
96
|
-
**(
|
97
|
-
{"controlnet": controlnet}
|
98
|
-
if controlnet is not None
|
99
|
-
and (
|
100
|
-
isinstance(controlnet, ControlNetModel)
|
101
|
-
or all(isinstance(c, ControlNetModel) for c in controlnet)
|
102
|
-
)
|
103
|
-
else {}
|
104
|
-
),
|
105
|
-
}
|
106
|
-
)
|
107
|
-
|
108
|
-
with ContextRblnConfig(
|
109
|
-
device=device,
|
110
|
-
device_map=device_map,
|
111
|
-
create_runtimes=create_runtimes,
|
112
|
-
optimze_host_mem=optimize_host_memory,
|
113
|
-
):
|
114
|
-
model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
|
115
|
-
|
116
|
-
if export is None or export is False:
|
117
|
-
return model
|
118
|
-
|
119
|
-
do_classifier_free_guidance = (
|
120
|
-
rbln_config.pop("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
|
121
|
-
)
|
122
|
-
|
123
|
-
# compile model, create runtime
|
124
|
-
if not isinstance(vae, RBLNAutoencoderKL):
|
125
|
-
vae = RBLNAutoencoderKL.from_pretrained(
|
126
|
-
model_id=model_id,
|
127
|
-
subfolder="vae",
|
128
|
-
export=True,
|
129
|
-
model_save_dir=model_save_dir,
|
130
|
-
rbln_unet_sample_size=model.unet.config.sample_size,
|
131
|
-
rbln_use_encode=False,
|
132
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
133
|
-
rbln_config={**rbln_config},
|
134
|
-
)
|
135
|
-
|
136
|
-
if not isinstance(text_encoder, RBLNCLIPTextModel):
|
137
|
-
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
138
|
-
model_id=model_id,
|
139
|
-
subfolder="text_encoder",
|
140
|
-
export=True,
|
141
|
-
model_save_dir=model_save_dir,
|
142
|
-
rbln_config={**rbln_config},
|
143
|
-
)
|
144
|
-
|
145
|
-
batch_size = rbln_config.pop("batch_size", 1)
|
146
|
-
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
147
|
-
|
148
|
-
if not isinstance(unet, RBLNUNet2DConditionModel):
|
149
|
-
unet = RBLNUNet2DConditionModel.from_pretrained(
|
150
|
-
model_id=model_id,
|
151
|
-
subfolder="unet",
|
152
|
-
export=True,
|
153
|
-
model_save_dir=model_save_dir,
|
154
|
-
rbln_max_seq_len=text_encoder.config.max_position_embeddings,
|
155
|
-
rbln_batch_size=unet_batch_size,
|
156
|
-
rbln_use_encode=False,
|
157
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
158
|
-
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
159
|
-
rbln_config={**rbln_config},
|
160
|
-
)
|
161
|
-
|
162
|
-
if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
|
163
|
-
if isinstance(controlnet, (list, tuple)):
|
164
|
-
multicontrolnet = []
|
165
|
-
for i, cid in enumerate(controlnet):
|
166
|
-
subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
|
167
|
-
multicontrolnet.append(
|
168
|
-
RBLNControlNetModel.from_model(
|
169
|
-
model=cid,
|
170
|
-
subfolder=subfolder_name,
|
171
|
-
model_save_dir=model_save_dir,
|
172
|
-
rbln_batch_size=unet_batch_size,
|
173
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
174
|
-
rbln_config={**rbln_config},
|
175
|
-
)
|
176
|
-
)
|
177
|
-
controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
|
178
|
-
controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
|
179
|
-
else:
|
180
|
-
controlnet = RBLNControlNetModel.from_model(
|
181
|
-
model=controlnet,
|
182
|
-
subfolder="controlnet",
|
183
|
-
model_save_dir=model_save_dir,
|
184
|
-
rbln_batch_size=unet_batch_size,
|
185
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
186
|
-
rbln_config={**rbln_config},
|
187
|
-
)
|
188
|
-
controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
|
189
|
-
|
190
|
-
if model_save_dir is not None:
|
191
|
-
# To skip saving original pytorch modules
|
192
|
-
del (model.vae, model.text_encoder, model.unet, model.controlnet)
|
193
|
-
|
194
|
-
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
195
|
-
# So config must be saved again, later.
|
196
|
-
model.save_pretrained(model_save_dir)
|
197
|
-
|
198
|
-
# replace modules
|
199
|
-
model.vae = vae
|
200
|
-
model.text_encoder = text_encoder
|
201
|
-
model.unet = unet
|
202
|
-
model.controlnet = controlnet
|
203
|
-
|
204
|
-
# update config to be able to load from file.
|
205
|
-
update_dict = {
|
206
|
-
"vae": ("optimum.rbln", "RBLNAutoencoderKL"),
|
207
|
-
"text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
|
208
|
-
"unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
|
209
|
-
"controlnet": controlnet_dict,
|
210
|
-
}
|
211
|
-
model.register_to_config(**update_dict)
|
212
|
-
|
213
|
-
if model_save_dir is not None:
|
214
|
-
# overwrite to replace incorrect config
|
215
|
-
model.save_config(model_save_dir)
|
216
|
-
|
217
|
-
if optimize_host_memory is False:
|
218
|
-
model.compiled_models = [vae.compiled_models[0], text_encoder.compiled_models[0], unet.compiled_models[0]]
|
219
|
-
if isinstance(controlnet, RBLNMultiControlNetModel):
|
220
|
-
for c_model in controlnet.nets:
|
221
|
-
model.compiled_models.append(c_model.compiled_models[0])
|
222
|
-
else:
|
223
|
-
model.compiled_models.append(controlnet.compiled_models[0])
|
224
|
-
|
225
|
-
return model
|
45
|
+
class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionControlNetPipeline):
|
46
|
+
original_class = StableDiffusionControlNetPipeline
|
47
|
+
_submodules = ["text_encoder", "unet", "vae", "controlnet"]
|
226
48
|
|
227
49
|
def check_inputs(
|
228
50
|
self,
|
@@ -388,6 +210,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
388
210
|
)
|
389
211
|
|
390
212
|
@torch.no_grad()
|
213
|
+
@remove_compile_time_kwargs
|
391
214
|
def __call__(
|
392
215
|
self,
|
393
216
|
prompt: Union[str, List[str]] = None,
|
@@ -597,6 +420,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
|
597
420
|
text_encoder_lora_scale = (
|
598
421
|
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
599
422
|
)
|
423
|
+
|
600
424
|
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
601
425
|
prompt,
|
602
426
|
device,
|
@@ -26,207 +26,24 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|
26
26
|
|
27
27
|
import torch
|
28
28
|
import torch.nn.functional as F
|
29
|
-
from diffusers import
|
29
|
+
from diffusers import StableDiffusionControlNetImg2ImgPipeline
|
30
30
|
from diffusers.image_processor import PipelineImageInput
|
31
|
-
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
32
31
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
33
32
|
from diffusers.utils import deprecate, logging
|
34
33
|
from diffusers.utils.torch_utils import is_compiled_module
|
35
|
-
from transformers import CLIPTextModel
|
36
34
|
|
37
|
-
from ....
|
38
|
-
from ....
|
39
|
-
from
|
40
|
-
from ...models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
|
35
|
+
from ....modeling_diffusers import RBLNDiffusionMixin
|
36
|
+
from ....utils.decorator_utils import remove_compile_time_kwargs
|
37
|
+
from ...models import RBLNControlNetModel
|
41
38
|
from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
|
42
39
|
|
43
40
|
|
44
41
|
logger = logging.get_logger(__name__)
|
45
42
|
|
46
43
|
|
47
|
-
class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
48
|
-
|
49
|
-
|
50
|
-
def from_pretrained(cls, model_id, **kwargs):
|
51
|
-
"""
|
52
|
-
Pipeline for image-to-image generation using Stable Diffusion with ControlNet.
|
53
|
-
|
54
|
-
This model inherits from [`StableDiffusionControlNetImg2ImgPipeline`]. Check the superclass documentation for the generic methods
|
55
|
-
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
56
|
-
|
57
|
-
It implements the methods to convert a pre-trained Stable Diffusion Controlnet pipeline into a RBLNStableDiffusionControlNetImg2Img pipeline by:
|
58
|
-
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
59
|
-
- compiling the resulting graph using the RBLN compiler.
|
60
|
-
|
61
|
-
Args:
|
62
|
-
model_id (`Union[str, Path]`):
|
63
|
-
Can be either:
|
64
|
-
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
65
|
-
- A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
|
66
|
-
"""
|
67
|
-
export = kwargs.pop("export", None)
|
68
|
-
vae = kwargs.pop("vae", None)
|
69
|
-
unet = kwargs.pop("unet", None)
|
70
|
-
text_encoder = kwargs.pop("text_encoder", None)
|
71
|
-
controlnet = kwargs.pop("controlnet", None)
|
72
|
-
model_save_dir = kwargs.pop("model_save_dir", None)
|
73
|
-
rbln_config = kwargs.pop("rbln_config", None)
|
74
|
-
rbln_config = {} if rbln_config is None else rbln_config
|
75
|
-
|
76
|
-
device = rbln_config.get("device", None)
|
77
|
-
device_map = rbln_config.get("device_map", None)
|
78
|
-
create_runtimes = rbln_config.get("create_runtimes", None)
|
79
|
-
optimize_host_memory = rbln_config.get("optimize_host_memory", None)
|
80
|
-
|
81
|
-
kwargs_dict = {
|
82
|
-
"pretrained_model_name_or_path": model_id,
|
83
|
-
**kwargs,
|
84
|
-
}
|
85
|
-
|
86
|
-
kwargs_dict.update(
|
87
|
-
{
|
88
|
-
**({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
|
89
|
-
**({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
|
90
|
-
**(
|
91
|
-
{"text_encoder": text_encoder}
|
92
|
-
if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
|
93
|
-
else {}
|
94
|
-
),
|
95
|
-
**(
|
96
|
-
{"controlnet": controlnet}
|
97
|
-
if controlnet is not None
|
98
|
-
and (
|
99
|
-
isinstance(controlnet, ControlNetModel)
|
100
|
-
or all(isinstance(c, ControlNetModel) for c in controlnet)
|
101
|
-
)
|
102
|
-
else {}
|
103
|
-
),
|
104
|
-
}
|
105
|
-
)
|
106
|
-
|
107
|
-
with ContextRblnConfig(
|
108
|
-
device=device,
|
109
|
-
device_map=device_map,
|
110
|
-
create_runtimes=create_runtimes,
|
111
|
-
optimze_host_mem=optimize_host_memory,
|
112
|
-
):
|
113
|
-
model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
|
114
|
-
|
115
|
-
if export is None or export is False:
|
116
|
-
return model
|
117
|
-
|
118
|
-
do_classifier_free_guidance = (
|
119
|
-
rbln_config.pop("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
|
120
|
-
)
|
121
|
-
|
122
|
-
# compile model, create runtime
|
123
|
-
if not isinstance(vae, RBLNAutoencoderKL):
|
124
|
-
vae = RBLNAutoencoderKL.from_pretrained(
|
125
|
-
model_id=model_id,
|
126
|
-
subfolder="vae",
|
127
|
-
export=True,
|
128
|
-
model_save_dir=model_save_dir,
|
129
|
-
rbln_unet_sample_size=model.unet.config.sample_size,
|
130
|
-
rbln_use_encode=True,
|
131
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
132
|
-
rbln_config={**rbln_config},
|
133
|
-
)
|
134
|
-
|
135
|
-
if not isinstance(text_encoder, RBLNCLIPTextModel):
|
136
|
-
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
137
|
-
model_id=model_id,
|
138
|
-
subfolder="text_encoder",
|
139
|
-
export=True,
|
140
|
-
model_save_dir=model_save_dir,
|
141
|
-
rbln_config={**rbln_config},
|
142
|
-
)
|
143
|
-
|
144
|
-
batch_size = rbln_config.pop("batch_size", 1)
|
145
|
-
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
146
|
-
|
147
|
-
if not isinstance(unet, RBLNUNet2DConditionModel):
|
148
|
-
unet = RBLNUNet2DConditionModel.from_pretrained(
|
149
|
-
model_id=model_id,
|
150
|
-
subfolder="unet",
|
151
|
-
export=True,
|
152
|
-
model_save_dir=model_save_dir,
|
153
|
-
rbln_max_seq_len=text_encoder.config.max_position_embeddings,
|
154
|
-
rbln_batch_size=unet_batch_size,
|
155
|
-
rbln_use_encode=True,
|
156
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
157
|
-
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
158
|
-
rbln_config={**rbln_config},
|
159
|
-
)
|
160
|
-
|
161
|
-
if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
|
162
|
-
if isinstance(controlnet, (list, tuple)):
|
163
|
-
multicontrolnet = []
|
164
|
-
for i, cid in enumerate(controlnet):
|
165
|
-
subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
|
166
|
-
multicontrolnet.append(
|
167
|
-
RBLNControlNetModel.from_model(
|
168
|
-
model=cid,
|
169
|
-
subfolder=subfolder_name,
|
170
|
-
model_save_dir=model_save_dir,
|
171
|
-
rbln_batch_size=unet_batch_size,
|
172
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
173
|
-
rbln_config={**rbln_config},
|
174
|
-
)
|
175
|
-
)
|
176
|
-
controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
|
177
|
-
controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
|
178
|
-
else:
|
179
|
-
controlnet = RBLNControlNetModel.from_model(
|
180
|
-
model=controlnet,
|
181
|
-
subfolder="controlnet",
|
182
|
-
model_save_dir=model_save_dir,
|
183
|
-
rbln_batch_size=unet_batch_size,
|
184
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
185
|
-
rbln_config={**rbln_config},
|
186
|
-
)
|
187
|
-
controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
|
188
|
-
|
189
|
-
if model_save_dir is not None:
|
190
|
-
# To skip saving original pytorch modules
|
191
|
-
del (model.vae, model.text_encoder, model.unet, model.controlnet)
|
192
|
-
|
193
|
-
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
194
|
-
# So config must be saved again, later.
|
195
|
-
model.save_pretrained(model_save_dir)
|
196
|
-
|
197
|
-
# replace modules
|
198
|
-
model.vae = vae
|
199
|
-
model.text_encoder = text_encoder
|
200
|
-
model.unet = unet
|
201
|
-
model.controlnet = controlnet
|
202
|
-
|
203
|
-
# update config to be able to load from file.
|
204
|
-
update_dict = {
|
205
|
-
"vae": ("optimum.rbln", "RBLNAutoencoderKL"),
|
206
|
-
"text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
|
207
|
-
"unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
|
208
|
-
"controlnet": controlnet_dict,
|
209
|
-
}
|
210
|
-
model.register_to_config(**update_dict)
|
211
|
-
|
212
|
-
if model_save_dir is not None:
|
213
|
-
# overwrite to replace incorrect config
|
214
|
-
model.save_config(model_save_dir)
|
215
|
-
|
216
|
-
if optimize_host_memory is False:
|
217
|
-
model.compiled_models = [
|
218
|
-
vae.compiled_models[0],
|
219
|
-
vae.compiled_models[1],
|
220
|
-
text_encoder.compiled_models[0],
|
221
|
-
unet.compiled_models[0],
|
222
|
-
]
|
223
|
-
if isinstance(controlnet, RBLNMultiControlNetModel):
|
224
|
-
for c_model in controlnet.nets:
|
225
|
-
model.compiled_models.append(c_model.compiled_models[0])
|
226
|
-
else:
|
227
|
-
model.compiled_models.append(controlnet.compiled_models[0])
|
228
|
-
|
229
|
-
return model
|
44
|
+
class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionControlNetImg2ImgPipeline):
|
45
|
+
original_class = StableDiffusionControlNetImg2ImgPipeline
|
46
|
+
_submodules = ["text_encoder", "unet", "vae", "controlnet"]
|
230
47
|
|
231
48
|
def check_inputs(
|
232
49
|
self,
|
@@ -386,6 +203,7 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2
|
|
386
203
|
)
|
387
204
|
|
388
205
|
@torch.no_grad()
|
206
|
+
@remove_compile_time_kwargs
|
389
207
|
def __call__(
|
390
208
|
self,
|
391
209
|
prompt: Union[str, List[str]] = None,
|
@@ -594,6 +412,7 @@ class RBLNStableDiffusionControlNetImg2ImgPipeline(StableDiffusionControlNetImg2
|
|
594
412
|
text_encoder_lora_scale = (
|
595
413
|
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
596
414
|
)
|
415
|
+
|
597
416
|
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
598
417
|
prompt,
|
599
418
|
device,
|
@@ -26,208 +26,24 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
26
26
|
|
27
27
|
import torch
|
28
28
|
import torch.nn.functional as F
|
29
|
-
from diffusers import
|
29
|
+
from diffusers import StableDiffusionXLControlNetPipeline
|
30
30
|
from diffusers.image_processor import PipelineImageInput
|
31
|
-
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
32
31
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
33
32
|
from diffusers.utils import deprecate, logging
|
34
33
|
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
|
35
|
-
from transformers import CLIPTextModel
|
36
34
|
|
37
|
-
from ....
|
38
|
-
from ....
|
39
|
-
from
|
40
|
-
from ...models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
|
35
|
+
from ....modeling_diffusers import RBLNDiffusionMixin
|
36
|
+
from ....utils.decorator_utils import remove_compile_time_kwargs
|
37
|
+
from ...models import RBLNControlNetModel
|
41
38
|
from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
|
42
39
|
|
43
40
|
|
44
41
|
logger = logging.get_logger(__name__)
|
45
42
|
|
46
43
|
|
47
|
-
class RBLNStableDiffusionXLControlNetPipeline(StableDiffusionXLControlNetPipeline):
|
48
|
-
|
49
|
-
|
50
|
-
def from_pretrained(cls, model_id, **kwargs):
|
51
|
-
"""
|
52
|
-
Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet.
|
53
|
-
|
54
|
-
This model inherits from [`StableDiffusionXLControlNetPipeline`]. Check the superclass documentation for the generic methods
|
55
|
-
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
56
|
-
|
57
|
-
It implements the methods to convert a pre-trained Stable Diffusion XL Controlnet pipeline into a RBLNStableDiffusionXLControlNet pipeline by:
|
58
|
-
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
59
|
-
- compiling the resulting graph using the RBLN compiler.
|
60
|
-
|
61
|
-
Args:
|
62
|
-
model_id (`Union[str, Path]`):
|
63
|
-
Can be either:
|
64
|
-
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
65
|
-
- A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
|
66
|
-
"""
|
67
|
-
export = kwargs.pop("export", None)
|
68
|
-
vae = kwargs.pop("vae", None)
|
69
|
-
unet = kwargs.pop("unet", None)
|
70
|
-
text_encoder = kwargs.pop("text_encoder", None)
|
71
|
-
text_encoder_2 = kwargs.pop("text_encoder_2", None)
|
72
|
-
controlnet = kwargs.pop("controlnet", None)
|
73
|
-
model_save_dir = kwargs.pop("model_save_dir", None)
|
74
|
-
rbln_config = kwargs.pop("rbln_config", None)
|
75
|
-
rbln_config = {} if rbln_config is None else rbln_config
|
76
|
-
|
77
|
-
device = rbln_config.get("device", None)
|
78
|
-
device_map = rbln_config.get("device_map", None)
|
79
|
-
create_runtimes = rbln_config.get("create_runtimes", None)
|
80
|
-
optimize_host_memory = rbln_config.get("optimize_host_memory", None)
|
81
|
-
|
82
|
-
kwargs_dict = {
|
83
|
-
"pretrained_model_name_or_path": model_id,
|
84
|
-
**kwargs,
|
85
|
-
}
|
86
|
-
|
87
|
-
kwargs_dict.update(
|
88
|
-
{
|
89
|
-
**({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
|
90
|
-
**({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
|
91
|
-
**(
|
92
|
-
{"text_encoder": text_encoder}
|
93
|
-
if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
|
94
|
-
else {}
|
95
|
-
),
|
96
|
-
**(
|
97
|
-
{"controlnet": controlnet}
|
98
|
-
if controlnet is not None
|
99
|
-
and (
|
100
|
-
isinstance(controlnet, ControlNetModel)
|
101
|
-
or all(isinstance(c, ControlNetModel) for c in controlnet)
|
102
|
-
)
|
103
|
-
else {}
|
104
|
-
),
|
105
|
-
}
|
106
|
-
)
|
107
|
-
|
108
|
-
with ContextRblnConfig(
|
109
|
-
device=device,
|
110
|
-
device_map=device_map,
|
111
|
-
create_runtimes=create_runtimes,
|
112
|
-
optimze_host_mem=optimize_host_memory,
|
113
|
-
):
|
114
|
-
model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
|
115
|
-
|
116
|
-
if export is None or export is False:
|
117
|
-
return model
|
118
|
-
|
119
|
-
do_classifier_free_guidance = (
|
120
|
-
rbln_config.pop("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
|
121
|
-
)
|
122
|
-
|
123
|
-
if not isinstance(vae, RBLNAutoencoderKL):
|
124
|
-
vae = RBLNAutoencoderKL.from_pretrained(
|
125
|
-
model_id=model_id,
|
126
|
-
subfolder="vae",
|
127
|
-
export=True,
|
128
|
-
model_save_dir=model_save_dir,
|
129
|
-
rbln_unet_sample_size=model.unet.config.sample_size,
|
130
|
-
rbln_use_encode=False,
|
131
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
132
|
-
rbln_config={**rbln_config},
|
133
|
-
)
|
134
|
-
|
135
|
-
if not isinstance(text_encoder, RBLNCLIPTextModel):
|
136
|
-
text_encoder = RBLNCLIPTextModel.from_pretrained(
|
137
|
-
model_id=model_id,
|
138
|
-
subfolder="text_encoder",
|
139
|
-
export=True,
|
140
|
-
model_save_dir=model_save_dir,
|
141
|
-
rbln_config={**rbln_config},
|
142
|
-
)
|
143
|
-
|
144
|
-
if not isinstance(text_encoder_2, RBLNCLIPTextModel):
|
145
|
-
text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
|
146
|
-
model_id=model_id,
|
147
|
-
subfolder="text_encoder_2",
|
148
|
-
export=True,
|
149
|
-
model_save_dir=model_save_dir,
|
150
|
-
rbln_config={**rbln_config},
|
151
|
-
)
|
152
|
-
|
153
|
-
batch_size = rbln_config.pop("batch_size", 1)
|
154
|
-
unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
|
155
|
-
|
156
|
-
if not isinstance(unet, RBLNUNet2DConditionModel):
|
157
|
-
unet = RBLNUNet2DConditionModel.from_pretrained(
|
158
|
-
model_id=model_id,
|
159
|
-
subfolder="unet",
|
160
|
-
export=True,
|
161
|
-
model_save_dir=model_save_dir,
|
162
|
-
rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
|
163
|
-
rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
|
164
|
-
rbln_batch_size=unet_batch_size,
|
165
|
-
rbln_use_encode=False,
|
166
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
167
|
-
rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
|
168
|
-
rbln_config={**rbln_config},
|
169
|
-
)
|
170
|
-
|
171
|
-
if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
|
172
|
-
if isinstance(controlnet, (list, tuple)):
|
173
|
-
multicontrolnet = []
|
174
|
-
for i, cid in enumerate(controlnet):
|
175
|
-
subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
|
176
|
-
multicontrolnet.append(
|
177
|
-
RBLNControlNetModel.from_model(
|
178
|
-
model=cid,
|
179
|
-
subfolder=subfolder_name,
|
180
|
-
model_save_dir=model_save_dir,
|
181
|
-
rbln_batch_size=unet_batch_size,
|
182
|
-
rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
|
183
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
184
|
-
rbln_config={**rbln_config},
|
185
|
-
)
|
186
|
-
)
|
187
|
-
controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
|
188
|
-
controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
|
189
|
-
else:
|
190
|
-
controlnet = RBLNControlNetModel.from_model(
|
191
|
-
model=controlnet,
|
192
|
-
subfolder="controlnet",
|
193
|
-
model_save_dir=model_save_dir,
|
194
|
-
rbln_batch_size=unet_batch_size,
|
195
|
-
rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
|
196
|
-
rbln_vae_scale_factor=model.vae_scale_factor,
|
197
|
-
rbln_config={**rbln_config},
|
198
|
-
)
|
199
|
-
controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
|
200
|
-
|
201
|
-
if model_save_dir is not None:
|
202
|
-
# To skip saving original pytorch modules
|
203
|
-
del (model.vae, model.text_encoder, model.unet, model.controlnet)
|
204
|
-
|
205
|
-
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
206
|
-
# So config must be saved again, later.
|
207
|
-
model.save_pretrained(model_save_dir)
|
208
|
-
|
209
|
-
# replace modules
|
210
|
-
model.vae = vae
|
211
|
-
model.text_encoder = text_encoder
|
212
|
-
model.unet = unet
|
213
|
-
model.text_encoder_2 = text_encoder_2
|
214
|
-
model.controlnet = controlnet
|
215
|
-
|
216
|
-
# update config to be able to load from file
|
217
|
-
update_dict = {
|
218
|
-
"vae": ("optimum.rbln", "RBLNAutoencoderKL"),
|
219
|
-
"text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
|
220
|
-
"unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
|
221
|
-
"text_encoder_2": ("optimum.rbln", "RBLNCLIPTextModel"),
|
222
|
-
"controlnet": controlnet_dict,
|
223
|
-
}
|
224
|
-
model.register_to_config(**update_dict)
|
225
|
-
|
226
|
-
if model_save_dir is not None:
|
227
|
-
# overwrite to replace incorrect config
|
228
|
-
model.save_config(model_save_dir)
|
229
|
-
|
230
|
-
return model
|
44
|
+
class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusionXLControlNetPipeline):
|
45
|
+
original_class = StableDiffusionXLControlNetPipeline
|
46
|
+
_submodules = ["text_encoder", "text_encoder_2", "unet", "vae", "controlnet"]
|
231
47
|
|
232
48
|
def check_inputs(
|
233
49
|
self,
|
@@ -419,6 +235,7 @@ class RBLNStableDiffusionXLControlNetPipeline(StableDiffusionXLControlNetPipelin
|
|
419
235
|
)
|
420
236
|
|
421
237
|
@torch.no_grad()
|
238
|
+
@remove_compile_time_kwargs
|
422
239
|
def __call__(
|
423
240
|
self,
|
424
241
|
prompt: Union[str, List[str]] = None,
|
@@ -682,6 +499,7 @@ class RBLNStableDiffusionXLControlNetPipeline(StableDiffusionXLControlNetPipelin
|
|
682
499
|
text_encoder_lora_scale = (
|
683
500
|
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
684
501
|
)
|
502
|
+
|
685
503
|
(
|
686
504
|
prompt_embeds,
|
687
505
|
negative_prompt_embeds,
|