diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. diffusers/__init__.py +7 -2
  2. diffusers/configuration_utils.py +4 -0
  3. diffusers/loaders.py +262 -12
  4. diffusers/models/attention.py +31 -12
  5. diffusers/models/attention_processor.py +189 -0
  6. diffusers/models/controlnet.py +9 -2
  7. diffusers/models/embeddings.py +66 -0
  8. diffusers/models/modeling_pytorch_flax_utils.py +6 -0
  9. diffusers/models/modeling_utils.py +5 -2
  10. diffusers/models/transformer_2d.py +1 -1
  11. diffusers/models/unet_2d_condition.py +45 -6
  12. diffusers/models/vae.py +3 -0
  13. diffusers/pipelines/__init__.py +8 -0
  14. diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
  15. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
  16. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
  17. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
  18. diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
  19. diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
  20. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
  21. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
  22. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
  23. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
  24. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
  25. diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
  26. diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
  27. diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
  28. diffusers/pipelines/pipeline_utils.py +54 -25
  29. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
  30. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
  31. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
  32. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
  33. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
  34. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
  35. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
  36. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
  37. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
  38. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
  39. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
  40. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
  41. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
  42. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
  43. diffusers/schedulers/scheduling_ddpm.py +63 -16
  44. diffusers/schedulers/scheduling_heun_discrete.py +51 -1
  45. diffusers/utils/__init__.py +4 -1
  46. diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
  47. diffusers/utils/dynamic_modules_utils.py +1 -1
  48. diffusers/utils/hub_utils.py +4 -1
  49. diffusers/utils/import_utils.py +41 -0
  50. diffusers/utils/pil_utils.py +24 -0
  51. diffusers/utils/testing_utils.py +10 -0
  52. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
  53. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
  54. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
  55. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
  56. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
  57. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1208 @@
