hcpdiff 0.9.0__py3-none-any.whl → 2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. hcpdiff/__init__.py +4 -4
  2. hcpdiff/ckpt_manager/__init__.py +4 -5
  3. hcpdiff/ckpt_manager/ckpt.py +24 -0
  4. hcpdiff/ckpt_manager/format/__init__.py +4 -0
  5. hcpdiff/ckpt_manager/format/diffusers.py +59 -0
  6. hcpdiff/ckpt_manager/format/emb.py +21 -0
  7. hcpdiff/ckpt_manager/format/lora_webui.py +244 -0
  8. hcpdiff/ckpt_manager/format/sd_single.py +41 -0
  9. hcpdiff/ckpt_manager/loader.py +64 -0
  10. hcpdiff/data/__init__.py +4 -28
  11. hcpdiff/data/cache/__init__.py +1 -0
  12. hcpdiff/data/cache/vae.py +102 -0
  13. hcpdiff/data/dataset.py +20 -0
  14. hcpdiff/data/handler/__init__.py +3 -0
  15. hcpdiff/data/handler/controlnet.py +18 -0
  16. hcpdiff/data/handler/diffusion.py +80 -0
  17. hcpdiff/data/handler/text.py +111 -0
  18. hcpdiff/data/source/__init__.py +1 -2
  19. hcpdiff/data/source/folder_class.py +12 -29
  20. hcpdiff/data/source/text2img.py +36 -74
  21. hcpdiff/data/source/text2img_cond.py +9 -15
  22. hcpdiff/diffusion/__init__.py +0 -0
  23. hcpdiff/diffusion/noise/__init__.py +2 -0
  24. hcpdiff/diffusion/noise/pyramid_noise.py +42 -0
  25. hcpdiff/diffusion/noise/zero_terminal.py +39 -0
  26. hcpdiff/diffusion/sampler/__init__.py +5 -0
  27. hcpdiff/diffusion/sampler/base.py +72 -0
  28. hcpdiff/diffusion/sampler/ddpm.py +20 -0
  29. hcpdiff/diffusion/sampler/diffusers.py +66 -0
  30. hcpdiff/diffusion/sampler/edm.py +22 -0
  31. hcpdiff/diffusion/sampler/sigma_scheduler/__init__.py +3 -0
  32. hcpdiff/diffusion/sampler/sigma_scheduler/base.py +14 -0
  33. hcpdiff/diffusion/sampler/sigma_scheduler/ddpm.py +197 -0
  34. hcpdiff/diffusion/sampler/sigma_scheduler/edm.py +48 -0
  35. hcpdiff/easy/__init__.py +2 -0
  36. hcpdiff/easy/cfg/__init__.py +3 -0
  37. hcpdiff/easy/cfg/sd15_train.py +201 -0
  38. hcpdiff/easy/cfg/sdxl_train.py +140 -0
  39. hcpdiff/easy/cfg/t2i.py +177 -0
  40. hcpdiff/easy/model/__init__.py +2 -0
  41. hcpdiff/easy/model/cnet.py +31 -0
  42. hcpdiff/easy/model/loader.py +79 -0
  43. hcpdiff/easy/sampler.py +46 -0
  44. hcpdiff/evaluate/__init__.py +1 -0
  45. hcpdiff/evaluate/previewer.py +60 -0
  46. hcpdiff/loss/__init__.py +4 -1
  47. hcpdiff/loss/base.py +41 -0
  48. hcpdiff/loss/gw.py +35 -0
  49. hcpdiff/loss/ssim.py +37 -0
  50. hcpdiff/loss/vlb.py +79 -0
  51. hcpdiff/loss/weighting.py +66 -0
  52. hcpdiff/models/__init__.py +2 -2
  53. hcpdiff/models/cfg_context.py +17 -14
  54. hcpdiff/models/compose/compose_hook.py +44 -23
  55. hcpdiff/models/compose/compose_tokenizer.py +21 -8
  56. hcpdiff/models/compose/sdxl_composer.py +4 -4
  57. hcpdiff/models/container.py +1 -1
  58. hcpdiff/models/controlnet.py +16 -16
  59. hcpdiff/models/lora_base_patch.py +14 -25
  60. hcpdiff/models/lora_layers.py +3 -9
  61. hcpdiff/models/lora_layers_patch.py +14 -24
  62. hcpdiff/models/text_emb_ex.py +84 -6
  63. hcpdiff/models/textencoder_ex.py +54 -18
  64. hcpdiff/models/wrapper/__init__.py +3 -0
  65. hcpdiff/models/wrapper/pixart.py +19 -0
  66. hcpdiff/models/wrapper/sd.py +218 -0
  67. hcpdiff/models/wrapper/utils.py +20 -0
  68. hcpdiff/parser/__init__.py +1 -0
  69. hcpdiff/parser/embpt.py +32 -0
  70. hcpdiff/tools/convert_caption_txt2json.py +1 -1
  71. hcpdiff/tools/dataset_generator.py +94 -0
  72. hcpdiff/tools/download_hf_model.py +24 -0
  73. hcpdiff/tools/embedding_convert.py +6 -2
  74. hcpdiff/tools/init_proj.py +3 -21
  75. hcpdiff/tools/lora_convert.py +19 -15
  76. hcpdiff/tools/save_model.py +12 -0
  77. hcpdiff/tools/sd2diffusers.py +1 -1
  78. hcpdiff/train_colo.py +1 -1
  79. hcpdiff/train_deepspeed.py +1 -1
  80. hcpdiff/trainer_ac.py +79 -0
  81. hcpdiff/trainer_ac_single.py +31 -0
  82. hcpdiff/utils/__init__.py +0 -2
  83. hcpdiff/utils/inpaint_pipe.py +790 -0
  84. hcpdiff/utils/net_utils.py +29 -6
  85. hcpdiff/utils/pipe_hook.py +46 -33
  86. hcpdiff/utils/utils.py +21 -4
  87. hcpdiff/workflow/__init__.py +15 -10
  88. hcpdiff/workflow/daam/__init__.py +1 -0
  89. hcpdiff/workflow/daam/act.py +66 -0
  90. hcpdiff/workflow/daam/hook.py +109 -0
  91. hcpdiff/workflow/diffusion.py +128 -136
  92. hcpdiff/workflow/fast.py +31 -0
  93. hcpdiff/workflow/flow.py +67 -0
  94. hcpdiff/workflow/io.py +36 -68
  95. hcpdiff/workflow/model.py +46 -43
  96. hcpdiff/workflow/text.py +84 -52
  97. hcpdiff/workflow/utils.py +32 -12
  98. hcpdiff/workflow/vae.py +37 -38
  99. hcpdiff-2.1.dist-info/METADATA +285 -0
  100. hcpdiff-2.1.dist-info/RECORD +114 -0
  101. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/WHEEL +1 -1
  102. hcpdiff-2.1.dist-info/entry_points.txt +5 -0
  103. hcpdiff/ckpt_manager/base.py +0 -16
  104. hcpdiff/ckpt_manager/ckpt_diffusers.py +0 -45
  105. hcpdiff/ckpt_manager/ckpt_pkl.py +0 -138
  106. hcpdiff/ckpt_manager/ckpt_safetensor.py +0 -60
  107. hcpdiff/ckpt_manager/ckpt_webui.py +0 -54
  108. hcpdiff/data/bucket.py +0 -358
  109. hcpdiff/data/caption_loader.py +0 -80
  110. hcpdiff/data/cond_dataset.py +0 -40
  111. hcpdiff/data/crop_info_dataset.py +0 -40
  112. hcpdiff/data/data_processor.py +0 -33
  113. hcpdiff/data/pair_dataset.py +0 -146
  114. hcpdiff/data/sampler.py +0 -54
  115. hcpdiff/data/source/base.py +0 -30
  116. hcpdiff/data/utils.py +0 -80
  117. hcpdiff/infer_workflow.py +0 -57
  118. hcpdiff/loggers/__init__.py +0 -13
  119. hcpdiff/loggers/base_logger.py +0 -76
  120. hcpdiff/loggers/cli_logger.py +0 -40
  121. hcpdiff/loggers/preview/__init__.py +0 -1
  122. hcpdiff/loggers/preview/image_previewer.py +0 -149
  123. hcpdiff/loggers/tensorboard_logger.py +0 -30
  124. hcpdiff/loggers/wandb_logger.py +0 -31
  125. hcpdiff/loggers/webui_logger.py +0 -9
  126. hcpdiff/loss/min_snr_loss.py +0 -52
  127. hcpdiff/models/layers.py +0 -81
  128. hcpdiff/models/plugin.py +0 -348
  129. hcpdiff/models/wrapper.py +0 -75
  130. hcpdiff/noise/__init__.py +0 -3
  131. hcpdiff/noise/noise_base.py +0 -16
  132. hcpdiff/noise/pyramid_noise.py +0 -50
  133. hcpdiff/noise/zero_terminal.py +0 -44
  134. hcpdiff/train_ac.py +0 -565
  135. hcpdiff/train_ac_single.py +0 -39
  136. hcpdiff/utils/caption_tools.py +0 -105
  137. hcpdiff/utils/cfg_net_tools.py +0 -321
  138. hcpdiff/utils/cfg_resolvers.py +0 -16
  139. hcpdiff/utils/ema.py +0 -52
  140. hcpdiff/utils/img_size_tool.py +0 -248
  141. hcpdiff/vis/__init__.py +0 -3
  142. hcpdiff/vis/base_interface.py +0 -12
  143. hcpdiff/vis/disk_interface.py +0 -48
  144. hcpdiff/vis/webui_interface.py +0 -17
  145. hcpdiff/visualizer.py +0 -258
  146. hcpdiff/visualizer_reloadable.py +0 -237
  147. hcpdiff/workflow/base.py +0 -59
  148. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime.yaml +0 -21
  149. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/anime/text2img_anime_lora.yaml +0 -58
  150. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/change_vae.yaml +0 -6
  151. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/euler_a.yaml +0 -8
  152. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img.yaml +0 -10
  153. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/img2img_controlnet.yaml +0 -19
  154. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/inpaint.yaml +0 -11
  155. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_lora.yaml +0 -26
  156. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/load_unet_part.yaml +0 -18
  157. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/offload_2GB.yaml +0 -6
  158. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/save_model.yaml +0 -44
  159. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img.yaml +0 -53
  160. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_DA++.yaml +0 -34
  161. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/infer/text2img_sdxl.yaml +0 -9
  162. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/plugins/plugin_controlnet.yaml +0 -17
  163. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/te_struct.txt +0 -193
  164. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/base_dataset.yaml +0 -29
  165. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/dataset/regularization_dataset.yaml +0 -31
  166. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/CustomDiffusion.yaml +0 -74
  167. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist++.yaml +0 -135
  168. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamArtist.yaml +0 -45
  169. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/DreamBooth.yaml +0 -62
  170. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/FT_sdxl.yaml +0 -33
  171. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/Lion_optimizer.yaml +0 -17
  172. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/TextualInversion.yaml +0 -41
  173. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/add_logger_tensorboard_wandb.yaml +0 -15
  174. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/controlnet.yaml +0 -53
  175. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/ema.yaml +0 -10
  176. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/fine-tuning.yaml +0 -53
  177. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/locon.yaml +0 -24
  178. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_anime_character.yaml +0 -77
  179. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_conventional.yaml +0 -56
  180. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/lora_sdxl.yaml +0 -41
  181. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/min_snr.yaml +0 -7
  182. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples/preview_in_training.yaml +0 -6
  183. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/DreamBooth.yaml +0 -70
  184. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/TextualInversion.yaml +0 -45
  185. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/fine-tuning.yaml +0 -45
  186. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/examples_noob/lora.yaml +0 -63
  187. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/train_base.yaml +0 -81
  188. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/train/tuning_base.yaml +0 -42
  189. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/unet_struct.txt +0 -932
  190. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_latent.yaml +0 -86
  191. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/highres_fix_pixel.yaml +0 -99
  192. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img.yaml +0 -57
  193. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/workflow/text2img_lora.yaml +0 -70
  194. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero2.json +0 -32
  195. hcpdiff-0.9.0.data/data/hcpdiff/cfgs/zero3.json +0 -39
  196. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/caption.txt +0 -1
  197. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name.txt +0 -1
  198. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_2pt_caption.txt +0 -1
  199. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/name_caption.txt +0 -1
  200. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object.txt +0 -27
  201. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/object_caption.txt +0 -27
  202. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style.txt +0 -19
  203. hcpdiff-0.9.0.data/data/hcpdiff/prompt_tuning_template/style_caption.txt +0 -19
  204. hcpdiff-0.9.0.dist-info/METADATA +0 -199
  205. hcpdiff-0.9.0.dist-info/RECORD +0 -155
  206. hcpdiff-0.9.0.dist-info/entry_points.txt +0 -2
  207. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info/licenses}/LICENSE +0 -0
  208. {hcpdiff-0.9.0.dist-info → hcpdiff-2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,790 @@
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import PIL.Image
20
+ import torch
21
+ from packaging import version
22
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
23
+
24
+ from diffusers import StableDiffusionInpaintPipelineLegacy
25
+ from diffusers.configuration_utils import FrozenDict
26
+ from diffusers.image_processor import VaeImageProcessor
27
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
28
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
29
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
30
+ from diffusers.schedulers import KarrasDiffusionSchedulers
31
+ from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
34
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
35
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
36
+
37
+ try:
38
+ from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers
39
+ except:
40
+ USE_PEFT_BACKEND = False
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+
45
+ def preprocess_image(image, batch_size):
46
+ w, h = image.size
47
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
48
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
49
+ image = np.array(image).astype(np.float32) / 255.0
50
+ image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
51
+ image = torch.from_numpy(image)
52
+ return 2.0 * image - 1.0
53
+
54
+
55
+ def preprocess_mask(mask, batch_size, scale_factor=8):
56
+ if not isinstance(mask, torch.FloatTensor):
57
+ mask = mask.convert("L")
58
+ w, h = mask.size
59
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
60
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
61
+ mask = np.array(mask).astype(np.float32) / 255.0
62
+ mask = np.tile(mask, (4, 1, 1))
63
+ mask = np.vstack([mask[None]] * batch_size)
64
+ mask = 1 - mask # repaint white, keep black
65
+ mask = torch.from_numpy(mask)
66
+ return mask
67
+
68
+ else:
69
+ valid_mask_channel_sizes = [1, 3]
70
+ # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W)
71
+ if mask.shape[3] in valid_mask_channel_sizes:
72
+ mask = mask.permute(0, 3, 1, 2)
73
+ elif mask.shape[1] not in valid_mask_channel_sizes:
74
+ raise ValueError(
75
+ f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension,"
76
+ f" but received mask of shape {tuple(mask.shape)}"
77
+ )
78
+ # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
79
+ mask = mask.mean(dim=1, keepdim=True)
80
+ h, w = mask.shape[-2:]
81
+ h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8
82
+ mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor))
83
+ return mask
84
+
85
+
86
+ class StableDiffusionInpaintPipelineLegacy(
87
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
88
+ ):
89
+ r"""
90
+ Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
91
+
92
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
93
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
94
+
95
+ In addition the pipeline inherits the following loading methods:
96
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
97
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
98
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
99
+
100
+ as well as the following saving methods:
101
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
102
+
103
+ Args:
104
+ vae ([`AutoencoderKL`]):
105
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
106
+ text_encoder ([`CLIPTextModel`]):
107
+ Frozen text-encoder. Stable Diffusion uses the text portion of
108
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
109
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
110
+ tokenizer (`CLIPTokenizer`):
111
+ Tokenizer of class
112
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
113
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
114
+ scheduler ([`SchedulerMixin`]):
115
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
116
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
117
+ safety_checker ([`StableDiffusionSafetyChecker`]):
118
+ Classification module that estimates whether generated images could be considered offensive or harmful.
119
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
120
+ feature_extractor ([`CLIPImageProcessor`]):
121
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
122
+ """
123
+
124
+ model_cpu_offload_seq = "text_encoder->unet->vae"
125
+ _optional_components = ["feature_extractor"]
126
+ _exclude_from_cpu_offload = ["safety_checker"]
127
+
128
+ def __init__(
129
+ self,
130
+ vae: AutoencoderKL,
131
+ text_encoder: CLIPTextModel,
132
+ tokenizer: CLIPTokenizer,
133
+ unet: UNet2DConditionModel,
134
+ scheduler: KarrasDiffusionSchedulers,
135
+ safety_checker: StableDiffusionSafetyChecker,
136
+ feature_extractor: CLIPImageProcessor,
137
+ requires_safety_checker: bool = True,
138
+ ):
139
+ super().__init__()
140
+
141
+ deprecation_message = (
142
+ f"The class {self.__class__} is deprecated and will be removed in v1.0.0. You can achieve exactly the same functionality"
143
+ "by loading your model into `StableDiffusionInpaintPipeline` instead. See https://github.com/huggingface/diffusers/pull/3533"
144
+ "for more information."
145
+ )
146
+ deprecate("legacy is outdated", "1.0.0", deprecation_message, standard_warn=False)
147
+
148
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
149
+ deprecation_message = (
150
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
151
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
152
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
153
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
154
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
155
+ " file"
156
+ )
157
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
158
+ new_config = dict(scheduler.config)
159
+ new_config["steps_offset"] = 1
160
+ scheduler._internal_dict = FrozenDict(new_config)
161
+
162
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
163
+ deprecation_message = (
164
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
165
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
166
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
167
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
168
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
169
+ )
170
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
171
+ new_config = dict(scheduler.config)
172
+ new_config["clip_sample"] = False
173
+ scheduler._internal_dict = FrozenDict(new_config)
174
+
175
+ if safety_checker is None and requires_safety_checker:
176
+ logger.warning(
177
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
178
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
179
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
180
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
181
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
182
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
183
+ )
184
+
185
+ if safety_checker is not None and feature_extractor is None:
186
+ raise ValueError(
187
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
188
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
189
+ )
190
+
191
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
192
+ version.parse(unet.config._diffusers_version).base_version
193
+ ) < version.parse("0.9.0.dev0")
194
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
195
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
196
+ deprecation_message = (
197
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
198
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
199
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
200
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
201
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
202
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
203
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
204
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
205
+ " the `unet/config.json` file"
206
+ )
207
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
208
+ new_config = dict(unet.config)
209
+ new_config["sample_size"] = 64
210
+ unet._internal_dict = FrozenDict(new_config)
211
+
212
+ self.register_modules(
213
+ vae=vae,
214
+ text_encoder=text_encoder,
215
+ tokenizer=tokenizer,
216
+ unet=unet,
217
+ scheduler=scheduler,
218
+ safety_checker=safety_checker,
219
+ feature_extractor=feature_extractor,
220
+ )
221
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
222
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
223
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
224
+
225
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
226
+ def _encode_prompt(
227
+ self,
228
+ prompt,
229
+ device,
230
+ num_images_per_prompt,
231
+ do_classifier_free_guidance,
232
+ negative_prompt=None,
233
+ prompt_embeds: Optional[torch.FloatTensor] = None,
234
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
235
+ lora_scale: Optional[float] = None,
236
+ **kwargs,
237
+ ):
238
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
239
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
240
+
241
+ prompt_embeds_tuple = self.encode_prompt(
242
+ prompt=prompt,
243
+ device=device,
244
+ num_images_per_prompt=num_images_per_prompt,
245
+ do_classifier_free_guidance=do_classifier_free_guidance,
246
+ negative_prompt=negative_prompt,
247
+ prompt_embeds=prompt_embeds,
248
+ negative_prompt_embeds=negative_prompt_embeds,
249
+ lora_scale=lora_scale,
250
+ **kwargs,
251
+ )
252
+
253
+ # concatenate for backwards comp
254
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
255
+
256
+ return prompt_embeds
257
+
258
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
259
+ def encode_prompt(
260
+ self,
261
+ prompt,
262
+ device,
263
+ num_images_per_prompt,
264
+ do_classifier_free_guidance,
265
+ negative_prompt=None,
266
+ prompt_embeds: Optional[torch.FloatTensor] = None,
267
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
268
+ lora_scale: Optional[float] = None,
269
+ clip_skip: Optional[int] = None,
270
+ ):
271
+ r"""
272
+ Encodes the prompt into text encoder hidden states.
273
+
274
+ Args:
275
+ prompt (`str` or `List[str]`, *optional*):
276
+ prompt to be encoded
277
+ device: (`torch.device`):
278
+ torch device
279
+ num_images_per_prompt (`int`):
280
+ number of images that should be generated per prompt
281
+ do_classifier_free_guidance (`bool`):
282
+ whether to use classifier free guidance or not
283
+ negative_prompt (`str` or `List[str]`, *optional*):
284
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
285
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
286
+ less than `1`).
287
+ prompt_embeds (`torch.FloatTensor`, *optional*):
288
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
289
+ provided, text embeddings will be generated from `prompt` input argument.
290
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
291
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
292
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
293
+ argument.
294
+ lora_scale (`float`, *optional*):
295
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
296
+ clip_skip (`int`, *optional*):
297
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
298
+ the output of the pre-final layer will be used for computing the prompt embeddings.
299
+ """
300
+ # set lora scale so that monkey patched LoRA
301
+ # function of text encoder can correctly access it
302
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
303
+ self._lora_scale = lora_scale
304
+
305
+ # dynamically adjust the LoRA scale
306
+ if not USE_PEFT_BACKEND:
307
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
308
+ else:
309
+ scale_lora_layers(self.text_encoder, lora_scale)
310
+
311
+ if prompt is not None and isinstance(prompt, str):
312
+ batch_size = 1
313
+ elif prompt is not None and isinstance(prompt, list):
314
+ batch_size = len(prompt)
315
+ else:
316
+ batch_size = prompt_embeds.shape[0]
317
+
318
+ if prompt_embeds is None:
319
+ # textual inversion: procecss multi-vector tokens if necessary
320
+ if isinstance(self, TextualInversionLoaderMixin):
321
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
322
+
323
+ text_inputs = self.tokenizer(
324
+ prompt,
325
+ padding="max_length",
326
+ max_length=self.tokenizer.model_max_length,
327
+ truncation=True,
328
+ return_tensors="pt",
329
+ )
330
+ text_input_ids = text_inputs.input_ids
331
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
332
+
333
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
334
+ text_input_ids, untruncated_ids
335
+ ):
336
+ removed_text = self.tokenizer.batch_decode(
337
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
338
+ )
339
+ logger.warning(
340
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
341
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
342
+ )
343
+
344
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
345
+ attention_mask = text_inputs.attention_mask.to(device)
346
+ else:
347
+ attention_mask = None
348
+
349
+ if clip_skip is None:
350
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
351
+ prompt_embeds = prompt_embeds[0]
352
+ else:
353
+ prompt_embeds = self.text_encoder(
354
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
355
+ )
356
+ # Access the `hidden_states` first, that contains a tuple of
357
+ # all the hidden states from the encoder layers. Then index into
358
+ # the tuple to access the hidden states from the desired layer.
359
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
360
+ # We also need to apply the final LayerNorm here to not mess with the
361
+ # representations. The `last_hidden_states` that we typically use for
362
+ # obtaining the final prompt representations passes through the LayerNorm
363
+ # layer.
364
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
365
+
366
+ if self.text_encoder is not None:
367
+ prompt_embeds_dtype = self.text_encoder.dtype
368
+ elif self.unet is not None:
369
+ prompt_embeds_dtype = self.unet.dtype
370
+ else:
371
+ prompt_embeds_dtype = prompt_embeds.dtype
372
+
373
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
374
+
375
+ bs_embed, seq_len, _ = prompt_embeds.shape
376
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
377
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
378
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
379
+
380
+ # get unconditional embeddings for classifier free guidance
381
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
382
+ uncond_tokens: List[str]
383
+ if negative_prompt is None:
384
+ uncond_tokens = [""] * batch_size
385
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
386
+ raise TypeError(
387
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
388
+ f" {type(prompt)}."
389
+ )
390
+ elif isinstance(negative_prompt, str):
391
+ uncond_tokens = [negative_prompt]
392
+ elif batch_size != len(negative_prompt):
393
+ raise ValueError(
394
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
395
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
396
+ " the batch size of `prompt`."
397
+ )
398
+ else:
399
+ uncond_tokens = negative_prompt
400
+
401
+ # textual inversion: procecss multi-vector tokens if necessary
402
+ if isinstance(self, TextualInversionLoaderMixin):
403
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
404
+
405
+ max_length = prompt_embeds.shape[1]
406
+ uncond_input = self.tokenizer(
407
+ uncond_tokens,
408
+ padding="max_length",
409
+ max_length=max_length,
410
+ truncation=True,
411
+ return_tensors="pt",
412
+ )
413
+
414
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
415
+ attention_mask = uncond_input.attention_mask.to(device)
416
+ else:
417
+ attention_mask = None
418
+
419
+ negative_prompt_embeds = self.text_encoder(
420
+ uncond_input.input_ids.to(device),
421
+ attention_mask=attention_mask,
422
+ )
423
+ negative_prompt_embeds = negative_prompt_embeds[0]
424
+
425
+ if do_classifier_free_guidance:
426
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
427
+ seq_len = negative_prompt_embeds.shape[1]
428
+
429
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
430
+
431
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
432
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
433
+
434
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
435
+ # Retrieve the original scale by scaling back the LoRA layers
436
+ unscale_lora_layers(self.text_encoder, lora_scale)
437
+
438
+ return prompt_embeds, negative_prompt_embeds
439
+
440
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
441
+ def run_safety_checker(self, image, device, dtype):
442
+ if self.safety_checker is None:
443
+ has_nsfw_concept = None
444
+ else:
445
+ if torch.is_tensor(image):
446
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
447
+ else:
448
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
449
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
450
+ image, has_nsfw_concept = self.safety_checker(
451
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
452
+ )
453
+ return image, has_nsfw_concept
454
+
455
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
456
+ def decode_latents(self, latents):
457
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
458
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
459
+
460
+ latents = 1 / self.vae.config.scaling_factor * latents
461
+ image = self.vae.decode(latents, return_dict=False)[0]
462
+ image = (image / 2 + 0.5).clamp(0, 1)
463
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
464
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
465
+ return image
466
+
467
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
468
+ def prepare_extra_step_kwargs(self, generator, eta):
469
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
470
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
471
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
472
+ # and should be between [0, 1]
473
+
474
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
475
+ extra_step_kwargs = {}
476
+ if accepts_eta:
477
+ extra_step_kwargs["eta"] = eta
478
+
479
+ # check if the scheduler accepts generator
480
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
481
+ if accepts_generator:
482
+ extra_step_kwargs["generator"] = generator
483
+ return extra_step_kwargs
484
+
485
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
486
+ def check_inputs(
487
+ self,
488
+ prompt,
489
+ strength,
490
+ callback_steps,
491
+ negative_prompt=None,
492
+ prompt_embeds=None,
493
+ negative_prompt_embeds=None,
494
+ callback_on_step_end_tensor_inputs=None,
495
+ ):
496
+ if strength < 0 or strength > 1:
497
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
498
+
499
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
500
+ raise ValueError(
501
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
502
+ f" {type(callback_steps)}."
503
+ )
504
+
505
+ if callback_on_step_end_tensor_inputs is not None and not all(
506
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
507
+ ):
508
+ raise ValueError(
509
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
510
+ )
511
+ if prompt is not None and prompt_embeds is not None:
512
+ raise ValueError(
513
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
514
+ " only forward one of the two."
515
+ )
516
+ elif prompt is None and prompt_embeds is None:
517
+ raise ValueError(
518
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
519
+ )
520
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
521
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
522
+
523
+ if negative_prompt is not None and negative_prompt_embeds is not None:
524
+ raise ValueError(
525
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
526
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
527
+ )
528
+
529
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
530
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
531
+ raise ValueError(
532
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
533
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
534
+ f" {negative_prompt_embeds.shape}."
535
+ )
536
+
537
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
538
+ def get_timesteps(self, num_inference_steps, strength, device):
539
+ # get the original timestep using init_timestep
540
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
541
+
542
+ t_start = max(num_inference_steps - init_timestep, 0)
543
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
544
+
545
+ return timesteps, num_inference_steps - t_start
546
+
547
+ def prepare_latents(self, image, timestep, num_images_per_prompt, dtype, device, generator):
548
+ image = image.to(device=device, dtype=dtype)
549
+ init_latent_dist = self.vae.encode(image).latent_dist
550
+ init_latents = init_latent_dist.sample(generator=generator)
551
+ init_latents = self.vae.config.scaling_factor * init_latents
552
+
553
+ # Expand init_latents for batch_size and num_images_per_prompt
554
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
555
+ init_latents_orig = init_latents
556
+
557
+ # add noise to latents using the timesteps
558
+ noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
559
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
560
+ latents = init_latents
561
+ return latents, init_latents_orig, noise
562
+
563
+ @torch.no_grad()
564
+ def __call__(
565
+ self,
566
+ prompt: Union[str, List[str]] = None,
567
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
568
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
569
+ strength: float = 0.8,
570
+ num_inference_steps: Optional[int] = 50,
571
+ guidance_scale: Optional[float] = 7.5,
572
+ negative_prompt: Optional[Union[str, List[str]]] = None,
573
+ num_images_per_prompt: Optional[int] = 1,
574
+ add_predicted_noise: Optional[bool] = False,
575
+ eta: Optional[float] = 0.0,
576
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
577
+ prompt_embeds: Optional[torch.FloatTensor] = None,
578
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
579
+ output_type: Optional[str] = "pil",
580
+ return_dict: bool = True,
581
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
582
+ callback_steps: int = 1,
583
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
584
+ clip_skip: Optional[int] = None,
585
+ ):
586
+ r"""
587
+ Function invoked when calling the pipeline for generation.
588
+
589
+ Args:
590
+ prompt (`str` or `List[str]`, *optional*):
591
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
592
+ instead.
593
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
594
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
595
+ process. This is the image whose masked region will be inpainted.
596
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
597
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
598
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
599
+ PIL image, it will be converted to a single channel (luminance) before use. If mask is a tensor, the
600
+ expected shape should be either `(B, H, W, C)` or `(B, C, H, W)`, where C is 1 or 3.
601
+ strength (`float`, *optional*, defaults to 0.8):
602
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
603
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
604
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to
605
+ that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
606
+ num_inference_steps (`int`, *optional*, defaults to 50):
607
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
608
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
609
+ guidance_scale (`float`, *optional*, defaults to 7.5):
610
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
611
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
612
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
613
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
614
+ usually at the expense of lower image quality.
615
+ negative_prompt (`str` or `List[str]`, *optional*):
616
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
617
+ `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
618
+ is less than `1`).
619
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
620
+ The number of images to generate per prompt.
621
+ add_predicted_noise (`bool`, *optional*, defaults to True):
622
+ Use predicted noise instead of random noise when constructing noisy versions of the original image in
623
+ the reverse diffusion process
624
+ eta (`float`, *optional*, defaults to 0.0):
625
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
626
+ [`schedulers.DDIMScheduler`], will be ignored for others.
627
+ generator (`torch.Generator`, *optional*):
628
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
629
+ to make generation deterministic.
630
+ prompt_embeds (`torch.FloatTensor`, *optional*):
631
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
632
+ provided, text embeddings will be generated from `prompt` input argument.
633
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
634
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
635
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
636
+ argument.
637
+ output_type (`str`, *optional*, defaults to `"pil"`):
638
+ The output format of the generate image. Choose between
639
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
640
+ return_dict (`bool`, *optional*, defaults to `True`):
641
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
642
+ plain tuple.
643
+ callback (`Callable`, *optional*):
644
+ A function that will be called every `callback_steps` steps during inference. The function will be
645
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
646
+ callback_steps (`int`, *optional*, defaults to 1):
647
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
648
+ called at every step.
649
+ cross_attention_kwargs (`dict`, *optional*):
650
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
651
+ `self.processor` in
652
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
653
+ clip_skip (`int`, *optional*):
654
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
655
+ the output of the pre-final layer will be used for computing the prompt embeddings.
656
+
657
+ Returns:
658
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
659
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
660
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
661
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
662
+ (nsfw) content, according to the `safety_checker`.
663
+ """
664
+ # 1. Check inputs
665
+ self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
666
+
667
+ # 2. Define call parameters
668
+ if prompt is not None and isinstance(prompt, str):
669
+ batch_size = 1
670
+ elif prompt is not None and isinstance(prompt, list):
671
+ batch_size = len(prompt)
672
+ else:
673
+ batch_size = prompt_embeds.shape[0]
674
+
675
+ device = self._execution_device
676
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
677
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
678
+ # corresponds to doing no classifier free guidance.
679
+ do_classifier_free_guidance = guidance_scale > 1.0
680
+
681
+ # 3. Encode input prompt
682
+ text_encoder_lora_scale = (
683
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
684
+ )
685
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
686
+ prompt,
687
+ device,
688
+ num_images_per_prompt,
689
+ do_classifier_free_guidance,
690
+ negative_prompt,
691
+ prompt_embeds=prompt_embeds,
692
+ negative_prompt_embeds=negative_prompt_embeds,
693
+ lora_scale=text_encoder_lora_scale,
694
+ clip_skip=clip_skip,
695
+ )
696
+ # For classifier free guidance, we need to do two forward passes.
697
+ # Here we concatenate the unconditional and text embeddings into a single batch
698
+ # to avoid doing two forward passes
699
+ if do_classifier_free_guidance:
700
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
701
+
702
+ # 4. Preprocess image and mask
703
+ if not isinstance(image, torch.FloatTensor):
704
+ image = preprocess_image(image, batch_size)
705
+
706
+ mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
707
+
708
+ # 5. set timesteps
709
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
710
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
711
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
712
+
713
+ # 6. Prepare latent variables
714
+ # encode the init image into latents and scale the latents
715
+ latents, init_latents_orig, noise = self.prepare_latents(
716
+ image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator
717
+ )
718
+
719
+ # 7. Prepare mask latent
720
+ mask = mask_image.to(device=device, dtype=latents.dtype)
721
+ mask = torch.cat([mask] * num_images_per_prompt)
722
+
723
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
724
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
725
+
726
+ # 9. Denoising loop
727
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
728
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
729
+ for i, t in enumerate(timesteps):
730
+ # expand the latents if we are doing classifier free guidance
731
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
732
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
733
+
734
+ # predict the noise residual
735
+ noise_pred = self.unet(
736
+ latent_model_input,
737
+ t,
738
+ encoder_hidden_states=prompt_embeds,
739
+ cross_attention_kwargs=cross_attention_kwargs,
740
+ return_dict=False,
741
+ )[0]
742
+
743
+ # perform guidance
744
+ if do_classifier_free_guidance:
745
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
746
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
747
+
748
+ # compute the previous noisy sample x_t -> x_t-1
749
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
750
+ # masking
751
+ if add_predicted_noise:
752
+ init_latents_proper = self.scheduler.add_noise(
753
+ init_latents_orig, noise_pred_uncond, torch.tensor([t])
754
+ )
755
+ else:
756
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
757
+
758
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
759
+
760
+ # call the callback, if provided
761
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
762
+ progress_bar.update()
763
+ if callback is not None and i % callback_steps == 0:
764
+ step_idx = i // getattr(self.scheduler, "order", 1)
765
+ callback(step_idx, t, latents)
766
+
767
+ # use original latents corresponding to unmasked portions of the image
768
+ latents = (init_latents_orig * mask) + (latents * (1 - mask))
769
+
770
+ if not output_type == "latent":
771
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
772
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
773
+ else:
774
+ image = latents
775
+ has_nsfw_concept = None
776
+
777
+ if has_nsfw_concept is None:
778
+ do_denormalize = [True] * image.shape[0]
779
+ else:
780
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
781
+
782
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
783
+
784
+ # Offload all models
785
+ self.maybe_free_model_hooks()
786
+
787
+ if not return_dict:
788
+ return (image, has_nsfw_concept)
789
+
790
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)