diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. diffusers/__init__.py +7 -2
  2. diffusers/configuration_utils.py +4 -0
  3. diffusers/loaders.py +262 -12
  4. diffusers/models/attention.py +31 -12
  5. diffusers/models/attention_processor.py +189 -0
  6. diffusers/models/controlnet.py +9 -2
  7. diffusers/models/embeddings.py +66 -0
  8. diffusers/models/modeling_pytorch_flax_utils.py +6 -0
  9. diffusers/models/modeling_utils.py +5 -2
  10. diffusers/models/transformer_2d.py +1 -1
  11. diffusers/models/unet_2d_condition.py +45 -6
  12. diffusers/models/vae.py +3 -0
  13. diffusers/pipelines/__init__.py +8 -0
  14. diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
  15. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
  16. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
  17. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
  18. diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
  19. diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
  20. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
  21. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
  22. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
  23. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
  24. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
  25. diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
  26. diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
  27. diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
  28. diffusers/pipelines/pipeline_utils.py +54 -25
  29. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
  30. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
  31. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
  32. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
  33. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
  34. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
  35. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
  36. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
  37. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
  38. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
  39. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
  40. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
  41. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
  42. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
  43. diffusers/schedulers/scheduling_ddpm.py +63 -16
  44. diffusers/schedulers/scheduling_heun_discrete.py +51 -1
  45. diffusers/utils/__init__.py +4 -1
  46. diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
  47. diffusers/utils/dynamic_modules_utils.py +1 -1
  48. diffusers/utils/hub_utils.py +4 -1
  49. diffusers/utils/import_utils.py +41 -0
  50. diffusers/utils/pil_utils.py +24 -0
  51. diffusers/utils/testing_utils.py +10 -0
  52. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
  53. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
  54. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
  55. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
  56. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
  57. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,947 @@