1
+ import html
2
+ import inspect
3
+ import re
4
+ import urllib.parse as ul
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import PIL
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
12
+
13
+ from ...models import UNet2DConditionModel
14
+ from ...schedulers import DDPMScheduler
15
+ from ...utils import (
16
+ BACKENDS_MAPPING,
17
+ PIL_INTERPOLATION,
18
+ is_accelerate_available,
19
+ is_accelerate_version,
20
+ is_bs4_available,
21
+ is_ftfy_available,
22
+ logging,
23
+ randn_tensor,
24
+ replace_example_docstring,
25
+ )
26
+ from ..pipeline_utils import DiffusionPipeline
27
+ from . import IFPipelineOutput
28
+ from .safety_checker import IFSafetyChecker
29
+ from .watermark import IFWatermarker
30
+
31
+
32
+ if is_bs4_available():
33
+ from bs4 import BeautifulSoup
34
+
35
+ if is_ftfy_available():
36
+ import ftfy
37
+
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize
43
+ def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
44
+ w, h = images.size
45
+
46
+ coef = w / h
47
+
48
+ w, h = img_size, img_size
49
+
50
+ if coef >= 1:
51
+ w = int(round(img_size / 8 * coef) * 8)
52
+ else:
53
+ h = int(round(img_size / 8 / coef) * 8)
54
+
55
+ images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None)
56
+
57
+ return images
58
+
59
+
60
+ EXAMPLE_DOC_STRING = """
61
+ Examples:
62
+ ```py
63
+ >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline
64
+ >>> from diffusers.utils import pt_to_pil
65
+ >>> import torch
66
+ >>> from PIL import Image
67
+ >>> import requests
68
+ >>> from io import BytesIO
69
+
70
+ >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png"
71
+ >>> response = requests.get(url)
72
+ >>> original_image = Image.open(BytesIO(response.content)).convert("RGB")
73
+ >>> original_image = original_image
74
+
75
+ >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png"
76
+ >>> response = requests.get(url)
77
+ >>> mask_image = Image.open(BytesIO(response.content))
78
+ >>> mask_image = mask_image
79
+
80
+ >>> pipe = IFInpaintingPipeline.from_pretrained(
81
+ ... "DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16
82
+ ... )
83
+ >>> pipe.enable_model_cpu_offload()
84
+
85
+ >>> prompt = "blue sunglasses"
86
+
87
+ >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
88
+ >>> image = pipe(
89
+ ... image=original_image,
90
+ ... mask_image=mask_image,
91
+ ... prompt_embeds=prompt_embeds,
92
+ ... negative_prompt_embeds=negative_embeds,
93
+ ... output_type="pt",
94
+ ... ).images
95
+
96
+ >>> # save intermediate image
97
+ >>> pil_image = pt_to_pil(image)
98
+ >>> pil_image[0].save("./if_stage_I.png")
99
+
100
+ >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained(
101
+ ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
102
+ ... )
103
+ >>> super_res_1_pipe.enable_model_cpu_offload()
104
+
105
+ >>> image = super_res_1_pipe(
106
+ ... image=image,
107
+ ... mask_image=mask_image,
108
+ ... original_image=original_image,
109
+ ... prompt_embeds=prompt_embeds,
110
+ ... negative_prompt_embeds=negative_embeds,
111
+ ... ).images
112
+ >>> image[0].save("./if_stage_II.png")
113
+ ```
114
+ """
115
+
116
+
117
+ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline):
118
+ tokenizer: T5Tokenizer
119
+ text_encoder: T5EncoderModel
120
+
121
+ unet: UNet2DConditionModel
122
+ scheduler: DDPMScheduler
123
+ image_noising_scheduler: DDPMScheduler
124
+
125
+ feature_extractor: Optional[CLIPImageProcessor]
126
+ safety_checker: Optional[IFSafetyChecker]
127
+
128
+ watermarker: Optional[IFWatermarker]
129
+
130
+ bad_punct_regex = re.compile(
131
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
132
+ ) # noqa
133
+
134
+ _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
135
+
136
+ def __init__(
137
+ self,
138
+ tokenizer: T5Tokenizer,
139
+ text_encoder: T5EncoderModel,
140
+ unet: UNet2DConditionModel,
141
+ scheduler: DDPMScheduler,
142
+ image_noising_scheduler: DDPMScheduler,
143
+ safety_checker: Optional[IFSafetyChecker],
144
+ feature_extractor: Optional[CLIPImageProcessor],
145
+ watermarker: Optional[IFWatermarker],
146
+ requires_safety_checker: bool = True,
147
+ ):
148
+ super().__init__()
149
+
150
+ if safety_checker is None and requires_safety_checker:
151
+ logger.warning(
152
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
153
+ " that you abide to the conditions of the IF license and do not expose unfiltered"
154
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
155
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
156
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
157
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
158
+ )
159
+
160
+ if safety_checker is not None and feature_extractor is None:
161
+ raise ValueError(
162
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
163
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
164
+ )
165
+
166
+ if unet.config.in_channels != 6:
167
+ logger.warn(
168
+ "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
169
+ )
170
+
171
+ self.register_modules(
172
+ tokenizer=tokenizer,
173
+ text_encoder=text_encoder,
174
+ unet=unet,
175
+ scheduler=scheduler,
176
+ image_noising_scheduler=image_noising_scheduler,
177
+ safety_checker=safety_checker,
178
+ feature_extractor=feature_extractor,
179
+ watermarker=watermarker,
180
+ )
181
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
182
+
183
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload
184
+ def enable_sequential_cpu_offload(self, gpu_id=0):
185
+ r"""
186
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
187
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
188
+ when their specific submodule has its `forward` method called.
189
+ """
190
+ if is_accelerate_available():
191
+ from accelerate import cpu_offload
192
+ else:
193
+ raise ImportError("Please install accelerate via `pip install accelerate`")
194
+
195
+ device = torch.device(f"cuda:{gpu_id}")
196
+
197
+ models = [
198
+ self.text_encoder,
199
+ self.unet,
200
+ ]
201
+ for cpu_offloaded_model in models:
202
+ if cpu_offloaded_model is not None:
203
+ cpu_offload(cpu_offloaded_model, device)
204
+
205
+ if self.safety_checker is not None:
206
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
207
+
208
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload
209
+ def enable_model_cpu_offload(self, gpu_id=0):
210
+ r"""
211
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
212
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
213
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
214
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
215
+ """
216
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
217
+ from accelerate import cpu_offload_with_hook
218
+ else:
219
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
220
+
221
+ device = torch.device(f"cuda:{gpu_id}")
222
+
223
+ if self.device.type != "cpu":
224
+ self.to("cpu", silence_dtype_warnings=True)
225
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
226
+
227
+ hook = None
228
+
229
+ if self.text_encoder is not None:
230
+ _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
231
+
232
+ # Accelerate will move the next model to the device _before_ calling the offload hook of the
233
+ # previous model. This will cause both models to be present on the device at the same time.
234
+ # IF uses T5 for its text encoder which is really large. We can manually call the offload
235
+ # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
236
+ # the GPU.
237
+ self.text_encoder_offload_hook = hook
238
+
239
+ _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
240
+
241
+ # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
242
+ self.unet_offload_hook = hook
243
+
244
+ if self.safety_checker is not None:
245
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
246
+
247
+ # We'll offload the last model manually.
248
+ self.final_offload_hook = hook
249
+
250
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
251
+ def remove_all_hooks(self):
252
+ if is_accelerate_available():
253
+ from accelerate.hooks import remove_hook_from_module
254
+ else:
255
+ raise ImportError("Please install accelerate via `pip install accelerate`")
256
+
257
+ for model in [self.text_encoder, self.unet, self.safety_checker]:
258
+ if model is not None:
259
+ remove_hook_from_module(model, recurse=True)
260
+
261
+ self.unet_offload_hook = None
262
+ self.text_encoder_offload_hook = None
263
+ self.final_offload_hook = None
264
+
265
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
266
+ def _text_preprocessing(self, text, clean_caption=False):
267
+ if clean_caption and not is_bs4_available():
268
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
269
+ logger.warn("Setting `clean_caption` to False...")
270
+ clean_caption = False
271
+
272
+ if clean_caption and not is_ftfy_available():
273
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
274
+ logger.warn("Setting `clean_caption` to False...")
275
+ clean_caption = False
276
+
277
+ if not isinstance(text, (tuple, list)):
278
+ text = [text]
279
+
280
+ def process(text: str):
281
+ if clean_caption:
282
+ text = self._clean_caption(text)
283
+ text = self._clean_caption(text)
284
+ else:
285
+ text = text.lower().strip()
286
+ return text
287
+
288
+ return [process(t) for t in text]
289
+
290
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
291
+ def _clean_caption(self, caption):
292
+ caption = str(caption)
293
+ caption = ul.unquote_plus(caption)
294
+ caption = caption.strip().lower()
295
+ caption = re.sub("<person>", "person", caption)
296
+ # urls:
297
+ caption = re.sub(
298
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
299
+ "",
300
+ caption,
301
+ ) # regex for urls
302
+ caption = re.sub(
303
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
304
+ "",
305
+ caption,
306
+ ) # regex for urls
307
+ # html:
308
+ caption = BeautifulSoup(caption, features="html.parser").text
309
+
310
+ # @<nickname>
311
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
312
+
313
+ # 31C0—31EF CJK Strokes
314
+ # 31F0—31FF Katakana Phonetic Extensions
315
+ # 3200—32FF Enclosed CJK Letters and Months
316
+ # 3300—33FF CJK Compatibility
317
+ # 3400—4DBF CJK Unified Ideographs Extension A
318
+ # 4DC0—4DFF Yijing Hexagram Symbols
319
+ # 4E00—9FFF CJK Unified Ideographs
320
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
321
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
322
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
323
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
324
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
325
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
326
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
327
+ #######################################################
328
+
329
+ # все виды тире / all types of dash --> "-"
330
+ caption = re.sub(
331
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
332
+ "-",
333
+ caption,
334
+ )
335
+
336
+ # кавычки к одному стандарту
337
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
338
+ caption = re.sub(r"[‘’]", "'", caption)
339
+
340
+ # &quot;
341
+ caption = re.sub(r"&quot;?", "", caption)
342
+ # &amp
343
+ caption = re.sub(r"&amp", "", caption)
344
+
345
+ # ip adresses:
346
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
347
+
348
+ # article ids:
349
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
350
+
351
+ # \n
352
+ caption = re.sub(r"\\n", " ", caption)
353
+
354
+ # "#123"
355
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
356
+ # "#12345.."
357
+ caption = re.sub(r"#\d{5,}\b", "", caption)
358
+ # "123456.."
359
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
360
+ # filenames:
361
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
362
+
363
+ #
364
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
365
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
366
+
367
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
368
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
369
+
370
+ # this-is-my-cute-cat / this_is_my_cute_cat
371
+ regex2 = re.compile(r"(?:\-|\_)")
372
+ if len(re.findall(regex2, caption)) > 3:
373
+ caption = re.sub(regex2, " ", caption)
374
+
375
+ caption = ftfy.fix_text(caption)
376
+ caption = html.unescape(html.unescape(caption))
377
+
378
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
379
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
380
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
381
+
382
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
383
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
384
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
385
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
386
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
387
+
388
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
389
+
390
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
391
+
392
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
393
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
394
+ caption = re.sub(r"\s+", " ", caption)
395
+
396
+ caption.strip()
397
+
398
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
399
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
400
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
401
+ caption = re.sub(r"^\.\S+$", "", caption)
402
+
403
+ return caption.strip()
404
+
405
+ @property
406
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
407
+ def _execution_device(self):
408
+ r"""
409
+ Returns the device on which the pipeline's models will be executed. After calling
410
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
411
+ hooks.
412
+ """
413
+ if not hasattr(self.unet, "_hf_hook"):
414
+ return self.device
415
+ for module in self.unet.modules():
416
+ if (
417
+ hasattr(module, "_hf_hook")
418
+ and hasattr(module._hf_hook, "execution_device")
419
+ and module._hf_hook.execution_device is not None
420
+ ):
421
+ return torch.device(module._hf_hook.execution_device)
422
+ return self.device
423
+
424
+ @torch.no_grad()
425
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
426
+ def encode_prompt(
427
+ self,
428
+ prompt,
429
+ do_classifier_free_guidance=True,
430
+ num_images_per_prompt=1,
431
+ device=None,
432
+ negative_prompt=None,
433
+ prompt_embeds: Optional[torch.FloatTensor] = None,
434
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
435
+ clean_caption: bool = False,
436
+ ):
437
+ r"""
438
+ Encodes the prompt into text encoder hidden states.
439
+
440
+ Args:
441
+ prompt (`str` or `List[str]`, *optional*):
442
+ prompt to be encoded
443
+ device: (`torch.device`, *optional*):
444
+ torch device to place the resulting embeddings on
445
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
446
+ number of images that should be generated per prompt
447
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
448
+ whether to use classifier free guidance or not
449
+ negative_prompt (`str` or `List[str]`, *optional*):
450
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
451
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
452
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
453
+ prompt_embeds (`torch.FloatTensor`, *optional*):
454
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
455
+ provided, text embeddings will be generated from `prompt` input argument.
456
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
457
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
458
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
459
+ argument.
460
+ """
461
+ if prompt is not None and negative_prompt is not None:
462
+ if type(prompt) is not type(negative_prompt):
463
+ raise TypeError(
464
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
465
+ f" {type(prompt)}."
466
+ )
467
+
468
+ if device is None:
469
+ device = self._execution_device
470
+
471
+ if prompt is not None and isinstance(prompt, str):
472
+ batch_size = 1
473
+ elif prompt is not None and isinstance(prompt, list):
474
+ batch_size = len(prompt)
475
+ else:
476
+ batch_size = prompt_embeds.shape[0]
477
+
478
+ # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
479
+ max_length = 77
480
+
481
+ if prompt_embeds is None:
482
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
483
+ text_inputs = self.tokenizer(
484
+ prompt,
485
+ padding="max_length",
486
+ max_length=max_length,
487
+ truncation=True,
488
+ add_special_tokens=True,
489
+ return_tensors="pt",
490
+ )
491
+ text_input_ids = text_inputs.input_ids
492
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
493
+
494
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
495
+ text_input_ids, untruncated_ids
496
+ ):
497
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
498
+ logger.warning(
499
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
500
+ f" {max_length} tokens: {removed_text}"
501
+ )
502
+
503
+ attention_mask = text_inputs.attention_mask.to(device)
504
+
505
+ prompt_embeds = self.text_encoder(
506
+ text_input_ids.to(device),
507
+ attention_mask=attention_mask,
508
+ )
509
+ prompt_embeds = prompt_embeds[0]
510
+
511
+ if self.text_encoder is not None:
512
+ dtype = self.text_encoder.dtype
513
+ elif self.unet is not None:
514
+ dtype = self.unet.dtype
515
+ else:
516
+ dtype = None
517
+
518
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
519
+
520
+ bs_embed, seq_len, _ = prompt_embeds.shape
521
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
522
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
523
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
524
+
525
+ # get unconditional embeddings for classifier free guidance
526
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
527
+ uncond_tokens: List[str]
528
+ if negative_prompt is None:
529
+ uncond_tokens = [""] * batch_size
530
+ elif isinstance(negative_prompt, str):
531
+ uncond_tokens = [negative_prompt]
532
+ elif batch_size != len(negative_prompt):
533
+ raise ValueError(
534
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
535
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
536
+ " the batch size of `prompt`."
537
+ )
538
+ else:
539
+ uncond_tokens = negative_prompt
540
+
541
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
542
+ max_length = prompt_embeds.shape[1]
543
+ uncond_input = self.tokenizer(
544
+ uncond_tokens,
545
+ padding="max_length",
546
+ max_length=max_length,
547
+ truncation=True,
548
+ return_attention_mask=True,
549
+ add_special_tokens=True,
550
+ return_tensors="pt",
551
+ )
552
+ attention_mask = uncond_input.attention_mask.to(device)
553
+
554
+ negative_prompt_embeds = self.text_encoder(
555
+ uncond_input.input_ids.to(device),
556
+ attention_mask=attention_mask,
557
+ )
558
+ negative_prompt_embeds = negative_prompt_embeds[0]
559
+
560
+ if do_classifier_free_guidance:
561
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
562
+ seq_len = negative_prompt_embeds.shape[1]
563
+
564
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
565
+
566
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
567
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
568
+
569
+ # For classifier free guidance, we need to do two forward passes.
570
+ # Here we concatenate the unconditional and text embeddings into a single batch
571
+ # to avoid doing two forward passes
572
+ else:
573
+ negative_prompt_embeds = None
574
+
575
+ return prompt_embeds, negative_prompt_embeds
576
+
577
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker
578
+ def run_safety_checker(self, image, device, dtype):
579
+ if self.safety_checker is not None:
580
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
581
+ image, nsfw_detected, watermark_detected = self.safety_checker(
582
+ images=image,
583
+ clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
584
+ )
585
+ else:
586
+ nsfw_detected = None
587
+ watermark_detected = None
588
+
589
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
590
+ self.unet_offload_hook.offload()
591
+
592
+ return image, nsfw_detected, watermark_detected
593
+
594
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
595
+ def prepare_extra_step_kwargs(self, generator, eta):
596
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
597
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
598
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
599
+ # and should be between [0, 1]
600
+
601
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
602
+ extra_step_kwargs = {}
603
+ if accepts_eta:
604
+ extra_step_kwargs["eta"] = eta
605
+
606
+ # check if the scheduler accepts generator
607
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
608
+ if accepts_generator:
609
+ extra_step_kwargs["generator"] = generator
610
+ return extra_step_kwargs
611
+
612
+ def check_inputs(
613
+ self,
614
+ prompt,
615
+ image,
616
+ original_image,
617
+ mask_image,
618
+ batch_size,
619
+ callback_steps,
620
+ negative_prompt=None,
621
+ prompt_embeds=None,
622
+ negative_prompt_embeds=None,
623
+ ):
624
+ if (callback_steps is None) or (
625
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
626
+ ):
627
+ raise ValueError(
628
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
629
+ f" {type(callback_steps)}."
630
+ )
631
+
632
+ if prompt is not None and prompt_embeds is not None:
633
+ raise ValueError(
634
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
635
+ " only forward one of the two."
636
+ )
637
+ elif prompt is None and prompt_embeds is None:
638
+ raise ValueError(
639
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
640
+ )
641
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
642
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
643
+
644
+ if negative_prompt is not None and negative_prompt_embeds is not None:
645
+ raise ValueError(
646
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
647
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
648
+ )
649
+
650
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
651
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
652
+ raise ValueError(
653
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
654
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
655
+ f" {negative_prompt_embeds.shape}."
656
+ )
657
+
658
+ # image
659
+
660
+ if isinstance(image, list):
661
+ check_image_type = image[0]
662
+ else:
663
+ check_image_type = image
664
+
665
+ if (
666
+ not isinstance(check_image_type, torch.Tensor)
667
+ and not isinstance(check_image_type, PIL.Image.Image)
668
+ and not isinstance(check_image_type, np.ndarray)
669
+ ):
670
+ raise ValueError(
671
+ "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
672
+ f" {type(check_image_type)}"
673
+ )
674
+
675
+ if isinstance(image, list):
676
+ image_batch_size = len(image)
677
+ elif isinstance(image, torch.Tensor):
678
+ image_batch_size = image.shape[0]
679
+ elif isinstance(image, PIL.Image.Image):
680
+ image_batch_size = 1
681
+ elif isinstance(image, np.ndarray):
682
+ image_batch_size = image.shape[0]
683
+ else:
684
+ assert False
685
+
686
+ if batch_size != image_batch_size:
687
+ raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
688
+
689
+ # original_image
690
+
691
+ if isinstance(original_image, list):
692
+ check_image_type = original_image[0]
693
+ else:
694
+ check_image_type = original_image
695
+
696
+ if (
697
+ not isinstance(check_image_type, torch.Tensor)
698
+ and not isinstance(check_image_type, PIL.Image.Image)
699
+ and not isinstance(check_image_type, np.ndarray)
700
+ ):
701
+ raise ValueError(
702
+ "`original_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
703
+ f" {type(check_image_type)}"
704
+ )
705
+
706
+ if isinstance(original_image, list):
707
+ image_batch_size = len(original_image)
708
+ elif isinstance(original_image, torch.Tensor):
709
+ image_batch_size = original_image.shape[0]
710
+ elif isinstance(original_image, PIL.Image.Image):
711
+ image_batch_size = 1
712
+ elif isinstance(original_image, np.ndarray):
713
+ image_batch_size = original_image.shape[0]
714
+ else:
715
+ assert False
716
+
717
+ if batch_size != image_batch_size:
718
+ raise ValueError(
719
+ f"original_image batch size: {image_batch_size} must be same as prompt batch size {batch_size}"
720
+ )
721
+
722
+ # mask_image
723
+
724
+ if isinstance(mask_image, list):
725
+ check_image_type = mask_image[0]
726
+ else:
727
+ check_image_type = mask_image
728
+
729
+ if (
730
+ not isinstance(check_image_type, torch.Tensor)
731
+ and not isinstance(check_image_type, PIL.Image.Image)
732
+ and not isinstance(check_image_type, np.ndarray)
733
+ ):
734
+ raise ValueError(
735
+ "`mask_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
736
+ f" {type(check_image_type)}"
737
+ )
738
+
739
+ if isinstance(mask_image, list):
740
+ image_batch_size = len(mask_image)
741
+ elif isinstance(mask_image, torch.Tensor):
742
+ image_batch_size = mask_image.shape[0]
743
+ elif isinstance(mask_image, PIL.Image.Image):
744
+ image_batch_size = 1
745
+ elif isinstance(mask_image, np.ndarray):
746
+ image_batch_size = mask_image.shape[0]
747
+ else:
748
+ assert False
749
+
750
+ if image_batch_size != 1 and batch_size != image_batch_size:
751
+ raise ValueError(
752
+ f"mask_image batch size: {image_batch_size} must be `1` or the same as prompt batch size {batch_size}"
753
+ )
754
+
755
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image with preprocess_image -> preprocess_original_image
756
+ def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor:
757
+ if not isinstance(image, list):
758
+ image = [image]
759
+
760
+ def numpy_to_pt(images):
761
+ if images.ndim == 3:
762
+ images = images[..., None]
763
+
764
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
765
+ return images
766
+
767
+ if isinstance(image[0], PIL.Image.Image):
768
+ new_image = []
769
+
770
+ for image_ in image:
771
+ image_ = image_.convert("RGB")
772
+ image_ = resize(image_, self.unet.sample_size)
773
+ image_ = np.array(image_)
774
+ image_ = image_.astype(np.float32)
775
+ image_ = image_ / 127.5 - 1
776
+ new_image.append(image_)
777
+
778
+ image = new_image
779
+
780
+ image = np.stack(image, axis=0) # to np
781
+ image = numpy_to_pt(image) # to pt
782
+
783
+ elif isinstance(image[0], np.ndarray):
784
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
785
+ image = numpy_to_pt(image)
786
+
787
+ elif isinstance(image[0], torch.Tensor):
788
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
789
+
790
+ return image
791
+
792
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image
793
+ def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor:
794
+ if not isinstance(image, torch.Tensor) and not isinstance(image, list):
795
+ image = [image]
796
+
797
+ if isinstance(image[0], PIL.Image.Image):
798
+ image = [np.array(i).astype(np.float32) / 255.0 for i in image]
799
+
800
+ image = np.stack(image, axis=0) # to np
801
+ torch.from_numpy(image.transpose(0, 3, 1, 2))
802
+ elif isinstance(image[0], np.ndarray):
803
+ image = np.stack(image, axis=0) # to np
804
+ if image.ndim == 5:
805
+ image = image[0]
806
+
807
+ image = torch.from_numpy(image.transpose(0, 3, 1, 2))
808
+ elif isinstance(image, list) and isinstance(image[0], torch.Tensor):
809
+ dims = image[0].ndim
810
+
811
+ if dims == 3:
812
+ image = torch.stack(image, dim=0)
813
+ elif dims == 4:
814
+ image = torch.concat(image, dim=0)
815
+ else:
816
+ raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}")
817
+
818
+ image = image.to(device=device, dtype=self.unet.dtype)
819
+
820
+ image = image.repeat_interleave(num_images_per_prompt, dim=0)
821
+
822
+ return image
823
+
824
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.preprocess_mask_image
825
+ def preprocess_mask_image(self, mask_image) -> torch.Tensor:
826
+ if not isinstance(mask_image, list):
827
+ mask_image = [mask_image]
828
+
829
+ if isinstance(mask_image[0], torch.Tensor):
830
+ mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0)
831
+
832
+ if mask_image.ndim == 2:
833
+ # Batch and add channel dim for single mask
834
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0)
835
+ elif mask_image.ndim == 3 and mask_image.shape[0] == 1:
836
+ # Single mask, the 0'th dimension is considered to be
837
+ # the existing batch size of 1
838
+ mask_image = mask_image.unsqueeze(0)
839
+ elif mask_image.ndim == 3 and mask_image.shape[0] != 1:
840
+ # Batch of mask, the 0'th dimension is considered to be
841
+ # the batching dimension
842
+ mask_image = mask_image.unsqueeze(1)
843
+
844
+ mask_image[mask_image < 0.5] = 0
845
+ mask_image[mask_image >= 0.5] = 1
846
+
847
+ elif isinstance(mask_image[0], PIL.Image.Image):
848
+ new_mask_image = []
849
+
850
+ for mask_image_ in mask_image:
851
+ mask_image_ = mask_image_.convert("L")
852
+ mask_image_ = resize(mask_image_, self.unet.sample_size)
853
+ mask_image_ = np.array(mask_image_)
854
+ mask_image_ = mask_image_[None, None, :]
855
+ new_mask_image.append(mask_image_)
856
+
857
+ mask_image = new_mask_image
858
+
859
+ mask_image = np.concatenate(mask_image, axis=0)
860
+ mask_image = mask_image.astype(np.float32) / 255.0
861
+ mask_image[mask_image < 0.5] = 0
862
+ mask_image[mask_image >= 0.5] = 1
863
+ mask_image = torch.from_numpy(mask_image)
864
+
865
+ elif isinstance(mask_image[0], np.ndarray):
866
+ mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)
867
+
868
+ mask_image[mask_image < 0.5] = 0
869
+ mask_image[mask_image >= 0.5] = 1
870
+ mask_image = torch.from_numpy(mask_image)
871
+
872
+ return mask_image
873
+
874
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps
875
+ def get_timesteps(self, num_inference_steps, strength):
876
+ # get the original timestep using init_timestep
877
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
878
+
879
+ t_start = max(num_inference_steps - init_timestep, 0)
880
+ timesteps = self.scheduler.timesteps[t_start:]
881
+
882
+ return timesteps, num_inference_steps - t_start
883
+
884
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.prepare_intermediate_images
885
+ def prepare_intermediate_images(
886
+ self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None
887
+ ):
888
+ image_batch_size, channels, height, width = image.shape
889
+
890
+ batch_size = batch_size * num_images_per_prompt
891
+
892
+ shape = (batch_size, channels, height, width)
893
+
894
+ if isinstance(generator, list) and len(generator) != batch_size:
895
+ raise ValueError(
896
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
897
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
898
+ )
899
+
900
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
901
+
902
+ image = image.repeat_interleave(num_images_per_prompt, dim=0)
903
+ noised_image = self.scheduler.add_noise(image, noise, timestep)
904
+
905
+ image = (1 - mask_image) * image + mask_image * noised_image
906
+
907
+ return image
908
+
909
+ @torch.no_grad()
910
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
911
+ def __call__(
912
+ self,
913
+ image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor],
914
+ original_image: Union[
915
+ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]
916
+ ] = None,
917
+ mask_image: Union[
918
+ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]
919
+ ] = None,
920
+ strength: float = 0.8,
921
+ prompt: Union[str, List[str]] = None,
922
+ num_inference_steps: int = 100,
923
+ timesteps: List[int] = None,
924
+ guidance_scale: float = 4.0,
925
+ negative_prompt: Optional[Union[str, List[str]]] = None,
926
+ num_images_per_prompt: Optional[int] = 1,
927
+ eta: float = 0.0,
928
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
929
+ prompt_embeds: Optional[torch.FloatTensor] = None,
930
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
931
+ output_type: Optional[str] = "pil",
932
+ return_dict: bool = True,
933
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
934
+ callback_steps: int = 1,
935
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
936
+ noise_level: int = 0,
937
+ clean_caption: bool = True,
938
+ ):
939
+ """
940
+ Function invoked when calling the pipeline for generation.
941
+
942
+ Args:
943
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
944
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
945
+ process.
946
+ original_image (`torch.FloatTensor` or `PIL.Image.Image`):
947
+ The original image that `image` was varied from.
948
+ mask_image (`PIL.Image.Image`):
949
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
950
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
951
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
952
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
953
+ strength (`float`, *optional*, defaults to 0.8):
954
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
955
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
956
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
957
+ be maximum and the denoising process will run for the full number of iterations specified in
958
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
959
+ prompt (`str` or `List[str]`, *optional*):
960
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
961
+ instead.
962
+ num_inference_steps (`int`, *optional*, defaults to 50):
963
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
964
+ expense of slower inference.
965
+ timesteps (`List[int]`, *optional*):
966
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
967
+ timesteps are used. Must be in descending order.
968
+ guidance_scale (`float`, *optional*, defaults to 7.5):
969
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
970
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
971
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
972
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
973
+ usually at the expense of lower image quality.
974
+ negative_prompt (`str` or `List[str]`, *optional*):
975
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
976
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
977
+ less than `1`).
978
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
979
+ The number of images to generate per prompt.
980
+ eta (`float`, *optional*, defaults to 0.0):
981
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
982
+ [`schedulers.DDIMScheduler`], will be ignored for others.
983
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
984
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
985
+ to make generation deterministic.
986
+ prompt_embeds (`torch.FloatTensor`, *optional*):
987
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
988
+ provided, text embeddings will be generated from `prompt` input argument.
989
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
990
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
991
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
992
+ argument.
993
+ output_type (`str`, *optional*, defaults to `"pil"`):
994
+ The output format of the generate image. Choose between
995
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
996
+ return_dict (`bool`, *optional*, defaults to `True`):
997
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
998
+ callback (`Callable`, *optional*):
999
+ A function that will be called every `callback_steps` steps during inference. The function will be
1000
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1001
+ callback_steps (`int`, *optional*, defaults to 1):
1002
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1003
+ called at every step.
1004
+ cross_attention_kwargs (`dict`, *optional*):
1005
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1006
+ `self.processor` in
1007
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1008
+ noise_level (`int`, *optional*, defaults to 0):
1009
+ The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)`
1010
+ clean_caption (`bool`, *optional*, defaults to `True`):
1011
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
1012
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
1013
+ prompt.
1014
+
1015
+ Examples:
1016
+
1017
+ Returns:
1018
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
1019
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
1020
+ returning a tuple, the first element is a list with the generated images, and the second element is a list
1021
+ of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
1022
+ or watermarked content, according to the `safety_checker`.
1023
+ """
1024
+ # 1. Check inputs. Raise error if not correct
1025
+ if prompt is not None and isinstance(prompt, str):
1026
+ batch_size = 1
1027
+ elif prompt is not None and isinstance(prompt, list):
1028
+ batch_size = len(prompt)
1029
+ else:
1030
+ batch_size = prompt_embeds.shape[0]
1031
+
1032
+ self.check_inputs(
1033
+ prompt,
1034
+ image,
1035
+ original_image,
1036
+ mask_image,
1037
+ batch_size,
1038
+ callback_steps,
1039
+ negative_prompt,
1040
+ prompt_embeds,
1041
+ negative_prompt_embeds,
1042
+ )
1043
+
1044
+ # 2. Define call parameters
1045
+
1046
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1047
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1048
+ # corresponds to doing no classifier free guidance.
1049
+ do_classifier_free_guidance = guidance_scale > 1.0
1050
+
1051
+ device = self._execution_device
1052
+
1053
+ # 3. Encode input prompt
1054
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
1055
+ prompt,
1056
+ do_classifier_free_guidance,
1057
+ num_images_per_prompt=num_images_per_prompt,
1058
+ device=device,
1059
+ negative_prompt=negative_prompt,
1060
+ prompt_embeds=prompt_embeds,
1061
+ negative_prompt_embeds=negative_prompt_embeds,
1062
+ clean_caption=clean_caption,
1063
+ )
1064
+
1065
+ if do_classifier_free_guidance:
1066
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1067
+
1068
+ dtype = prompt_embeds.dtype
1069
+
1070
+ # 4. Prepare timesteps
1071
+ if timesteps is not None:
1072
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
1073
+ timesteps = self.scheduler.timesteps
1074
+ num_inference_steps = len(timesteps)
1075
+ else:
1076
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1077
+ timesteps = self.scheduler.timesteps
1078
+
1079
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
1080
+
1081
+ # 5. prepare original image
1082
+ original_image = self.preprocess_original_image(original_image)
1083
+ original_image = original_image.to(device=device, dtype=dtype)
1084
+
1085
+ # 6. prepare mask image
1086
+ mask_image = self.preprocess_mask_image(mask_image)
1087
+ mask_image = mask_image.to(device=device, dtype=dtype)
1088
+
1089
+ if mask_image.shape[0] == 1:
1090
+ mask_image = mask_image.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
1091
+ else:
1092
+ mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0)
1093
+
1094
+ # 6. Prepare intermediate images
1095
+ noise_timestep = timesteps[0:1]
1096
+ noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt)
1097
+
1098
+ intermediate_images = self.prepare_intermediate_images(
1099
+ original_image,
1100
+ noise_timestep,
1101
+ batch_size,
1102
+ num_images_per_prompt,
1103
+ dtype,
1104
+ device,
1105
+ mask_image,
1106
+ generator,
1107
+ )
1108
+
1109
+ # 7. Prepare upscaled image and noise level
1110
+ _, _, height, width = original_image.shape
1111
+
1112
+ image = self.preprocess_image(image, num_images_per_prompt, device)
1113
+
1114
+ upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True)
1115
+
1116
+ noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
1117
+ noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
1118
+ upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)
1119
+
1120
+ if do_classifier_free_guidance:
1121
+ noise_level = torch.cat([noise_level] * 2)
1122
+
1123
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1124
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1125
+
1126
+ # HACK: see comment in `enable_model_cpu_offload`
1127
+ if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
1128
+ self.text_encoder_offload_hook.offload()
1129
+
1130
+ # 9. Denoising loop
1131
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1132
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1133
+ for i, t in enumerate(timesteps):
1134
+ model_input = torch.cat([intermediate_images, upscaled], dim=1)
1135
+
1136
+ model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
1137
+ model_input = self.scheduler.scale_model_input(model_input, t)
1138
+
1139
+ # predict the noise residual
1140
+ noise_pred = self.unet(
1141
+ model_input,
1142
+ t,
1143
+ encoder_hidden_states=prompt_embeds,
1144
+ class_labels=noise_level,
1145
+ cross_attention_kwargs=cross_attention_kwargs,
1146
+ ).sample
1147
+
1148
+ # perform guidance
1149
+ if do_classifier_free_guidance:
1150
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1151
+ noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
1152
+ noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
1153
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1154
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
1155
+
1156
+ # compute the previous noisy sample x_t -> x_t-1
1157
+ prev_intermediate_images = intermediate_images
1158
+
1159
+ intermediate_images = self.scheduler.step(
1160
+ noise_pred, t, intermediate_images, **extra_step_kwargs
1161
+ ).prev_sample
1162
+
1163
+ intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images
1164
+
1165
+ # call the callback, if provided
1166
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1167
+ progress_bar.update()
1168
+ if callback is not None and i % callback_steps == 0:
1169
+ callback(i, t, intermediate_images)
1170
+
1171
+ image = intermediate_images
1172
+
1173
+ if output_type == "pil":
1174
+ # 10. Post-processing
1175
+ image = (image / 2 + 0.5).clamp(0, 1)
1176
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
1177
+
1178
+ # 11. Run safety checker
1179
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
1180
+
1181
+ # 12. Convert to PIL
1182
+ image = self.numpy_to_pil(image)
1183
+
1184
+ # 13. Apply watermark
1185
+ if self.watermarker is not None:
1186
+ self.watermarker.apply_watermark(image, self.unet.config.sample_size)
1187
+ elif output_type == "pt":
1188
+ nsfw_detected = None
1189
+ watermark_detected = None
1190
+
1191
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
1192
+ self.unet_offload_hook.offload()
1193
+ else:
1194
+ # 10. Post-processing
1195
+ image = (image / 2 + 0.5).clamp(0, 1)
1196
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
1197
+
1198
+ # 11. Run safety checker
1199
+ image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
1200
+
1201
+ # Offload last model to CPU
1202
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1203
+ self.final_offload_hook.offload()
1204
+
1205
+ if not return_dict:
1206
+ return (image, nsfw_detected, watermark_detected)
1207
+
1208
+ return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)