diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. diffusers/__init__.py +7 -2
  2. diffusers/configuration_utils.py +4 -0
  3. diffusers/loaders.py +262 -12
  4. diffusers/models/attention.py +31 -12
  5. diffusers/models/attention_processor.py +189 -0
  6. diffusers/models/controlnet.py +9 -2
  7. diffusers/models/embeddings.py +66 -0
  8. diffusers/models/modeling_pytorch_flax_utils.py +6 -0
  9. diffusers/models/modeling_utils.py +5 -2
  10. diffusers/models/transformer_2d.py +1 -1
  11. diffusers/models/unet_2d_condition.py +45 -6
  12. diffusers/models/vae.py +3 -0
  13. diffusers/pipelines/__init__.py +8 -0
  14. diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
  15. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
  16. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
  17. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
  18. diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
  19. diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
  20. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
  21. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
  22. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
  23. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
  24. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
  25. diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
  26. diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
  27. diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
  28. diffusers/pipelines/pipeline_utils.py +54 -25
  29. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
  30. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
  31. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
  32. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
  33. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
  34. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
  35. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
  36. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
  37. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
  38. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
  39. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
  40. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
  41. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
  42. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
  43. diffusers/schedulers/scheduling_ddpm.py +63 -16
  44. diffusers/schedulers/scheduling_heun_discrete.py +51 -1
  45. diffusers/utils/__init__.py +4 -1
  46. diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
  47. diffusers/utils/dynamic_modules_utils.py +1 -1
  48. diffusers/utils/hub_utils.py +4 -1
  49. diffusers/utils/import_utils.py +41 -0
  50. diffusers/utils/pil_utils.py +24 -0
  51. diffusers/utils/testing_utils.py +10 -0
  52. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
  53. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
  54. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
  55. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
  56. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
  57. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,854 @@
