diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +7 -2
- diffusers/configuration_utils.py +4 -0
- diffusers/loaders.py +262 -12
- diffusers/models/attention.py +31 -12
- diffusers/models/attention_processor.py +189 -0
- diffusers/models/controlnet.py +9 -2
- diffusers/models/embeddings.py +66 -0
- diffusers/models/modeling_pytorch_flax_utils.py +6 -0
- diffusers/models/modeling_utils.py +5 -2
- diffusers/models/transformer_2d.py +1 -1
- diffusers/models/unet_2d_condition.py +45 -6
- diffusers/models/vae.py +3 -0
- diffusers/pipelines/__init__.py +8 -0
- diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
- diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
- diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
- diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
- diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
- diffusers/pipelines/pipeline_utils.py +54 -25
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
- diffusers/schedulers/scheduling_ddpm.py +63 -16
- diffusers/schedulers/scheduling_heun_discrete.py +51 -1
- diffusers/utils/__init__.py +4 -1
- diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/hub_utils.py +4 -1
- diffusers/utils/import_utils.py +41 -0
- diffusers/utils/pil_utils.py +24 -0
- diffusers/utils/testing_utils.py +10 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,854 @@
|
|
1
|
+
import html
|
2
|
+
import inspect
|
3
|
+
import re
|
4
|
+
import urllib.parse as ul
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
6
|
+
|
7
|
+
import torch
|
8
|
+
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
9
|
+
|
10
|
+
from ...models import UNet2DConditionModel
|
11
|
+
from ...schedulers import DDPMScheduler
|
12
|
+
from ...utils import (
|
13
|
+
BACKENDS_MAPPING,
|
14
|
+
is_accelerate_available,
|
15
|
+
is_accelerate_version,
|
16
|
+
is_bs4_available,
|
17
|
+
is_ftfy_available,
|
18
|
+
logging,
|
19
|
+
randn_tensor,
|
20
|
+
replace_example_docstring,
|
21
|
+
)
|
22
|
+
from ..pipeline_utils import DiffusionPipeline
|
23
|
+
from . import IFPipelineOutput
|
24
|
+
from .safety_checker import IFSafetyChecker
|
25
|
+
from .watermark import IFWatermarker
|
26
|
+
|
27
|
+
|
28
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
29
|
+
|
30
|
+
if is_bs4_available():
|
31
|
+
from bs4 import BeautifulSoup
|
32
|
+
|
33
|
+
if is_ftfy_available():
|
34
|
+
import ftfy
|
35
|
+
|
36
|
+
|
37
|
+
EXAMPLE_DOC_STRING = """
|
38
|
+
Examples:
|
39
|
+
```py
|
40
|
+
>>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline
|
41
|
+
>>> from diffusers.utils import pt_to_pil
|
42
|
+
>>> import torch
|
43
|
+
|
44
|
+
>>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
|
45
|
+
>>> pipe.enable_model_cpu_offload()
|
46
|
+
|
47
|
+
>>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
|
48
|
+
>>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
|
49
|
+
|
50
|
+
>>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images
|
51
|
+
|
52
|
+
>>> # save intermediate image
|
53
|
+
>>> pil_image = pt_to_pil(image)
|
54
|
+
>>> pil_image[0].save("./if_stage_I.png")
|
55
|
+
|
56
|
+
>>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained(
|
57
|
+
... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
|
58
|
+
... )
|
59
|
+
>>> super_res_1_pipe.enable_model_cpu_offload()
|
60
|
+
|
61
|
+
>>> image = super_res_1_pipe(
|
62
|
+
... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt"
|
63
|
+
... ).images
|
64
|
+
|
65
|
+
>>> # save intermediate image
|
66
|
+
>>> pil_image = pt_to_pil(image)
|
67
|
+
>>> pil_image[0].save("./if_stage_I.png")
|
68
|
+
|
69
|
+
>>> safety_modules = {
|
70
|
+
... "feature_extractor": pipe.feature_extractor,
|
71
|
+
... "safety_checker": pipe.safety_checker,
|
72
|
+
... "watermarker": pipe.watermarker,
|
73
|
+
... }
|
74
|
+
>>> super_res_2_pipe = DiffusionPipeline.from_pretrained(
|
75
|
+
... "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
|
76
|
+
... )
|
77
|
+
>>> super_res_2_pipe.enable_model_cpu_offload()
|
78
|
+
|
79
|
+
>>> image = super_res_2_pipe(
|
80
|
+
... prompt=prompt,
|
81
|
+
... image=image,
|
82
|
+
... ).images
|
83
|
+
>>> image[0].save("./if_stage_II.png")
|
84
|
+
```
|
85
|
+
"""
|
86
|
+
|
87
|
+
|
88
|
+
class IFPipeline(DiffusionPipeline):
|
89
|
+
tokenizer: T5Tokenizer
|
90
|
+
text_encoder: T5EncoderModel
|
91
|
+
|
92
|
+
unet: UNet2DConditionModel
|
93
|
+
scheduler: DDPMScheduler
|
94
|
+
|
95
|
+
feature_extractor: Optional[CLIPImageProcessor]
|
96
|
+
safety_checker: Optional[IFSafetyChecker]
|
97
|
+
|
98
|
+
watermarker: Optional[IFWatermarker]
|
99
|
+
|
100
|
+
bad_punct_regex = re.compile(
|
101
|
+
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
102
|
+
) # noqa
|
103
|
+
|
104
|
+
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
|
105
|
+
|
106
|
+
def __init__(
|
107
|
+
self,
|
108
|
+
tokenizer: T5Tokenizer,
|
109
|
+
text_encoder: T5EncoderModel,
|
110
|
+
unet: UNet2DConditionModel,
|
111
|
+
scheduler: DDPMScheduler,
|
112
|
+
safety_checker: Optional[IFSafetyChecker],
|
113
|
+
feature_extractor: Optional[CLIPImageProcessor],
|
114
|
+
watermarker: Optional[IFWatermarker],
|
115
|
+
requires_safety_checker: bool = True,
|
116
|
+
):
|
117
|
+
super().__init__()
|
118
|
+
|
119
|
+
if safety_checker is None and requires_safety_checker:
|
120
|
+
logger.warning(
|
121
|
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
122
|
+
" that you abide to the conditions of the IF license and do not expose unfiltered"
|
123
|
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
124
|
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
125
|
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
126
|
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
127
|
+
)
|
128
|
+
|
129
|
+
if safety_checker is not None and feature_extractor is None:
|
130
|
+
raise ValueError(
|
131
|
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
132
|
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
133
|
+
)
|
134
|
+
|
135
|
+
self.register_modules(
|
136
|
+
tokenizer=tokenizer,
|
137
|
+
text_encoder=text_encoder,
|
138
|
+
unet=unet,
|
139
|
+
scheduler=scheduler,
|
140
|
+
safety_checker=safety_checker,
|
141
|
+
feature_extractor=feature_extractor,
|
142
|
+
watermarker=watermarker,
|
143
|
+
)
|
144
|
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
145
|
+
|
146
|
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
147
|
+
r"""
|
148
|
+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
149
|
+
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
150
|
+
when their specific submodule has its `forward` method called.
|
151
|
+
"""
|
152
|
+
if is_accelerate_available():
|
153
|
+
from accelerate import cpu_offload
|
154
|
+
else:
|
155
|
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
156
|
+
|
157
|
+
device = torch.device(f"cuda:{gpu_id}")
|
158
|
+
|
159
|
+
models = [
|
160
|
+
self.text_encoder,
|
161
|
+
self.unet,
|
162
|
+
]
|
163
|
+
for cpu_offloaded_model in models:
|
164
|
+
if cpu_offloaded_model is not None:
|
165
|
+
cpu_offload(cpu_offloaded_model, device)
|
166
|
+
|
167
|
+
if self.safety_checker is not None:
|
168
|
+
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
169
|
+
|
170
|
+
def enable_model_cpu_offload(self, gpu_id=0):
|
171
|
+
r"""
|
172
|
+
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
173
|
+
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
174
|
+
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
175
|
+
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
176
|
+
"""
|
177
|
+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
178
|
+
from accelerate import cpu_offload_with_hook
|
179
|
+
else:
|
180
|
+
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
181
|
+
|
182
|
+
device = torch.device(f"cuda:{gpu_id}")
|
183
|
+
|
184
|
+
if self.device.type != "cpu":
|
185
|
+
self.to("cpu", silence_dtype_warnings=True)
|
186
|
+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
187
|
+
|
188
|
+
hook = None
|
189
|
+
|
190
|
+
if self.text_encoder is not None:
|
191
|
+
_, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
|
192
|
+
|
193
|
+
# Accelerate will move the next model to the device _before_ calling the offload hook of the
|
194
|
+
# previous model. This will cause both models to be present on the device at the same time.
|
195
|
+
# IF uses T5 for its text encoder which is really large. We can manually call the offload
|
196
|
+
# hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
|
197
|
+
# the GPU.
|
198
|
+
self.text_encoder_offload_hook = hook
|
199
|
+
|
200
|
+
_, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
|
201
|
+
|
202
|
+
# if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
|
203
|
+
self.unet_offload_hook = hook
|
204
|
+
|
205
|
+
if self.safety_checker is not None:
|
206
|
+
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
207
|
+
|
208
|
+
# We'll offload the last model manually.
|
209
|
+
self.final_offload_hook = hook
|
210
|
+
|
211
|
+
def remove_all_hooks(self):
|
212
|
+
if is_accelerate_available():
|
213
|
+
from accelerate.hooks import remove_hook_from_module
|
214
|
+
else:
|
215
|
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
216
|
+
|
217
|
+
for model in [self.text_encoder, self.unet, self.safety_checker]:
|
218
|
+
if model is not None:
|
219
|
+
remove_hook_from_module(model, recurse=True)
|
220
|
+
|
221
|
+
self.unet_offload_hook = None
|
222
|
+
self.text_encoder_offload_hook = None
|
223
|
+
self.final_offload_hook = None
|
224
|
+
|
225
|
+
@property
|
226
|
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
227
|
+
def _execution_device(self):
|
228
|
+
r"""
|
229
|
+
Returns the device on which the pipeline's models will be executed. After calling
|
230
|
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
231
|
+
hooks.
|
232
|
+
"""
|
233
|
+
if not hasattr(self.unet, "_hf_hook"):
|
234
|
+
return self.device
|
235
|
+
for module in self.unet.modules():
|
236
|
+
if (
|
237
|
+
hasattr(module, "_hf_hook")
|
238
|
+
and hasattr(module._hf_hook, "execution_device")
|
239
|
+
and module._hf_hook.execution_device is not None
|
240
|
+
):
|
241
|
+
return torch.device(module._hf_hook.execution_device)
|
242
|
+
return self.device
|
243
|
+
|
244
|
+
@torch.no_grad()
|
245
|
+
def encode_prompt(
|
246
|
+
self,
|
247
|
+
prompt,
|
248
|
+
do_classifier_free_guidance=True,
|
249
|
+
num_images_per_prompt=1,
|
250
|
+
device=None,
|
251
|
+
negative_prompt=None,
|
252
|
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
253
|
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
254
|
+
clean_caption: bool = False,
|
255
|
+
):
|
256
|
+
r"""
|
257
|
+
Encodes the prompt into text encoder hidden states.
|
258
|
+
|
259
|
+
Args:
|
260
|
+
prompt (`str` or `List[str]`, *optional*):
|
261
|
+
prompt to be encoded
|
262
|
+
device: (`torch.device`, *optional*):
|
263
|
+
torch device to place the resulting embeddings on
|
264
|
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
265
|
+
number of images that should be generated per prompt
|
266
|
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
267
|
+
whether to use classifier free guidance or not
|
268
|
+
negative_prompt (`str` or `List[str]`, *optional*):
|
269
|
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
270
|
+
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
271
|
+
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
272
|
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
273
|
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
274
|
+
provided, text embeddings will be generated from `prompt` input argument.
|
275
|
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
276
|
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
277
|
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
278
|
+
argument.
|
279
|
+
"""
|
280
|
+
if prompt is not None and negative_prompt is not None:
|
281
|
+
if type(prompt) is not type(negative_prompt):
|
282
|
+
raise TypeError(
|
283
|
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
284
|
+
f" {type(prompt)}."
|
285
|
+
)
|
286
|
+
|
287
|
+
if device is None:
|
288
|
+
device = self._execution_device
|
289
|
+
|
290
|
+
if prompt is not None and isinstance(prompt, str):
|
291
|
+
batch_size = 1
|
292
|
+
elif prompt is not None and isinstance(prompt, list):
|
293
|
+
batch_size = len(prompt)
|
294
|
+
else:
|
295
|
+
batch_size = prompt_embeds.shape[0]
|
296
|
+
|
297
|
+
# while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
|
298
|
+
max_length = 77
|
299
|
+
|
300
|
+
if prompt_embeds is None:
|
301
|
+
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
302
|
+
text_inputs = self.tokenizer(
|
303
|
+
prompt,
|
304
|
+
padding="max_length",
|
305
|
+
max_length=max_length,
|
306
|
+
truncation=True,
|
307
|
+
add_special_tokens=True,
|
308
|
+
return_tensors="pt",
|
309
|
+
)
|
310
|
+
text_input_ids = text_inputs.input_ids
|
311
|
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
312
|
+
|
313
|
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
314
|
+
text_input_ids, untruncated_ids
|
315
|
+
):
|
316
|
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
317
|
+
logger.warning(
|
318
|
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
319
|
+
f" {max_length} tokens: {removed_text}"
|
320
|
+
)
|
321
|
+
|
322
|
+
attention_mask = text_inputs.attention_mask.to(device)
|
323
|
+
|
324
|
+
prompt_embeds = self.text_encoder(
|
325
|
+
text_input_ids.to(device),
|
326
|
+
attention_mask=attention_mask,
|
327
|
+
)
|
328
|
+
prompt_embeds = prompt_embeds[0]
|
329
|
+
|
330
|
+
if self.text_encoder is not None:
|
331
|
+
dtype = self.text_encoder.dtype
|
332
|
+
elif self.unet is not None:
|
333
|
+
dtype = self.unet.dtype
|
334
|
+
else:
|
335
|
+
dtype = None
|
336
|
+
|
337
|
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
338
|
+
|
339
|
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
340
|
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
341
|
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
342
|
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
343
|
+
|
344
|
+
# get unconditional embeddings for classifier free guidance
|
345
|
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
346
|
+
uncond_tokens: List[str]
|
347
|
+
if negative_prompt is None:
|
348
|
+
uncond_tokens = [""] * batch_size
|
349
|
+
elif isinstance(negative_prompt, str):
|
350
|
+
uncond_tokens = [negative_prompt]
|
351
|
+
elif batch_size != len(negative_prompt):
|
352
|
+
raise ValueError(
|
353
|
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
354
|
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
355
|
+
" the batch size of `prompt`."
|
356
|
+
)
|
357
|
+
else:
|
358
|
+
uncond_tokens = negative_prompt
|
359
|
+
|
360
|
+
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
361
|
+
max_length = prompt_embeds.shape[1]
|
362
|
+
uncond_input = self.tokenizer(
|
363
|
+
uncond_tokens,
|
364
|
+
padding="max_length",
|
365
|
+
max_length=max_length,
|
366
|
+
truncation=True,
|
367
|
+
return_attention_mask=True,
|
368
|
+
add_special_tokens=True,
|
369
|
+
return_tensors="pt",
|
370
|
+
)
|
371
|
+
attention_mask = uncond_input.attention_mask.to(device)
|
372
|
+
|
373
|
+
negative_prompt_embeds = self.text_encoder(
|
374
|
+
uncond_input.input_ids.to(device),
|
375
|
+
attention_mask=attention_mask,
|
376
|
+
)
|
377
|
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
378
|
+
|
379
|
+
if do_classifier_free_guidance:
|
380
|
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
381
|
+
seq_len = negative_prompt_embeds.shape[1]
|
382
|
+
|
383
|
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
384
|
+
|
385
|
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
386
|
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
387
|
+
|
388
|
+
# For classifier free guidance, we need to do two forward passes.
|
389
|
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
390
|
+
# to avoid doing two forward passes
|
391
|
+
else:
|
392
|
+
negative_prompt_embeds = None
|
393
|
+
|
394
|
+
return prompt_embeds, negative_prompt_embeds
|
395
|
+
|
396
|
+
def run_safety_checker(self, image, device, dtype):
|
397
|
+
if self.safety_checker is not None:
|
398
|
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
399
|
+
image, nsfw_detected, watermark_detected = self.safety_checker(
|
400
|
+
images=image,
|
401
|
+
clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
|
402
|
+
)
|
403
|
+
else:
|
404
|
+
nsfw_detected = None
|
405
|
+
watermark_detected = None
|
406
|
+
|
407
|
+
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
408
|
+
self.unet_offload_hook.offload()
|
409
|
+
|
410
|
+
return image, nsfw_detected, watermark_detected
|
411
|
+
|
412
|
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
413
|
+
def prepare_extra_step_kwargs(self, generator, eta):
|
414
|
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
415
|
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
416
|
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
417
|
+
# and should be between [0, 1]
|
418
|
+
|
419
|
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
420
|
+
extra_step_kwargs = {}
|
421
|
+
if accepts_eta:
|
422
|
+
extra_step_kwargs["eta"] = eta
|
423
|
+
|
424
|
+
# check if the scheduler accepts generator
|
425
|
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
426
|
+
if accepts_generator:
|
427
|
+
extra_step_kwargs["generator"] = generator
|
428
|
+
return extra_step_kwargs
|
429
|
+
|
430
|
+
def check_inputs(
|
431
|
+
self,
|
432
|
+
prompt,
|
433
|
+
callback_steps,
|
434
|
+
negative_prompt=None,
|
435
|
+
prompt_embeds=None,
|
436
|
+
negative_prompt_embeds=None,
|
437
|
+
):
|
438
|
+
if (callback_steps is None) or (
|
439
|
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
440
|
+
):
|
441
|
+
raise ValueError(
|
442
|
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
443
|
+
f" {type(callback_steps)}."
|
444
|
+
)
|
445
|
+
|
446
|
+
if prompt is not None and prompt_embeds is not None:
|
447
|
+
raise ValueError(
|
448
|
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
449
|
+
" only forward one of the two."
|
450
|
+
)
|
451
|
+
elif prompt is None and prompt_embeds is None:
|
452
|
+
raise ValueError(
|
453
|
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
454
|
+
)
|
455
|
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
456
|
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
457
|
+
|
458
|
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
459
|
+
raise ValueError(
|
460
|
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
461
|
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
462
|
+
)
|
463
|
+
|
464
|
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
465
|
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
466
|
+
raise ValueError(
|
467
|
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
468
|
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
469
|
+
f" {negative_prompt_embeds.shape}."
|
470
|
+
)
|
471
|
+
|
472
|
+
def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator):
|
473
|
+
shape = (batch_size, num_channels, height, width)
|
474
|
+
if isinstance(generator, list) and len(generator) != batch_size:
|
475
|
+
raise ValueError(
|
476
|
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
477
|
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
478
|
+
)
|
479
|
+
|
480
|
+
intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
481
|
+
|
482
|
+
# scale the initial noise by the standard deviation required by the scheduler
|
483
|
+
intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
|
484
|
+
return intermediate_images
|
485
|
+
|
486
|
+
def _text_preprocessing(self, text, clean_caption=False):
|
487
|
+
if clean_caption and not is_bs4_available():
|
488
|
+
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
489
|
+
logger.warn("Setting `clean_caption` to False...")
|
490
|
+
clean_caption = False
|
491
|
+
|
492
|
+
if clean_caption and not is_ftfy_available():
|
493
|
+
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
494
|
+
logger.warn("Setting `clean_caption` to False...")
|
495
|
+
clean_caption = False
|
496
|
+
|
497
|
+
if not isinstance(text, (tuple, list)):
|
498
|
+
text = [text]
|
499
|
+
|
500
|
+
def process(text: str):
|
501
|
+
if clean_caption:
|
502
|
+
text = self._clean_caption(text)
|
503
|
+
text = self._clean_caption(text)
|
504
|
+
else:
|
505
|
+
text = text.lower().strip()
|
506
|
+
return text
|
507
|
+
|
508
|
+
return [process(t) for t in text]
|
509
|
+
|
510
|
+
def _clean_caption(self, caption):
|
511
|
+
caption = str(caption)
|
512
|
+
caption = ul.unquote_plus(caption)
|
513
|
+
caption = caption.strip().lower()
|
514
|
+
caption = re.sub("<person>", "person", caption)
|
515
|
+
# urls:
|
516
|
+
caption = re.sub(
|
517
|
+
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
518
|
+
"",
|
519
|
+
caption,
|
520
|
+
) # regex for urls
|
521
|
+
caption = re.sub(
|
522
|
+
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
523
|
+
"",
|
524
|
+
caption,
|
525
|
+
) # regex for urls
|
526
|
+
# html:
|
527
|
+
caption = BeautifulSoup(caption, features="html.parser").text
|
528
|
+
|
529
|
+
# @<nickname>
|
530
|
+
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
531
|
+
|
532
|
+
# 31C0—31EF CJK Strokes
|
533
|
+
# 31F0—31FF Katakana Phonetic Extensions
|
534
|
+
# 3200—32FF Enclosed CJK Letters and Months
|
535
|
+
# 3300—33FF CJK Compatibility
|
536
|
+
# 3400—4DBF CJK Unified Ideographs Extension A
|
537
|
+
# 4DC0—4DFF Yijing Hexagram Symbols
|
538
|
+
# 4E00—9FFF CJK Unified Ideographs
|
539
|
+
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
540
|
+
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
541
|
+
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
542
|
+
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
543
|
+
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
544
|
+
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
545
|
+
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
546
|
+
#######################################################
|
547
|
+
|
548
|
+
# все виды тире / all types of dash --> "-"
|
549
|
+
caption = re.sub(
|
550
|
+
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
551
|
+
"-",
|
552
|
+
caption,
|
553
|
+
)
|
554
|
+
|
555
|
+
# кавычки к одному стандарту
|
556
|
+
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
557
|
+
caption = re.sub(r"[‘’]", "'", caption)
|
558
|
+
|
559
|
+
# "
|
560
|
+
caption = re.sub(r""?", "", caption)
|
561
|
+
# &
|
562
|
+
caption = re.sub(r"&", "", caption)
|
563
|
+
|
564
|
+
# ip adresses:
|
565
|
+
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
566
|
+
|
567
|
+
# article ids:
|
568
|
+
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
569
|
+
|
570
|
+
# \n
|
571
|
+
caption = re.sub(r"\\n", " ", caption)
|
572
|
+
|
573
|
+
# "#123"
|
574
|
+
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
575
|
+
# "#12345.."
|
576
|
+
caption = re.sub(r"#\d{5,}\b", "", caption)
|
577
|
+
# "123456.."
|
578
|
+
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
579
|
+
# filenames:
|
580
|
+
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
581
|
+
|
582
|
+
#
|
583
|
+
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
584
|
+
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
585
|
+
|
586
|
+
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
587
|
+
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
588
|
+
|
589
|
+
# this-is-my-cute-cat / this_is_my_cute_cat
|
590
|
+
regex2 = re.compile(r"(?:\-|\_)")
|
591
|
+
if len(re.findall(regex2, caption)) > 3:
|
592
|
+
caption = re.sub(regex2, " ", caption)
|
593
|
+
|
594
|
+
caption = ftfy.fix_text(caption)
|
595
|
+
caption = html.unescape(html.unescape(caption))
|
596
|
+
|
597
|
+
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
598
|
+
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
599
|
+
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
600
|
+
|
601
|
+
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
602
|
+
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
603
|
+
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
604
|
+
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
605
|
+
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
606
|
+
|
607
|
+
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
608
|
+
|
609
|
+
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
610
|
+
|
611
|
+
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
612
|
+
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
613
|
+
caption = re.sub(r"\s+", " ", caption)
|
614
|
+
|
615
|
+
caption.strip()
|
616
|
+
|
617
|
+
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
618
|
+
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
619
|
+
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
620
|
+
caption = re.sub(r"^\.\S+$", "", caption)
|
621
|
+
|
622
|
+
return caption.strip()
|
623
|
+
|
624
|
+
@torch.no_grad()
|
625
|
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
626
|
+
def __call__(
|
627
|
+
self,
|
628
|
+
prompt: Union[str, List[str]] = None,
|
629
|
+
num_inference_steps: int = 100,
|
630
|
+
timesteps: List[int] = None,
|
631
|
+
guidance_scale: float = 7.0,
|
632
|
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
633
|
+
num_images_per_prompt: Optional[int] = 1,
|
634
|
+
height: Optional[int] = None,
|
635
|
+
width: Optional[int] = None,
|
636
|
+
eta: float = 0.0,
|
637
|
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
638
|
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
639
|
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
640
|
+
output_type: Optional[str] = "pil",
|
641
|
+
return_dict: bool = True,
|
642
|
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
643
|
+
callback_steps: int = 1,
|
644
|
+
clean_caption: bool = True,
|
645
|
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
646
|
+
):
|
647
|
+
"""
|
648
|
+
Function invoked when calling the pipeline for generation.
|
649
|
+
|
650
|
+
Args:
|
651
|
+
prompt (`str` or `List[str]`, *optional*):
|
652
|
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
653
|
+
instead.
|
654
|
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
655
|
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
656
|
+
expense of slower inference.
|
657
|
+
timesteps (`List[int]`, *optional*):
|
658
|
+
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
659
|
+
timesteps are used. Must be in descending order.
|
660
|
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
661
|
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
662
|
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
663
|
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
664
|
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
665
|
+
usually at the expense of lower image quality.
|
666
|
+
negative_prompt (`str` or `List[str]`, *optional*):
|
667
|
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
668
|
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
669
|
+
less than `1`).
|
670
|
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
671
|
+
The number of images to generate per prompt.
|
672
|
+
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
673
|
+
The height in pixels of the generated image.
|
674
|
+
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
675
|
+
The width in pixels of the generated image.
|
676
|
+
eta (`float`, *optional*, defaults to 0.0):
|
677
|
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
678
|
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
679
|
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
680
|
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
681
|
+
to make generation deterministic.
|
682
|
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
683
|
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
684
|
+
provided, text embeddings will be generated from `prompt` input argument.
|
685
|
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
686
|
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
687
|
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
688
|
+
argument.
|
689
|
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
690
|
+
The output format of the generate image. Choose between
|
691
|
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
692
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
693
|
+
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
694
|
+
callback (`Callable`, *optional*):
|
695
|
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
696
|
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
697
|
+
callback_steps (`int`, *optional*, defaults to 1):
|
698
|
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
699
|
+
called at every step.
|
700
|
+
clean_caption (`bool`, *optional*, defaults to `True`):
|
701
|
+
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
702
|
+
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
703
|
+
prompt.
|
704
|
+
cross_attention_kwargs (`dict`, *optional*):
|
705
|
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
706
|
+
`self.processor` in
|
707
|
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
708
|
+
|
709
|
+
Examples:
|
710
|
+
|
711
|
+
Returns:
|
712
|
+
[`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
|
713
|
+
[`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
|
714
|
+
returning a tuple, the first element is a list with the generated images, and the second element is a list
|
715
|
+
of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
|
716
|
+
or watermarked content, according to the `safety_checker`.
|
717
|
+
"""
|
718
|
+
# 1. Check inputs. Raise error if not correct
|
719
|
+
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
720
|
+
|
721
|
+
# 2. Define call parameters
|
722
|
+
height = height or self.unet.config.sample_size
|
723
|
+
width = width or self.unet.config.sample_size
|
724
|
+
|
725
|
+
if prompt is not None and isinstance(prompt, str):
|
726
|
+
batch_size = 1
|
727
|
+
elif prompt is not None and isinstance(prompt, list):
|
728
|
+
batch_size = len(prompt)
|
729
|
+
else:
|
730
|
+
batch_size = prompt_embeds.shape[0]
|
731
|
+
|
732
|
+
device = self._execution_device
|
733
|
+
|
734
|
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
735
|
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
736
|
+
# corresponds to doing no classifier free guidance.
|
737
|
+
do_classifier_free_guidance = guidance_scale > 1.0
|
738
|
+
|
739
|
+
# 3. Encode input prompt
|
740
|
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
741
|
+
prompt,
|
742
|
+
do_classifier_free_guidance,
|
743
|
+
num_images_per_prompt=num_images_per_prompt,
|
744
|
+
device=device,
|
745
|
+
negative_prompt=negative_prompt,
|
746
|
+
prompt_embeds=prompt_embeds,
|
747
|
+
negative_prompt_embeds=negative_prompt_embeds,
|
748
|
+
clean_caption=clean_caption,
|
749
|
+
)
|
750
|
+
|
751
|
+
if do_classifier_free_guidance:
|
752
|
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
753
|
+
|
754
|
+
# 4. Prepare timesteps
|
755
|
+
if timesteps is not None:
|
756
|
+
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
|
757
|
+
timesteps = self.scheduler.timesteps
|
758
|
+
num_inference_steps = len(timesteps)
|
759
|
+
else:
|
760
|
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
761
|
+
timesteps = self.scheduler.timesteps
|
762
|
+
|
763
|
+
# 5. Prepare intermediate images
|
764
|
+
intermediate_images = self.prepare_intermediate_images(
|
765
|
+
batch_size * num_images_per_prompt,
|
766
|
+
self.unet.config.in_channels,
|
767
|
+
height,
|
768
|
+
width,
|
769
|
+
prompt_embeds.dtype,
|
770
|
+
device,
|
771
|
+
generator,
|
772
|
+
)
|
773
|
+
|
774
|
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
775
|
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
776
|
+
|
777
|
+
# HACK: see comment in `enable_model_cpu_offload`
|
778
|
+
if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
|
779
|
+
self.text_encoder_offload_hook.offload()
|
780
|
+
|
781
|
+
# 7. Denoising loop
|
782
|
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
783
|
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
784
|
+
for i, t in enumerate(timesteps):
|
785
|
+
model_input = (
|
786
|
+
torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images
|
787
|
+
)
|
788
|
+
model_input = self.scheduler.scale_model_input(model_input, t)
|
789
|
+
|
790
|
+
# predict the noise residual
|
791
|
+
noise_pred = self.unet(
|
792
|
+
model_input,
|
793
|
+
t,
|
794
|
+
encoder_hidden_states=prompt_embeds,
|
795
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
796
|
+
).sample
|
797
|
+
|
798
|
+
# perform guidance
|
799
|
+
if do_classifier_free_guidance:
|
800
|
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
801
|
+
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
|
802
|
+
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
|
803
|
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
804
|
+
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
805
|
+
|
806
|
+
# compute the previous noisy sample x_t -> x_t-1
|
807
|
+
intermediate_images = self.scheduler.step(
|
808
|
+
noise_pred, t, intermediate_images, **extra_step_kwargs
|
809
|
+
).prev_sample
|
810
|
+
|
811
|
+
# call the callback, if provided
|
812
|
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
813
|
+
progress_bar.update()
|
814
|
+
if callback is not None and i % callback_steps == 0:
|
815
|
+
callback(i, t, intermediate_images)
|
816
|
+
|
817
|
+
image = intermediate_images
|
818
|
+
|
819
|
+
if output_type == "pil":
|
820
|
+
# 8. Post-processing
|
821
|
+
image = (image / 2 + 0.5).clamp(0, 1)
|
822
|
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
823
|
+
|
824
|
+
# 9. Run safety checker
|
825
|
+
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
826
|
+
|
827
|
+
# 10. Convert to PIL
|
828
|
+
image = self.numpy_to_pil(image)
|
829
|
+
|
830
|
+
# 11. Apply watermark
|
831
|
+
if self.watermarker is not None:
|
832
|
+
self.watermarker.apply_watermark(image, self.unet.config.sample_size)
|
833
|
+
elif output_type == "pt":
|
834
|
+
nsfw_detected = None
|
835
|
+
watermark_detected = None
|
836
|
+
|
837
|
+
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
838
|
+
self.unet_offload_hook.offload()
|
839
|
+
else:
|
840
|
+
# 8. Post-processing
|
841
|
+
image = (image / 2 + 0.5).clamp(0, 1)
|
842
|
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
843
|
+
|
844
|
+
# 9. Run safety checker
|
845
|
+
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
846
|
+
|
847
|
+
# Offload last model to CPU
|
848
|
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
849
|
+
self.final_offload_hook.offload()
|
850
|
+
|
851
|
+
if not return_dict:
|
852
|
+
return (image, nsfw_detected, watermark_detected)
|
853
|
+
|
854
|
+
return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)
|