1
+ import html
2
+ import inspect
3
+ import re
4
+ import urllib.parse as ul
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import PIL
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
12
+
13
+ from ...models import UNet2DConditionModel
14
+ from ...schedulers import DDPMScheduler
15
+ from ...utils import (
16
+ BACKENDS_MAPPING,
17
+ is_accelerate_available,
18
+ is_accelerate_version,
19
+ is_bs4_available,
20
+ is_ftfy_available,
21
+ logging,
22
+ randn_tensor,
23
+ replace_example_docstring,
24
+ )
25
+ from ..pipeline_utils import DiffusionPipeline
26
+ from . import IFPipelineOutput
27
+ from .safety_checker import IFSafetyChecker
28
+ from .watermark import IFWatermarker
29
+
30
+
31
+ if is_bs4_available():
32
+ from bs4 import BeautifulSoup
33
+
34
+ if is_ftfy_available():
35
+ import ftfy
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ EXAMPLE_DOC_STRING = """
42
+ Examples:
43
+ ```py
44
+ >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline
45
+ >>> from diffusers.utils import pt_to_pil
46
+ >>> import torch
47
+
48
+ >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
49
+ >>> pipe.enable_model_cpu_offload()
50
+
51
+ >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
52
+ >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
53
+
54
+ >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images
55
+
56
+ >>> # save intermediate image
57
+ >>> pil_image = pt_to_pil(image)
58
+ >>> pil_image[0].save("./if_stage_I.png")
59
+
60
+ >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained(
61
+ ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
62
+ ... )
63
+ >>> super_res_1_pipe.enable_model_cpu_offload()
64
+
65
+ >>> image = super_res_1_pipe(
66
+ ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds
67
+ ... ).images
68
+ >>> image[0].save("./if_stage_II.png")
69
+ ```
70
+ """
71
+
72
+
73
+ class IFSuperResolutionPipeline(DiffusionPipeline):
74
+ tokenizer: T5Tokenizer
75
+ text_encoder: T5EncoderModel
76
+
77
+ unet: UNet2DConditionModel
78
+ scheduler: DDPMScheduler
79
+ image_noising_scheduler: DDPMScheduler
80
+
81
+ feature_extractor: Optional[CLIPImageProcessor]
82
+ safety_checker: Optional[IFSafetyChecker]
83
+
84
+ watermarker: Optional[IFWatermarker]
85
+
86
+ bad_punct_regex = re.compile(
87
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
88
+ ) # noqa
89
+
90
+ _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
91
+
92
+ def __init__(
93
+ self,
94
+ tokenizer: T5Tokenizer,
95
+ text_encoder: T5EncoderModel,
96
+ unet: UNet2DConditionModel,
97
+ scheduler: DDPMScheduler,
98
+ image_noising_scheduler: DDPMScheduler,
99
+ safety_checker: Optional[IFSafetyChecker],
100
+ feature_extractor: Optional[CLIPImageProcessor],
101
+ watermarker: Optional[IFWatermarker],
102
+ requires_safety_checker: bool = True,
103
+ ):
104
+ super().__init__()
105
+
106
+ if safety_checker is None and requires_safety_checker:
107
+ logger.warning(
108
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
109
+ " that you abide to the conditions of the IF license and do not expose unfiltered"
110
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
111
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
112
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
113
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
114
+ )
115
+
116
+ if safety_checker is not None and feature_extractor is None:
117
+ raise ValueError(
118
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
119
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
120
+ )
121
+
122
+ if unet.config.in_channels != 6:
123
+ logger.warn(
124
+ "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
125
+ )
126
+
127
+ self.register_modules(
128
+ tokenizer=tokenizer,
129
+ text_encoder=text_encoder,
130
+ unet=unet,
131
+ scheduler=scheduler,
132
+ image_noising_scheduler=image_noising_scheduler,
133
+ safety_checker=safety_checker,
134
+ feature_extractor=feature_extractor,
135
+ watermarker=watermarker,
136
+ )
137
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
138
+
139
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload
140
+ def enable_sequential_cpu_offload(self, gpu_id=0):
141
+ r"""
142
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
143
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
144
+ when their specific submodule has its `forward` method called.
145
+ """
146
+ if is_accelerate_available():
147
+ from accelerate import cpu_offload
148
+ else:
149
+ raise ImportError("Please install accelerate via `pip install accelerate`")
150
+
151
+ device = torch.device(f"cuda:{gpu_id}")
152
+
153
+ models = [
154
+ self.text_encoder,
155
+ self.unet,
156
+ ]
157
+ for cpu_offloaded_model in models:
158
+ if cpu_offloaded_model is not None:
159
+ cpu_offload(cpu_offloaded_model, device)
160
+
161
+ if self.safety_checker is not None:
162
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
163
+
164
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload
165
+ def enable_model_cpu_offload(self, gpu_id=0):
166
+ r"""
167
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
168
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
169
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
170
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
171
+ """
172
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
173
+ from accelerate import cpu_offload_with_hook
174
+ else:
175
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
176
+
177
+ device = torch.device(f"cuda:{gpu_id}")
178
+
179
+ if self.device.type != "cpu":
180
+ self.to("cpu", silence_dtype_warnings=True)
181
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
182
+
183
+ hook = None
184
+
185
+ if self.text_encoder is not None:
186
+ _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
187
+
188
+ # Accelerate will move the next model to the device _before_ calling the offload hook of the
189
+ # previous model. This will cause both models to be present on the device at the same time.
190
+ # IF uses T5 for its text encoder which is really large. We can manually call the offload
191
+ # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
192
+ # the GPU.
193
+ self.text_encoder_offload_hook = hook
194
+
195
+ _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
196
+
197
+ # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
198
+ self.unet_offload_hook = hook
199
+
200
+ if self.safety_checker is not None:
201
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
202
+
203
+ # We'll offload the last model manually.
204
+ self.final_offload_hook = hook
205
+
206
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
207
+ def remove_all_hooks(self):
208
+ if is_accelerate_available():
209
+ from accelerate.hooks import remove_hook_from_module
210
+ else:
211
+ raise ImportError("Please install accelerate via `pip install accelerate`")
212
+
213
+ for model in [self.text_encoder, self.unet, self.safety_checker]:
214
+ if model is not None:
215
+ remove_hook_from_module(model, recurse=True)
216
+
217
+ self.unet_offload_hook = None
218
+ self.text_encoder_offload_hook = None
219
+ self.final_offload_hook = None
220
+
221
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
222
+ def _text_preprocessing(self, text, clean_caption=False):
223
+ if clean_caption and not is_bs4_available():
224
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
225
+ logger.warn("Setting `clean_caption` to False...")
226
+ clean_caption = False
227
+
228
+ if clean_caption and not is_ftfy_available():
229
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
230
+ logger.warn("Setting `clean_caption` to False...")
231
+ clean_caption = False
232
+
233
+ if not isinstance(text, (tuple, list)):
234
+ text = [text]
235
+
236
+ def process(text: str):
237
+ if clean_caption:
238
+ text = self._clean_caption(text)
239
+ text = self._clean_caption(text)
240
+ else:
241
+ text = text.lower().strip()
242
+ return text
243
+
244
+ return [process(t) for t in text]
245
+
246
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
247
+ def _clean_caption(self, caption):
248
+ caption = str(caption)
249
+ caption = ul.unquote_plus(caption)
250
+ caption = caption.strip().lower()
251
+ caption = re.sub("<person>", "person", caption)
252
+ # urls:
253
+ caption = re.sub(
254
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
255
+ "",
256
+ caption,
257
+ ) # regex for urls
258
+ caption = re.sub(
259
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
260
+ "",
261
+ caption,
262
+ ) # regex for urls
263
+ # html:
264
+ caption = BeautifulSoup(caption, features="html.parser").text
265
+
266
+ # @<nickname>
267
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
268
+
269
+ # 31C0—31EF CJK Strokes
270
+ # 31F0—31FF Katakana Phonetic Extensions
271
+ # 3200—32FF Enclosed CJK Letters and Months
272
+ # 3300—33FF CJK Compatibility
273
+ # 3400—4DBF CJK Unified Ideographs Extension A
274
+ # 4DC0—4DFF Yijing Hexagram Symbols
275
+ # 4E00—9FFF CJK Unified Ideographs
276
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
277
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
278
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
279
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
280
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
281
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
282
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
283
+ #######################################################
284
+
285
+ # все виды тире / all types of dash --> "-"
286
+ caption = re.sub(
287
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
288
+ "-",
289
+ caption,
290
+ )
291
+
292
+ # кавычки к одному стандарту
293
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
294
+ caption = re.sub(r"[‘’]", "'", caption)
295
+
296
+ # &quot;
297
+ caption = re.sub(r"&quot;?", "", caption)
298
+ # &amp
299
+ caption = re.sub(r"&amp", "", caption)
300
+
301
+ # ip adresses:
302
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
303
+
304
+ # article ids:
305
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
306
+
307
+ # \n
308
+ caption = re.sub(r"\\n", " ", caption)
309
+
310
+ # "#123"
311
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
312
+ # "#12345.."
313
+ caption = re.sub(r"#\d{5,}\b", "", caption)
314
+ # "123456.."
315
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
316
+ # filenames:
317
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
318
+
319
+ #
320
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
321
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
322
+
323
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
324
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
325
+
326
+ # this-is-my-cute-cat / this_is_my_cute_cat
327
+ regex2 = re.compile(r"(?:\-|\_)")
328
+ if len(re.findall(regex2, caption)) > 3:
329
+ caption = re.sub(regex2, " ", caption)
330
+
331
+ caption = ftfy.fix_text(caption)
332
+ caption = html.unescape(html.unescape(caption))
333
+
334
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
335
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
336
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
337
+
338
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
339
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
340
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
341
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
342
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
343
+
344
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
345
+
346
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
347
+
348
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
349
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
350
+ caption = re.sub(r"\s+", " ", caption)
351
+
352
+ caption.strip()
353
+
354
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
355
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
356
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
357
+ caption = re.sub(r"^\.\S+$", "", caption)
358
+
359
+ return caption.strip()
360
+
361
+ @property
362
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
363
+ def _execution_device(self):
364
+ r"""
365
+ Returns the device on which the pipeline's models will be executed. After calling
366
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
367
+ hooks.
368
+ """
369
+ if not hasattr(self.unet, "_hf_hook"):
370
+ return self.device
371
+ for module in self.unet.modules():
372
+ if (
373
+ hasattr(module, "_hf_hook")
374
+ and hasattr(module._hf_hook, "execution_device")
375
+ and module._hf_hook.execution_device is not None
376
+ ):
377
+ return torch.device(module._hf_hook.execution_device)
378
+ return self.device
379
+
380
+ @torch.no_grad()
381
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
382
+ def encode_prompt(
383
+ self,
384
+ prompt,
385
+ do_classifier_free_guidance=True,
386
+ num_images_per_prompt=1,
387
+ device=None,
388
+ negative_prompt=None,
389
+ prompt_embeds: Optional[torch.FloatTensor] = None,
390
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
391
+ clean_caption: bool = False,
392
+ ):
393
+ r"""
394
+ Encodes the prompt into text encoder hidden states.
395
+
396
+ Args:
397
+ prompt (`str` or `List[str]`, *optional*):
398
+ prompt to be encoded
399
+ device: (`torch.device`, *optional*):
400
+ torch device to place the resulting embeddings on
401
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
402
+ number of images that should be generated per prompt
403
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
404
+ whether to use classifier free guidance or not
405
+ negative_prompt (`str` or `List[str]`, *optional*):
406
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
407
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
408
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
409
+ prompt_embeds (`torch.FloatTensor`, *optional*):
410
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
411
+ provided, text embeddings will be generated from `prompt` input argument.
412
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
413
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
414
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
415
+ argument.
416
+ """
417
+ if prompt is not None and negative_prompt is not None:
418
+ if type(prompt) is not type(negative_prompt):
419
+ raise TypeError(
420
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
421
+ f" {type(prompt)}."
422
+ )
423
+
424
+ if device is None:
425
+ device = self._execution_device
426
+
427
+ if prompt is not None and isinstance(prompt, str):
428
+ batch_size = 1
429
+ elif prompt is not None and isinstance(prompt, list):
430
+ batch_size = len(prompt)
431
+ else:
432
+ batch_size = prompt_embeds.shape[0]
433
+
434
+ # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
435
+ max_length = 77
436
+
437
+ if prompt_embeds is None:
438
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
439
+ text_inputs = self.tokenizer(
440
+ prompt,
441
+ padding="max_length",
442
+ max_length=max_length,
443
+ truncation=True,
444
+ add_special_tokens=True,
445
+ return_tensors="pt",
446
+ )
447
+ text_input_ids = text_inputs.input_ids
448
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
449
+
450
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
451
+ text_input_ids, untruncated_ids
452
+ ):
453
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
454
+ logger.warning(
455
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
456
+ f" {max_length} tokens: {removed_text}"
457
+ )
458
+
459
+ attention_mask = text_inputs.attention_mask.to(device)
460
+
461
+ prompt_embeds = self.text_encoder(
462
+ text_input_ids.to(device),
463
+ attention_mask=attention_mask,
464
+ )
465
+ prompt_embeds = prompt_embeds[0]
466
+
467
+ if self.text_encoder is not None:
468
+ dtype = self.text_encoder.dtype
469
+ elif self.unet is not None:
470
+ dtype = self.unet.dtype
471
+ else:
472
+ dtype = None
473
+
474
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
475
+
476
+ bs_embed, seq_len, _ = prompt_embeds.shape
477
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
478
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
479
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
480
+
481
+ # get unconditional embeddings for classifier free guidance
482
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
483
+ uncond_tokens: List[str]
484
+ if negative_prompt is None:
485
+ uncond_tokens = [""] * batch_size
486
+ elif isinstance(negative_prompt, str):
487
+ uncond_tokens = [negative_prompt]
488
+ elif batch_size != len(negative_prompt):
489
+ raise ValueError(
490
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
491
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
492
+ " the batch size of `prompt`."
493
+ )
494
+ else:
495
+ uncond_tokens = negative_prompt
496
+
497
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
498
+ max_length = prompt_embeds.shape[1]
499
+ uncond_input = self.tokenizer(
500
+ uncond_tokens,
501
+ padding="max_length",
502
+ max_length=max_length,
503
+ truncation=True,
504
+ return_attention_mask=True,
505
+ add_special_tokens=True,
506
+ return_tensors="pt",
507
+ )
508
+ attention_mask = uncond_input.attention_mask.to(device)
509
+
510
+ negative_prompt_embeds = self.text_encoder(
511
+ uncond_input.input_ids.to(device),
512
+ attention_mask=attention_mask,
513
+ )
514
+ negative_prompt_embeds = negative_prompt_embeds[0]
515
+
516
+ if do_classifier_free_guidance:
517
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
518
+ seq_len = negative_prompt_embeds.shape[1]
519
+
520
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
521
+
522
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
523
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
524
+
525
+ # For classifier free guidance, we need to do two forward passes.
526
+ # Here we concatenate the unconditional and text embeddings into a single batch
527
+ # to avoid doing two forward passes
528
+ else:
529
+ negative_prompt_embeds = None
530
+
531
+ return prompt_embeds, negative_prompt_embeds
532
+
533
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker
534
+ def run_safety_checker(self, image, device, dtype):
535
+ if self.safety_checker is not None:
536
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
537
+ image, nsfw_detected, watermark_detected = self.safety_checker(
538
+ images=image,
539
+ clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
540
+ )
541
+ else:
542
+ nsfw_detected = None
543
+ watermark_detected = None
544
+
545
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
546
+ self.unet_offload_hook.offload()
547
+
548
+ return image, nsfw_detected, watermark_detected
549
+
550
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
551
+ def prepare_extra_step_kwargs(self, generator, eta):
552
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
553
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
554
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
555
+ # and should be between [0, 1]
556
+
557
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
558
+ extra_step_kwargs = {}
559
+ if accepts_eta:
560
+ extra_step_kwargs["eta"] = eta
561
+
562
+ # check if the scheduler accepts generator
563
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
564
+ if accepts_generator:
565
+ extra_step_kwargs["generator"] = generator
566
+ return extra_step_kwargs
567
+
568
+ def check_inputs(
569
+ self,
570
+ prompt,
571
+ image,
572
+ batch_size,
573
+ noise_level,
574
+ callback_steps,
575
+ negative_prompt=None,
576
+ prompt_embeds=None,
577
+ negative_prompt_embeds=None,
578
+ ):
579
+ if (callback_steps is None) or (
580
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
581
+ ):
582
+ raise ValueError(
583
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
584
+ f" {type(callback_steps)}."
585
+ )
586
+
587
+ if prompt is not None and prompt_embeds is not None:
588
+ raise ValueError(
589
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
590
+ " only forward one of the two."
591
+ )
592
+ elif prompt is None and prompt_embeds is None:
593
+ raise ValueError(
594
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
595
+ )
596
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
597
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
598
+
599
+ if negative_prompt is not None and negative_prompt_embeds is not None:
600
+ raise ValueError(
601
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
602
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
603
+ )
604
+
605
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
606
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
607
+ raise ValueError(
608
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
609
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
610
+ f" {negative_prompt_embeds.shape}."
611
+ )
612
+
613
+ if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
614
+ raise ValueError(
615
+ f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})"
616
+ )
617
+
618
+ if isinstance(image, list):
619
+ check_image_type = image[0]
620
+ else:
621
+ check_image_type = image
622
+
623
+ if (
624
+ not isinstance(check_image_type, torch.Tensor)
625
+ and not isinstance(check_image_type, PIL.Image.Image)
626
+ and not isinstance(check_image_type, np.ndarray)
627
+ ):
628
+ raise ValueError(
629
+ "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
630
+ f" {type(check_image_type)}"
631
+ )
632
+
633
+ if isinstance(image, list):
634
+ image_batch_size = len(image)
635
+ elif isinstance(image, torch.Tensor):
636
+ image_batch_size = image.shape[0]
637
+ elif isinstance(image, PIL.Image.Image):
638
+ image_batch_size = 1
639
+ elif isinstance(image, np.ndarray):
640
+ image_batch_size = image.shape[0]
641
+ else:
642
+ assert False
643
+
644
+ if batch_size != image_batch_size:
645
+ raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
646
+
647
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_intermediate_images
648
+ def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator):
649
+ shape = (batch_size, num_channels, height, width)
650
+ if isinstance(generator, list) and len(generator) != batch_size:
651
+ raise ValueError(
652
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
653
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
654
+ )
655
+
656
+ intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
657
+
658
+ # scale the initial noise by the standard deviation required by the scheduler
659
+ intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
660
+ return intermediate_images
661
+
662
+ def preprocess_image(self, image, num_images_per_prompt, device):
663
+ if not isinstance(image, torch.Tensor) and not isinstance(image, list):
664
+ image = [image]
665
+
666
+ if isinstance(image[0], PIL.Image.Image):
667
+ image = [np.array(i).astype(np.float32) / 255.0 for i in image]
668
+
669
+ image = np.stack(image, axis=0) # to np
670
+ torch.from_numpy(image.transpose(0, 3, 1, 2))
671
+ elif isinstance(image[0], np.ndarray):
672
+ image = np.stack(image, axis=0) # to np
673
+ if image.ndim == 5:
674
+ image = image[0]
675
+
676
+ image = torch.from_numpy(image.transpose(0, 3, 1, 2))
677
+ elif isinstance(image, list) and isinstance(image[0], torch.Tensor):
678
+ dims = image[0].ndim
679
+
680
+ if dims == 3:
681
+ image = torch.stack(image, dim=0)
682
+ elif dims == 4:
683
+ image = torch.concat(image, dim=0)
684
+ else:
685
+ raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}")
686
+
687
+ image = image.to(device=device, dtype=self.unet.dtype)
688
+
689
+ image = image.repeat_interleave(num_images_per_prompt, dim=0)
690
+
691
+ return image
692
+
693
+ @torch.no_grad()
694
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
695
+ def __call__(
696
+ self,
697
+ prompt: Union[str, List[str]] = None,
698
+ image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
699
+ num_inference_steps: int = 50,
700
+ timesteps: List[int] = None,
701
+ guidance_scale: float = 4.0,
702
+ negative_prompt: Optional[Union[str, List[str]]] = None,
703
+ num_images_per_prompt: Optional[int] = 1,
704
+ eta: float = 0.0,
705
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
706
+ prompt_embeds: Optional[torch.FloatTensor] = None,
707
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
708
+ output_type: Optional[str] = "pil",
709
+ return_dict: bool = True,
710
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
711
+ callback_steps: int = 1,
712
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
713
+ noise_level: int = 250,
714
+ clean_caption: bool = True,
715
+ ):
716
+ """
717
+ Function invoked when calling the pipeline for generation.
718
+
719
+ Args:
720
+ prompt (`str` or `List[str]`, *optional*):
721
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
722
+ instead.
723
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`):
724
+ The image to be upscaled.
725
+ num_inference_steps (`int`, *optional*, defaults to 50):
726
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
727
+ expense of slower inference.
728
+ timesteps (`List[int]`, *optional*):
729
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
730
+ timesteps are used. Must be in descending order.
731
+ guidance_scale (`float`, *optional*, defaults to 7.5):
732
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
733
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
734
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
735
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
736
+ usually at the expense of lower image quality.
737
+ negative_prompt (`str` or `List[str]`, *optional*):
738
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
739
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
740
+ less than `1`).
741
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
742
+ The number of images to generate per prompt.
743
+ eta (`float`, *optional*, defaults to 0.0):
744
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
745
+ [`schedulers.DDIMScheduler`], will be ignored for others.
746
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
747
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
748
+ to make generation deterministic.
749
+ prompt_embeds (`torch.FloatTensor`, *optional*):
750
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
751
+ provided, text embeddings will be generated from `prompt` input argument.
752
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
753
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
754
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
755
+ argument.
756
+ output_type (`str`, *optional*, defaults to `"pil"`):
757
+ The output format of the generate image. Choose between
758
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
759
+ return_dict (`bool`, *optional*, defaults to `True`):
760
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
761
+ callback (`Callable`, *optional*):
762
+ A function that will be called every `callback_steps` steps during inference. The function will be
763
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
764
+ callback_steps (`int`, *optional*, defaults to 1):
765
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
766
+ called at every step.
767
+ cross_attention_kwargs (`dict`, *optional*):
768
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
769
+ `self.processor` in
770
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
771
+ noise_level (`int`, *optional*, defaults to 250):
772
+ The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)`
773
+ clean_caption (`bool`, *optional*, defaults to `True`):
774
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
775
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
776
+ prompt.
777
+
778
+ Examples:
779
+
780
+ Returns:
781
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
782
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
783
+ returning a tuple, the first element is a list with the generated images, and the second element is a list
784
+ of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
785
+ or watermarked content, according to the `safety_checker`.
786
+ """
787
+ # 1. Check inputs. Raise error if not correct
788
+
789
+ if prompt is not None and isinstance(prompt, str):
790
+ batch_size = 1
791
+ elif prompt is not None and isinstance(prompt, list):
792
+ batch_size = len(prompt)
793
+ else:
794
+ batch_size = prompt_embeds.shape[0]
795
+
796
+ self.check_inputs(
797
+ prompt,
798
+ image,
799
+ batch_size,
800
+ noise_level,
801
+ callback_steps,
802
+ negative_prompt,
803
+ prompt_embeds,
804
+ negative_prompt_embeds,
805
+ )
806
+
807
+ # 2. Define call parameters
808
+
809
+ height = self.unet.config.sample_size
810
+ width = self.unet.config.sample_size
811
+
812
+ device = self._execution_device
813
+
814
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
815
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
816
+ # corresponds to doing no classifier free guidance.
817
+ do_classifier_free_guidance = guidance_scale > 1.0
818
+
819
+ # 3. Encode input prompt
820
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
821
+ prompt,
822
+ do_classifier_free_guidance,
823
+ num_images_per_prompt=num_images_per_prompt,
824
+ device=device,
825
+ negative_prompt=negative_prompt,
826
+ prompt_embeds=prompt_embeds,
827
+ negative_prompt_embeds=negative_prompt_embeds,
828
+ clean_caption=clean_caption,
829
+ )
830
+
831
+ if do_classifier_free_guidance:
832
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
833
+
834
+ # 4. Prepare timesteps
835
+ if timesteps is not None:
836
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
837
+ timesteps = self.scheduler.timesteps
838
+ num_inference_steps = len(timesteps)
839
+ else:
840
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
841
+ timesteps = self.scheduler.timesteps
842
+
843
+ # 5. Prepare intermediate images
844
+ num_channels = self.unet.config.in_channels // 2
845
+ intermediate_images = self.prepare_intermediate_images(
846
+ batch_size * num_images_per_prompt,
847
+ num_channels,
848
+ height,
849
+ width,
850
+ prompt_embeds.dtype,
851
+ device,
852
+ generator,
853
+ )
854
+
855
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
856
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
857
+
858
+ # 7. Prepare upscaled image and noise level
859
+ image = self.preprocess_image(image, num_images_per_prompt, device)
860
+ upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True)
861
+
862
+ noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
863
+ noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
864
+ upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)
865
+
866
+ if do_classifier_free_guidance:
867
+ noise_level = torch.cat([noise_level] * 2)
868
+
869
+ # HACK: see comment in `enable_model_cpu_offload`
870
+ if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
871
+ self.text_encoder_offload_hook.offload()
872
+
873
+ # 8. Denoising loop
874
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
875
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
876
+ for i, t in enumerate(timesteps):
877
+ model_input = torch.cat([intermediate_images, upscaled], dim=1)
878
+
879
+ model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
880
+ model_input = self.scheduler.scale_model_input(model_input, t)
881
+
882
+ # predict the noise residual
883
+ noise_pred = self.unet(
884
+ model_input,
885
+ t,
886
+ encoder_hidden_states=prompt_embeds,
887
+ class_labels=noise_level,
888
+ cross_attention_kwargs=cross_attention_kwargs,
889
+ ).sample
890
+
891
+ # perform guidance
892
+ if do_classifier_free_guidance:
893
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
894
+ noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
895
+ noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
896
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
897
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
898
+
899
+ # compute the previous noisy sample x_t -> x_t-1
900
+ intermediate_images = self.scheduler.step(
901
+ noise_pred, t, intermediate_images, **extra_step_kwargs
902
+ ).prev_sample
903
+
904
+ # call the callback, if provided
905
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
906
+ progress_bar.update()
907
+ if callback is not None and i % callback_steps == 0:
908
+ callback(i, t, intermediate_images)
909
+
910
+ image = intermediate_images
911
+
912
+ if output_type == "pil":
913
+ # 9. Post-processing
914
+ image = (image / 2 + 0.5).clamp(0, 1)
915
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
916
+
917
+ # 10. Run safety checker
918
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
919
+
920
+ # 11. Convert to PIL
921
+ image = self.numpy_to_pil(image)
922
+
923
+ # 12. Apply watermark
924
+ if self.watermarker is not None:
925
+ self.watermarker.apply_watermark(image, self.unet.config.sample_size)
926
+ elif output_type == "pt":
927
+ nsfw_detected = None
928
+ watermark_detected = None
929
+
930
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
931
+ self.unet_offload_hook.offload()
932
+ else:
933
+ # 9. Post-processing
934
+ image = (image / 2 + 0.5).clamp(0, 1)
935
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
936
+
937
+ # 10. Run safety checker
938
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
939
+
940
+ # Offload last model to CPU
941
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
942
+ self.final_offload_hook.offload()
943
+
944
+ if not return_dict:
945
+ return (image, nsfw_detected, watermark_detected)
946
+
947
+ return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)