1
+ import html
2
+ import inspect
3
+ import re
4
+ import urllib.parse as ul
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+
7
+ import torch
8
+ from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
9
+
10
+ from ...models import UNet2DConditionModel
11
+ from ...schedulers import DDPMScheduler
12
+ from ...utils import (
13
+ BACKENDS_MAPPING,
14
+ is_accelerate_available,
15
+ is_accelerate_version,
16
+ is_bs4_available,
17
+ is_ftfy_available,
18
+ logging,
19
+ randn_tensor,
20
+ replace_example_docstring,
21
+ )
22
+ from ..pipeline_utils import DiffusionPipeline
23
+ from . import IFPipelineOutput
24
+ from .safety_checker import IFSafetyChecker
25
+ from .watermark import IFWatermarker
26
+
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+ if is_bs4_available():
31
+ from bs4 import BeautifulSoup
32
+
33
+ if is_ftfy_available():
34
+ import ftfy
35
+
36
+
37
+ EXAMPLE_DOC_STRING = """
38
+ Examples:
39
+ ```py
40
+ >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline
41
+ >>> from diffusers.utils import pt_to_pil
42
+ >>> import torch
43
+
44
+ >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
45
+ >>> pipe.enable_model_cpu_offload()
46
+
47
+ >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
48
+ >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
49
+
50
+ >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images
51
+
52
+ >>> # save intermediate image
53
+ >>> pil_image = pt_to_pil(image)
54
+ >>> pil_image[0].save("./if_stage_I.png")
55
+
56
+ >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained(
57
+ ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
58
+ ... )
59
+ >>> super_res_1_pipe.enable_model_cpu_offload()
60
+
61
+ >>> image = super_res_1_pipe(
62
+ ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt"
63
+ ... ).images
64
+
65
+ >>> # save intermediate image
66
+ >>> pil_image = pt_to_pil(image)
67
+ >>> pil_image[0].save("./if_stage_I.png")
68
+
69
+ >>> safety_modules = {
70
+ ... "feature_extractor": pipe.feature_extractor,
71
+ ... "safety_checker": pipe.safety_checker,
72
+ ... "watermarker": pipe.watermarker,
73
+ ... }
74
+ >>> super_res_2_pipe = DiffusionPipeline.from_pretrained(
75
+ ... "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
76
+ ... )
77
+ >>> super_res_2_pipe.enable_model_cpu_offload()
78
+
79
+ >>> image = super_res_2_pipe(
80
+ ... prompt=prompt,
81
+ ... image=image,
82
+ ... ).images
83
+ >>> image[0].save("./if_stage_II.png")
84
+ ```
85
+ """
86
+
87
+
88
+ class IFPipeline(DiffusionPipeline):
89
+ tokenizer: T5Tokenizer
90
+ text_encoder: T5EncoderModel
91
+
92
+ unet: UNet2DConditionModel
93
+ scheduler: DDPMScheduler
94
+
95
+ feature_extractor: Optional[CLIPImageProcessor]
96
+ safety_checker: Optional[IFSafetyChecker]
97
+
98
+ watermarker: Optional[IFWatermarker]
99
+
100
+ bad_punct_regex = re.compile(
101
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
102
+ ) # noqa
103
+
104
+ _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
105
+
106
+ def __init__(
107
+ self,
108
+ tokenizer: T5Tokenizer,
109
+ text_encoder: T5EncoderModel,
110
+ unet: UNet2DConditionModel,
111
+ scheduler: DDPMScheduler,
112
+ safety_checker: Optional[IFSafetyChecker],
113
+ feature_extractor: Optional[CLIPImageProcessor],
114
+ watermarker: Optional[IFWatermarker],
115
+ requires_safety_checker: bool = True,
116
+ ):
117
+ super().__init__()
118
+
119
+ if safety_checker is None and requires_safety_checker:
120
+ logger.warning(
121
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
122
+ " that you abide to the conditions of the IF license and do not expose unfiltered"
123
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
124
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
125
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
126
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
127
+ )
128
+
129
+ if safety_checker is not None and feature_extractor is None:
130
+ raise ValueError(
131
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
132
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
133
+ )
134
+
135
+ self.register_modules(
136
+ tokenizer=tokenizer,
137
+ text_encoder=text_encoder,
138
+ unet=unet,
139
+ scheduler=scheduler,
140
+ safety_checker=safety_checker,
141
+ feature_extractor=feature_extractor,
142
+ watermarker=watermarker,
143
+ )
144
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
145
+
146
+ def enable_sequential_cpu_offload(self, gpu_id=0):
147
+ r"""
148
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
149
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
150
+ when their specific submodule has its `forward` method called.
151
+ """
152
+ if is_accelerate_available():
153
+ from accelerate import cpu_offload
154
+ else:
155
+ raise ImportError("Please install accelerate via `pip install accelerate`")
156
+
157
+ device = torch.device(f"cuda:{gpu_id}")
158
+
159
+ models = [
160
+ self.text_encoder,
161
+ self.unet,
162
+ ]
163
+ for cpu_offloaded_model in models:
164
+ if cpu_offloaded_model is not None:
165
+ cpu_offload(cpu_offloaded_model, device)
166
+
167
+ if self.safety_checker is not None:
168
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
169
+
170
+ def enable_model_cpu_offload(self, gpu_id=0):
171
+ r"""
172
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
173
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
174
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
175
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
176
+ """
177
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
178
+ from accelerate import cpu_offload_with_hook
179
+ else:
180
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
181
+
182
+ device = torch.device(f"cuda:{gpu_id}")
183
+
184
+ if self.device.type != "cpu":
185
+ self.to("cpu", silence_dtype_warnings=True)
186
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
187
+
188
+ hook = None
189
+
190
+ if self.text_encoder is not None:
191
+ _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
192
+
193
+ # Accelerate will move the next model to the device _before_ calling the offload hook of the
194
+ # previous model. This will cause both models to be present on the device at the same time.
195
+ # IF uses T5 for its text encoder which is really large. We can manually call the offload
196
+ # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
197
+ # the GPU.
198
+ self.text_encoder_offload_hook = hook
199
+
200
+ _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
201
+
202
+ # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
203
+ self.unet_offload_hook = hook
204
+
205
+ if self.safety_checker is not None:
206
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
207
+
208
+ # We'll offload the last model manually.
209
+ self.final_offload_hook = hook
210
+
211
+ def remove_all_hooks(self):
212
+ if is_accelerate_available():
213
+ from accelerate.hooks import remove_hook_from_module
214
+ else:
215
+ raise ImportError("Please install accelerate via `pip install accelerate`")
216
+
217
+ for model in [self.text_encoder, self.unet, self.safety_checker]:
218
+ if model is not None:
219
+ remove_hook_from_module(model, recurse=True)
220
+
221
+ self.unet_offload_hook = None
222
+ self.text_encoder_offload_hook = None
223
+ self.final_offload_hook = None
224
+
225
+ @property
226
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
227
+ def _execution_device(self):
228
+ r"""
229
+ Returns the device on which the pipeline's models will be executed. After calling
230
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
231
+ hooks.
232
+ """
233
+ if not hasattr(self.unet, "_hf_hook"):
234
+ return self.device
235
+ for module in self.unet.modules():
236
+ if (
237
+ hasattr(module, "_hf_hook")
238
+ and hasattr(module._hf_hook, "execution_device")
239
+ and module._hf_hook.execution_device is not None
240
+ ):
241
+ return torch.device(module._hf_hook.execution_device)
242
+ return self.device
243
+
244
+ @torch.no_grad()
245
+ def encode_prompt(
246
+ self,
247
+ prompt,
248
+ do_classifier_free_guidance=True,
249
+ num_images_per_prompt=1,
250
+ device=None,
251
+ negative_prompt=None,
252
+ prompt_embeds: Optional[torch.FloatTensor] = None,
253
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
254
+ clean_caption: bool = False,
255
+ ):
256
+ r"""
257
+ Encodes the prompt into text encoder hidden states.
258
+
259
+ Args:
260
+ prompt (`str` or `List[str]`, *optional*):
261
+ prompt to be encoded
262
+ device: (`torch.device`, *optional*):
263
+ torch device to place the resulting embeddings on
264
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
265
+ number of images that should be generated per prompt
266
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
267
+ whether to use classifier free guidance or not
268
+ negative_prompt (`str` or `List[str]`, *optional*):
269
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
270
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
271
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
272
+ prompt_embeds (`torch.FloatTensor`, *optional*):
273
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
274
+ provided, text embeddings will be generated from `prompt` input argument.
275
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
276
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
277
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
278
+ argument.
279
+ """
280
+ if prompt is not None and negative_prompt is not None:
281
+ if type(prompt) is not type(negative_prompt):
282
+ raise TypeError(
283
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
284
+ f" {type(prompt)}."
285
+ )
286
+
287
+ if device is None:
288
+ device = self._execution_device
289
+
290
+ if prompt is not None and isinstance(prompt, str):
291
+ batch_size = 1
292
+ elif prompt is not None and isinstance(prompt, list):
293
+ batch_size = len(prompt)
294
+ else:
295
+ batch_size = prompt_embeds.shape[0]
296
+
297
+ # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
298
+ max_length = 77
299
+
300
+ if prompt_embeds is None:
301
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
302
+ text_inputs = self.tokenizer(
303
+ prompt,
304
+ padding="max_length",
305
+ max_length=max_length,
306
+ truncation=True,
307
+ add_special_tokens=True,
308
+ return_tensors="pt",
309
+ )
310
+ text_input_ids = text_inputs.input_ids
311
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
312
+
313
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
314
+ text_input_ids, untruncated_ids
315
+ ):
316
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
317
+ logger.warning(
318
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
319
+ f" {max_length} tokens: {removed_text}"
320
+ )
321
+
322
+ attention_mask = text_inputs.attention_mask.to(device)
323
+
324
+ prompt_embeds = self.text_encoder(
325
+ text_input_ids.to(device),
326
+ attention_mask=attention_mask,
327
+ )
328
+ prompt_embeds = prompt_embeds[0]
329
+
330
+ if self.text_encoder is not None:
331
+ dtype = self.text_encoder.dtype
332
+ elif self.unet is not None:
333
+ dtype = self.unet.dtype
334
+ else:
335
+ dtype = None
336
+
337
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
338
+
339
+ bs_embed, seq_len, _ = prompt_embeds.shape
340
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
341
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
342
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
343
+
344
+ # get unconditional embeddings for classifier free guidance
345
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
346
+ uncond_tokens: List[str]
347
+ if negative_prompt is None:
348
+ uncond_tokens = [""] * batch_size
349
+ elif isinstance(negative_prompt, str):
350
+ uncond_tokens = [negative_prompt]
351
+ elif batch_size != len(negative_prompt):
352
+ raise ValueError(
353
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
354
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
355
+ " the batch size of `prompt`."
356
+ )
357
+ else:
358
+ uncond_tokens = negative_prompt
359
+
360
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
361
+ max_length = prompt_embeds.shape[1]
362
+ uncond_input = self.tokenizer(
363
+ uncond_tokens,
364
+ padding="max_length",
365
+ max_length=max_length,
366
+ truncation=True,
367
+ return_attention_mask=True,
368
+ add_special_tokens=True,
369
+ return_tensors="pt",
370
+ )
371
+ attention_mask = uncond_input.attention_mask.to(device)
372
+
373
+ negative_prompt_embeds = self.text_encoder(
374
+ uncond_input.input_ids.to(device),
375
+ attention_mask=attention_mask,
376
+ )
377
+ negative_prompt_embeds = negative_prompt_embeds[0]
378
+
379
+ if do_classifier_free_guidance:
380
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
381
+ seq_len = negative_prompt_embeds.shape[1]
382
+
383
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
384
+
385
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
386
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
387
+
388
+ # For classifier free guidance, we need to do two forward passes.
389
+ # Here we concatenate the unconditional and text embeddings into a single batch
390
+ # to avoid doing two forward passes
391
+ else:
392
+ negative_prompt_embeds = None
393
+
394
+ return prompt_embeds, negative_prompt_embeds
395
+
396
+ def run_safety_checker(self, image, device, dtype):
397
+ if self.safety_checker is not None:
398
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
399
+ image, nsfw_detected, watermark_detected = self.safety_checker(
400
+ images=image,
401
+ clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
402
+ )
403
+ else:
404
+ nsfw_detected = None
405
+ watermark_detected = None
406
+
407
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
408
+ self.unet_offload_hook.offload()
409
+
410
+ return image, nsfw_detected, watermark_detected
411
+
412
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
413
+ def prepare_extra_step_kwargs(self, generator, eta):
414
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
415
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
416
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
417
+ # and should be between [0, 1]
418
+
419
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
420
+ extra_step_kwargs = {}
421
+ if accepts_eta:
422
+ extra_step_kwargs["eta"] = eta
423
+
424
+ # check if the scheduler accepts generator
425
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
426
+ if accepts_generator:
427
+ extra_step_kwargs["generator"] = generator
428
+ return extra_step_kwargs
429
+
430
+ def check_inputs(
431
+ self,
432
+ prompt,
433
+ callback_steps,
434
+ negative_prompt=None,
435
+ prompt_embeds=None,
436
+ negative_prompt_embeds=None,
437
+ ):
438
+ if (callback_steps is None) or (
439
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
440
+ ):
441
+ raise ValueError(
442
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
443
+ f" {type(callback_steps)}."
444
+ )
445
+
446
+ if prompt is not None and prompt_embeds is not None:
447
+ raise ValueError(
448
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
449
+ " only forward one of the two."
450
+ )
451
+ elif prompt is None and prompt_embeds is None:
452
+ raise ValueError(
453
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
454
+ )
455
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
456
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
457
+
458
+ if negative_prompt is not None and negative_prompt_embeds is not None:
459
+ raise ValueError(
460
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
461
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
462
+ )
463
+
464
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
465
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
466
+ raise ValueError(
467
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
468
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
469
+ f" {negative_prompt_embeds.shape}."
470
+ )
471
+
472
+ def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator):
473
+ shape = (batch_size, num_channels, height, width)
474
+ if isinstance(generator, list) and len(generator) != batch_size:
475
+ raise ValueError(
476
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
477
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
478
+ )
479
+
480
+ intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
481
+
482
+ # scale the initial noise by the standard deviation required by the scheduler
483
+ intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
484
+ return intermediate_images
485
+
486
+ def _text_preprocessing(self, text, clean_caption=False):
487
+ if clean_caption and not is_bs4_available():
488
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
489
+ logger.warn("Setting `clean_caption` to False...")
490
+ clean_caption = False
491
+
492
+ if clean_caption and not is_ftfy_available():
493
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
494
+ logger.warn("Setting `clean_caption` to False...")
495
+ clean_caption = False
496
+
497
+ if not isinstance(text, (tuple, list)):
498
+ text = [text]
499
+
500
+ def process(text: str):
501
+ if clean_caption:
502
+ text = self._clean_caption(text)
503
+ text = self._clean_caption(text)
504
+ else:
505
+ text = text.lower().strip()
506
+ return text
507
+
508
+ return [process(t) for t in text]
509
+
510
+ def _clean_caption(self, caption):
511
+ caption = str(caption)
512
+ caption = ul.unquote_plus(caption)
513
+ caption = caption.strip().lower()
514
+ caption = re.sub("<person>", "person", caption)
515
+ # urls:
516
+ caption = re.sub(
517
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
518
+ "",
519
+ caption,
520
+ ) # regex for urls
521
+ caption = re.sub(
522
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
523
+ "",
524
+ caption,
525
+ ) # regex for urls
526
+ # html:
527
+ caption = BeautifulSoup(caption, features="html.parser").text
528
+
529
+ # @<nickname>
530
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
531
+
532
+ # 31C0—31EF CJK Strokes
533
+ # 31F0—31FF Katakana Phonetic Extensions
534
+ # 3200—32FF Enclosed CJK Letters and Months
535
+ # 3300—33FF CJK Compatibility
536
+ # 3400—4DBF CJK Unified Ideographs Extension A
537
+ # 4DC0—4DFF Yijing Hexagram Symbols
538
+ # 4E00—9FFF CJK Unified Ideographs
539
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
540
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
541
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
542
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
543
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
544
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
545
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
546
+ #######################################################
547
+
548
+ # все виды тире / all types of dash --> "-"
549
+ caption = re.sub(
550
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
551
+ "-",
552
+ caption,
553
+ )
554
+
555
+ # кавычки к одному стандарту
556
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
557
+ caption = re.sub(r"[‘’]", "'", caption)
558
+
559
+ # &quot;
560
+ caption = re.sub(r"&quot;?", "", caption)
561
+ # &amp
562
+ caption = re.sub(r"&amp", "", caption)
563
+
564
+ # ip adresses:
565
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
566
+
567
+ # article ids:
568
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
569
+
570
+ # \n
571
+ caption = re.sub(r"\\n", " ", caption)
572
+
573
+ # "#123"
574
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
575
+ # "#12345.."
576
+ caption = re.sub(r"#\d{5,}\b", "", caption)
577
+ # "123456.."
578
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
579
+ # filenames:
580
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
581
+
582
+ #
583
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
584
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
585
+
586
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
587
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
588
+
589
+ # this-is-my-cute-cat / this_is_my_cute_cat
590
+ regex2 = re.compile(r"(?:\-|\_)")
591
+ if len(re.findall(regex2, caption)) > 3:
592
+ caption = re.sub(regex2, " ", caption)
593
+
594
+ caption = ftfy.fix_text(caption)
595
+ caption = html.unescape(html.unescape(caption))
596
+
597
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
598
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
599
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
600
+
601
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
602
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
603
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
604
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
605
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
606
+
607
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
608
+
609
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
610
+
611
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
612
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
613
+ caption = re.sub(r"\s+", " ", caption)
614
+
615
+ caption.strip()
616
+
617
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
618
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
619
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
620
+ caption = re.sub(r"^\.\S+$", "", caption)
621
+
622
+ return caption.strip()
623
+
624
+ @torch.no_grad()
625
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
626
+ def __call__(
627
+ self,
628
+ prompt: Union[str, List[str]] = None,
629
+ num_inference_steps: int = 100,
630
+ timesteps: List[int] = None,
631
+ guidance_scale: float = 7.0,
632
+ negative_prompt: Optional[Union[str, List[str]]] = None,
633
+ num_images_per_prompt: Optional[int] = 1,
634
+ height: Optional[int] = None,
635
+ width: Optional[int] = None,
636
+ eta: float = 0.0,
637
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
638
+ prompt_embeds: Optional[torch.FloatTensor] = None,
639
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
640
+ output_type: Optional[str] = "pil",
641
+ return_dict: bool = True,
642
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
643
+ callback_steps: int = 1,
644
+ clean_caption: bool = True,
645
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
646
+ ):
647
+ """
648
+ Function invoked when calling the pipeline for generation.
649
+
650
+ Args:
651
+ prompt (`str` or `List[str]`, *optional*):
652
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
653
+ instead.
654
+ num_inference_steps (`int`, *optional*, defaults to 50):
655
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
656
+ expense of slower inference.
657
+ timesteps (`List[int]`, *optional*):
658
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
659
+ timesteps are used. Must be in descending order.
660
+ guidance_scale (`float`, *optional*, defaults to 7.5):
661
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
662
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
663
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
664
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
665
+ usually at the expense of lower image quality.
666
+ negative_prompt (`str` or `List[str]`, *optional*):
667
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
668
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
669
+ less than `1`).
670
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
671
+ The number of images to generate per prompt.
672
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
673
+ The height in pixels of the generated image.
674
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
675
+ The width in pixels of the generated image.
676
+ eta (`float`, *optional*, defaults to 0.0):
677
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
678
+ [`schedulers.DDIMScheduler`], will be ignored for others.
679
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
680
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
681
+ to make generation deterministic.
682
+ prompt_embeds (`torch.FloatTensor`, *optional*):
683
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
684
+ provided, text embeddings will be generated from `prompt` input argument.
685
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
686
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
687
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
688
+ argument.
689
+ output_type (`str`, *optional*, defaults to `"pil"`):
690
+ The output format of the generate image. Choose between
691
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
692
+ return_dict (`bool`, *optional*, defaults to `True`):
693
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
694
+ callback (`Callable`, *optional*):
695
+ A function that will be called every `callback_steps` steps during inference. The function will be
696
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
697
+ callback_steps (`int`, *optional*, defaults to 1):
698
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
699
+ called at every step.
700
+ clean_caption (`bool`, *optional*, defaults to `True`):
701
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
702
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
703
+ prompt.
704
+ cross_attention_kwargs (`dict`, *optional*):
705
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
706
+ `self.processor` in
707
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
708
+
709
+ Examples:
710
+
711
+ Returns:
712
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
713
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
714
+ returning a tuple, the first element is a list with the generated images, and the second element is a list
715
+ of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
716
+ or watermarked content, according to the `safety_checker`.
717
+ """
718
+ # 1. Check inputs. Raise error if not correct
719
+ self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
720
+
721
+ # 2. Define call parameters
722
+ height = height or self.unet.config.sample_size
723
+ width = width or self.unet.config.sample_size
724
+
725
+ if prompt is not None and isinstance(prompt, str):
726
+ batch_size = 1
727
+ elif prompt is not None and isinstance(prompt, list):
728
+ batch_size = len(prompt)
729
+ else:
730
+ batch_size = prompt_embeds.shape[0]
731
+
732
+ device = self._execution_device
733
+
734
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
735
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
736
+ # corresponds to doing no classifier free guidance.
737
+ do_classifier_free_guidance = guidance_scale > 1.0
738
+
739
+ # 3. Encode input prompt
740
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
741
+ prompt,
742
+ do_classifier_free_guidance,
743
+ num_images_per_prompt=num_images_per_prompt,
744
+ device=device,
745
+ negative_prompt=negative_prompt,
746
+ prompt_embeds=prompt_embeds,
747
+ negative_prompt_embeds=negative_prompt_embeds,
748
+ clean_caption=clean_caption,
749
+ )
750
+
751
+ if do_classifier_free_guidance:
752
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
753
+
754
+ # 4. Prepare timesteps
755
+ if timesteps is not None:
756
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
757
+ timesteps = self.scheduler.timesteps
758
+ num_inference_steps = len(timesteps)
759
+ else:
760
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
761
+ timesteps = self.scheduler.timesteps
762
+
763
+ # 5. Prepare intermediate images
764
+ intermediate_images = self.prepare_intermediate_images(
765
+ batch_size * num_images_per_prompt,
766
+ self.unet.config.in_channels,
767
+ height,
768
+ width,
769
+ prompt_embeds.dtype,
770
+ device,
771
+ generator,
772
+ )
773
+
774
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
775
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
776
+
777
+ # HACK: see comment in `enable_model_cpu_offload`
778
+ if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
779
+ self.text_encoder_offload_hook.offload()
780
+
781
+ # 7. Denoising loop
782
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
783
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
784
+ for i, t in enumerate(timesteps):
785
+ model_input = (
786
+ torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images
787
+ )
788
+ model_input = self.scheduler.scale_model_input(model_input, t)
789
+
790
+ # predict the noise residual
791
+ noise_pred = self.unet(
792
+ model_input,
793
+ t,
794
+ encoder_hidden_states=prompt_embeds,
795
+ cross_attention_kwargs=cross_attention_kwargs,
796
+ ).sample
797
+
798
+ # perform guidance
799
+ if do_classifier_free_guidance:
800
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
801
+ noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
802
+ noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
803
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
804
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
805
+
806
+ # compute the previous noisy sample x_t -> x_t-1
807
+ intermediate_images = self.scheduler.step(
808
+ noise_pred, t, intermediate_images, **extra_step_kwargs
809
+ ).prev_sample
810
+
811
+ # call the callback, if provided
812
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
813
+ progress_bar.update()
814
+ if callback is not None and i % callback_steps == 0:
815
+ callback(i, t, intermediate_images)
816
+
817
+ image = intermediate_images
818
+
819
+ if output_type == "pil":
820
+ # 8. Post-processing
821
+ image = (image / 2 + 0.5).clamp(0, 1)
822
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
823
+
824
+ # 9. Run safety checker
825
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
826
+
827
+ # 10. Convert to PIL
828
+ image = self.numpy_to_pil(image)
829
+
830
+ # 11. Apply watermark
831
+ if self.watermarker is not None:
832
+ self.watermarker.apply_watermark(image, self.unet.config.sample_size)
833
+ elif output_type == "pt":
834
+ nsfw_detected = None
835
+ watermark_detected = None
836
+
837
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
838
+ self.unet_offload_hook.offload()
839
+ else:
840
+ # 8. Post-processing
841
+ image = (image / 2 + 0.5).clamp(0, 1)
842
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
843
+
844
+ # 9. Run safety checker
845
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
846
+
847
+ # Offload last model to CPU
848
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
849
+ self.final_offload_hook.offload()
850
+
851
+ if not return_dict:
852
+ return (image, nsfw_detected, watermark_detected)
853
+
854
+ return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)