diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,832 @@
1
+ # Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
20
+
21
+ from ...image_processor import VaeImageProcessor
22
+ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
23
+ from ...models import AutoencoderKL, UNet2DConditionModel
24
+ from ...schedulers import KarrasDiffusionSchedulers
25
+ from ...utils import (
26
+ is_accelerate_available,
27
+ is_accelerate_version,
28
+ logging,
29
+ randn_tensor,
30
+ replace_example_docstring,
31
+ )
32
+ from ..pipeline_utils import DiffusionPipeline
33
+ from . import StableDiffusionPipelineOutput
34
+ from .safety_checker import StableDiffusionSafetyChecker
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+ EXAMPLE_DOC_STRING = """
40
+ Examples:
41
+ ```py
42
+ >>> import torch
43
+ >>> from diffusers import DDPMParallelScheduler
44
+ >>> from diffusers import StableDiffusionParadigmsPipeline
45
+
46
+ >>> scheduler = DDPMParallelScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
47
+
48
+ >>> pipe = StableDiffusionParadigmsPipeline.from_pretrained(
49
+ ... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, torch_dtype=torch.float16
50
+ ... )
51
+ >>> pipe = pipe.to("cuda")
52
+
53
+ >>> ngpu, batch_per_device = torch.cuda.device_count(), 5
54
+ >>> pipe.wrapped_unet = torch.nn.DataParallel(pipe.unet, device_ids=[d for d in range(ngpu)])
55
+
56
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
57
+ >>> image = pipe(prompt, parallel=ngpu * batch_per_device, num_inference_steps=1000).images[0]
58
+ ```
59
+ """
60
+
61
+
62
+ class StableDiffusionParadigmsPipeline(
63
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
64
+ ):
65
+ r"""
66
+ Parallelized version of StableDiffusionPipeline, based on the paper https://arxiv.org/abs/2305.16317 This pipeline
67
+ parallelizes the denoising steps to generate a single image faster (more akin to model parallelism).
68
+
69
+ Pipeline for text-to-image generation using Stable Diffusion.
70
+
71
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
72
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
73
+
74
+ In addition the pipeline inherits the following loading methods:
75
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
76
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
77
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
78
+
79
+ as well as the following saving methods:
80
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
81
+
82
+ Args:
83
+ vae ([`AutoencoderKL`]):
84
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
85
+ text_encoder ([`CLIPTextModel`]):
86
+ Frozen text-encoder. Stable Diffusion uses the text portion of
87
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
88
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
89
+ tokenizer (`CLIPTokenizer`):
90
+ Tokenizer of class
91
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
92
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
93
+ scheduler ([`SchedulerMixin`]):
94
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
95
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
96
+ safety_checker ([`StableDiffusionSafetyChecker`]):
97
+ Classification module that estimates whether generated images could be considered offensive or harmful.
98
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
99
+ feature_extractor ([`CLIPImageProcessor`]):
100
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
101
+ """
102
+ _optional_components = ["safety_checker", "feature_extractor"]
103
+
104
+ def __init__(
105
+ self,
106
+ vae: AutoencoderKL,
107
+ text_encoder: CLIPTextModel,
108
+ tokenizer: CLIPTokenizer,
109
+ unet: UNet2DConditionModel,
110
+ scheduler: KarrasDiffusionSchedulers,
111
+ safety_checker: StableDiffusionSafetyChecker,
112
+ feature_extractor: CLIPImageProcessor,
113
+ requires_safety_checker: bool = True,
114
+ ):
115
+ super().__init__()
116
+
117
+ if safety_checker is None and requires_safety_checker:
118
+ logger.warning(
119
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
120
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
121
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
122
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
123
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
124
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
125
+ )
126
+
127
+ if safety_checker is not None and feature_extractor is None:
128
+ raise ValueError(
129
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
130
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
131
+ )
132
+
133
+ self.register_modules(
134
+ vae=vae,
135
+ text_encoder=text_encoder,
136
+ tokenizer=tokenizer,
137
+ unet=unet,
138
+ scheduler=scheduler,
139
+ safety_checker=safety_checker,
140
+ feature_extractor=feature_extractor,
141
+ )
142
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
143
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
144
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
145
+
146
+ # attribute to wrap the unet with torch.nn.DataParallel when running multiple denoising steps on multiple GPUs
147
+ self.wrapped_unet = self.unet
148
+
149
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
150
+ def enable_vae_slicing(self):
151
+ r"""
152
+ Enable sliced VAE decoding.
153
+
154
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
155
+ steps. This is useful to save some memory and allow larger batch sizes.
156
+ """
157
+ self.vae.enable_slicing()
158
+
159
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
160
+ def disable_vae_slicing(self):
161
+ r"""
162
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
163
+ computing decoding in one step.
164
+ """
165
+ self.vae.disable_slicing()
166
+
167
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
168
+ def enable_vae_tiling(self):
169
+ r"""
170
+ Enable tiled VAE decoding.
171
+
172
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
173
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
174
+ """
175
+ self.vae.enable_tiling()
176
+
177
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
178
+ def disable_vae_tiling(self):
179
+ r"""
180
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
181
+ computing decoding in one step.
182
+ """
183
+ self.vae.disable_tiling()
184
+
185
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
186
+ def enable_sequential_cpu_offload(self, gpu_id=0):
187
+ r"""
188
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
189
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
190
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
191
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
192
+ `enable_model_cpu_offload`, but performance is lower.
193
+ """
194
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
195
+ from accelerate import cpu_offload
196
+ else:
197
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
198
+
199
+ device = torch.device(f"cuda:{gpu_id}")
200
+
201
+ if self.device.type != "cpu":
202
+ self.to("cpu", silence_dtype_warnings=True)
203
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
204
+
205
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
206
+ cpu_offload(cpu_offloaded_model, device)
207
+
208
+ if self.safety_checker is not None:
209
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
210
+
211
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
212
+ def enable_model_cpu_offload(self, gpu_id=0):
213
+ r"""
214
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
215
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
216
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
217
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
218
+ """
219
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
220
+ from accelerate import cpu_offload_with_hook
221
+ else:
222
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
223
+
224
+ device = torch.device(f"cuda:{gpu_id}")
225
+
226
+ if self.device.type != "cpu":
227
+ self.to("cpu", silence_dtype_warnings=True)
228
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
229
+
230
+ hook = None
231
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
232
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
233
+
234
+ if self.safety_checker is not None:
235
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
236
+
237
+ # We'll offload the last model manually.
238
+ self.final_offload_hook = hook
239
+
240
+ @property
241
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
242
+ def _execution_device(self):
243
+ r"""
244
+ Returns the device on which the pipeline's models will be executed. After calling
245
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
246
+ hooks.
247
+ """
248
+ if not hasattr(self.unet, "_hf_hook"):
249
+ return self.device
250
+ for module in self.unet.modules():
251
+ if (
252
+ hasattr(module, "_hf_hook")
253
+ and hasattr(module._hf_hook, "execution_device")
254
+ and module._hf_hook.execution_device is not None
255
+ ):
256
+ return torch.device(module._hf_hook.execution_device)
257
+ return self.device
258
+
259
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
260
+ def _encode_prompt(
261
+ self,
262
+ prompt,
263
+ device,
264
+ num_images_per_prompt,
265
+ do_classifier_free_guidance,
266
+ negative_prompt=None,
267
+ prompt_embeds: Optional[torch.FloatTensor] = None,
268
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
269
+ lora_scale: Optional[float] = None,
270
+ ):
271
+ r"""
272
+ Encodes the prompt into text encoder hidden states.
273
+
274
+ Args:
275
+ prompt (`str` or `List[str]`, *optional*):
276
+ prompt to be encoded
277
+ device: (`torch.device`):
278
+ torch device
279
+ num_images_per_prompt (`int`):
280
+ number of images that should be generated per prompt
281
+ do_classifier_free_guidance (`bool`):
282
+ whether to use classifier free guidance or not
283
+ negative_prompt (`str` or `List[str]`, *optional*):
284
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
285
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
286
+ less than `1`).
287
+ prompt_embeds (`torch.FloatTensor`, *optional*):
288
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
289
+ provided, text embeddings will be generated from `prompt` input argument.
290
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
291
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
292
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
293
+ argument.
294
+ lora_scale (`float`, *optional*):
295
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
296
+ """
297
+ # set lora scale so that monkey patched LoRA
298
+ # function of text encoder can correctly access it
299
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
300
+ self._lora_scale = lora_scale
301
+
302
+ if prompt is not None and isinstance(prompt, str):
303
+ batch_size = 1
304
+ elif prompt is not None and isinstance(prompt, list):
305
+ batch_size = len(prompt)
306
+ else:
307
+ batch_size = prompt_embeds.shape[0]
308
+
309
+ if prompt_embeds is None:
310
+ # textual inversion: procecss multi-vector tokens if necessary
311
+ if isinstance(self, TextualInversionLoaderMixin):
312
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
313
+
314
+ text_inputs = self.tokenizer(
315
+ prompt,
316
+ padding="max_length",
317
+ max_length=self.tokenizer.model_max_length,
318
+ truncation=True,
319
+ return_tensors="pt",
320
+ )
321
+ text_input_ids = text_inputs.input_ids
322
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
323
+
324
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
325
+ text_input_ids, untruncated_ids
326
+ ):
327
+ removed_text = self.tokenizer.batch_decode(
328
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
329
+ )
330
+ logger.warning(
331
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
332
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
333
+ )
334
+
335
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
336
+ attention_mask = text_inputs.attention_mask.to(device)
337
+ else:
338
+ attention_mask = None
339
+
340
+ prompt_embeds = self.text_encoder(
341
+ text_input_ids.to(device),
342
+ attention_mask=attention_mask,
343
+ )
344
+ prompt_embeds = prompt_embeds[0]
345
+
346
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
347
+
348
+ bs_embed, seq_len, _ = prompt_embeds.shape
349
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
350
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
351
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
352
+
353
+ # get unconditional embeddings for classifier free guidance
354
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
355
+ uncond_tokens: List[str]
356
+ if negative_prompt is None:
357
+ uncond_tokens = [""] * batch_size
358
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
359
+ raise TypeError(
360
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
361
+ f" {type(prompt)}."
362
+ )
363
+ elif isinstance(negative_prompt, str):
364
+ uncond_tokens = [negative_prompt]
365
+ elif batch_size != len(negative_prompt):
366
+ raise ValueError(
367
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
368
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
369
+ " the batch size of `prompt`."
370
+ )
371
+ else:
372
+ uncond_tokens = negative_prompt
373
+
374
+ # textual inversion: procecss multi-vector tokens if necessary
375
+ if isinstance(self, TextualInversionLoaderMixin):
376
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
377
+
378
+ max_length = prompt_embeds.shape[1]
379
+ uncond_input = self.tokenizer(
380
+ uncond_tokens,
381
+ padding="max_length",
382
+ max_length=max_length,
383
+ truncation=True,
384
+ return_tensors="pt",
385
+ )
386
+
387
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
388
+ attention_mask = uncond_input.attention_mask.to(device)
389
+ else:
390
+ attention_mask = None
391
+
392
+ negative_prompt_embeds = self.text_encoder(
393
+ uncond_input.input_ids.to(device),
394
+ attention_mask=attention_mask,
395
+ )
396
+ negative_prompt_embeds = negative_prompt_embeds[0]
397
+
398
+ if do_classifier_free_guidance:
399
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
400
+ seq_len = negative_prompt_embeds.shape[1]
401
+
402
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
403
+
404
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
405
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
406
+
407
+ # For classifier free guidance, we need to do two forward passes.
408
+ # Here we concatenate the unconditional and text embeddings into a single batch
409
+ # to avoid doing two forward passes
410
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
411
+
412
+ return prompt_embeds
413
+
414
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
415
+ def run_safety_checker(self, image, device, dtype):
416
+ if self.safety_checker is None:
417
+ has_nsfw_concept = None
418
+ else:
419
+ if torch.is_tensor(image):
420
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
421
+ else:
422
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
423
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
424
+ image, has_nsfw_concept = self.safety_checker(
425
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
426
+ )
427
+ return image, has_nsfw_concept
428
+
429
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
430
+ def prepare_extra_step_kwargs(self, generator, eta):
431
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
432
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
433
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
434
+ # and should be between [0, 1]
435
+
436
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
437
+ extra_step_kwargs = {}
438
+ if accepts_eta:
439
+ extra_step_kwargs["eta"] = eta
440
+
441
+ # check if the scheduler accepts generator
442
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
443
+ if accepts_generator:
444
+ extra_step_kwargs["generator"] = generator
445
+ return extra_step_kwargs
446
+
447
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
448
+ def check_inputs(
449
+ self,
450
+ prompt,
451
+ height,
452
+ width,
453
+ callback_steps,
454
+ negative_prompt=None,
455
+ prompt_embeds=None,
456
+ negative_prompt_embeds=None,
457
+ ):
458
+ if height % 8 != 0 or width % 8 != 0:
459
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
460
+
461
+ if (callback_steps is None) or (
462
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
463
+ ):
464
+ raise ValueError(
465
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
466
+ f" {type(callback_steps)}."
467
+ )
468
+
469
+ if prompt is not None and prompt_embeds is not None:
470
+ raise ValueError(
471
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
472
+ " only forward one of the two."
473
+ )
474
+ elif prompt is None and prompt_embeds is None:
475
+ raise ValueError(
476
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
477
+ )
478
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
479
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
480
+
481
+ if negative_prompt is not None and negative_prompt_embeds is not None:
482
+ raise ValueError(
483
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
484
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
485
+ )
486
+
487
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
488
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
489
+ raise ValueError(
490
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
491
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
492
+ f" {negative_prompt_embeds.shape}."
493
+ )
494
+
495
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
496
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
497
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
498
+ if isinstance(generator, list) and len(generator) != batch_size:
499
+ raise ValueError(
500
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
501
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
502
+ )
503
+
504
+ if latents is None:
505
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
506
+ else:
507
+ latents = latents.to(device)
508
+
509
+ # scale the initial noise by the standard deviation required by the scheduler
510
+ latents = latents * self.scheduler.init_noise_sigma
511
+ return latents
512
+
513
+ def _cumsum(self, input, dim, debug=False):
514
+ if debug:
515
+ # cumsum_cuda_kernel does not have a deterministic implementation
516
+ # so perform cumsum on cpu for debugging purposes
517
+ return torch.cumsum(input.cpu().float(), dim=dim).to(input.device)
518
+ else:
519
+ return torch.cumsum(input, dim=dim)
520
+
521
+ @torch.no_grad()
522
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
523
+ def __call__(
524
+ self,
525
+ prompt: Union[str, List[str]] = None,
526
+ height: Optional[int] = None,
527
+ width: Optional[int] = None,
528
+ num_inference_steps: int = 50,
529
+ parallel: int = 10,
530
+ tolerance: float = 0.1,
531
+ guidance_scale: float = 7.5,
532
+ negative_prompt: Optional[Union[str, List[str]]] = None,
533
+ num_images_per_prompt: Optional[int] = 1,
534
+ eta: float = 0.0,
535
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
536
+ latents: Optional[torch.FloatTensor] = None,
537
+ prompt_embeds: Optional[torch.FloatTensor] = None,
538
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
539
+ output_type: Optional[str] = "pil",
540
+ return_dict: bool = True,
541
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
542
+ callback_steps: int = 1,
543
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
544
+ debug: bool = False,
545
+ ):
546
+ r"""
547
+ Function invoked when calling the pipeline for generation.
548
+
549
+ Args:
550
+ prompt (`str` or `List[str]`, *optional*):
551
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
552
+ instead.
553
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
554
+ The height in pixels of the generated image.
555
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
556
+ The width in pixels of the generated image.
557
+ num_inference_steps (`int`, *optional*, defaults to 50):
558
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
559
+ expense of slower inference.
560
+ parallel (`int`, *optional*, defaults to 10):
561
+ The batch size to use when doing parallel sampling. More parallelism may lead to faster inference but
562
+ requires higher memory usage and also can require more total FLOPs.
563
+ tolerance (`float`, *optional*, defaults to 0.1):
564
+ The error tolerance for determining when to slide the batch window forward for parallel sampling. Lower
565
+ tolerance usually leads to less/no degradation. Higher tolerance is faster but can risk degradation of
566
+ sample quality. The tolerance is specified as a ratio of the scheduler's noise magnitude.
567
+ guidance_scale (`float`, *optional*, defaults to 7.5):
568
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
569
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
570
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
571
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
572
+ usually at the expense of lower image quality.
573
+ negative_prompt (`str` or `List[str]`, *optional*):
574
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
575
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
576
+ less than `1`).
577
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
578
+ The number of images to generate per prompt.
579
+ eta (`float`, *optional*, defaults to 0.0):
580
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
581
+ [`schedulers.DDIMScheduler`], will be ignored for others.
582
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
583
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
584
+ to make generation deterministic.
585
+ latents (`torch.FloatTensor`, *optional*):
586
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
587
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
588
+ tensor will ge generated by sampling using the supplied random `generator`.
589
+ prompt_embeds (`torch.FloatTensor`, *optional*):
590
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
591
+ provided, text embeddings will be generated from `prompt` input argument.
592
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
593
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
594
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
595
+ argument.
596
+ output_type (`str`, *optional*, defaults to `"pil"`):
597
+ The output format of the generate image. Choose between
598
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
599
+ return_dict (`bool`, *optional*, defaults to `True`):
600
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
601
+ plain tuple.
602
+ callback (`Callable`, *optional*):
603
+ A function that will be called every `callback_steps` steps during inference. The function will be
604
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
605
+ callback_steps (`int`, *optional*, defaults to 1):
606
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
607
+ called at every step.
608
+ cross_attention_kwargs (`dict`, *optional*):
609
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
610
+ `self.processor` in
611
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
612
+ debug (`bool`, *optional*, defaults to `False`):
613
+ Whether or not to run in debug mode. In debug mode, torch.cumsum is evaluated using the CPU.
614
+
615
+ Examples:
616
+
617
+ Returns:
618
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
619
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
620
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
621
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
622
+ (nsfw) content, according to the `safety_checker`.
623
+ """
624
+ # 0. Default height and width to unet
625
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
626
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
627
+
628
+ # 1. Check inputs. Raise error if not correct
629
+ self.check_inputs(
630
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
631
+ )
632
+
633
+ # 2. Define call parameters
634
+ if prompt is not None and isinstance(prompt, str):
635
+ batch_size = 1
636
+ elif prompt is not None and isinstance(prompt, list):
637
+ batch_size = len(prompt)
638
+ else:
639
+ batch_size = prompt_embeds.shape[0]
640
+
641
+ device = self._execution_device
642
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
643
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
644
+ # corresponds to doing no classifier free guidance.
645
+ do_classifier_free_guidance = guidance_scale > 1.0
646
+
647
+ # 3. Encode input prompt
648
+ prompt_embeds = self._encode_prompt(
649
+ prompt,
650
+ device,
651
+ num_images_per_prompt,
652
+ do_classifier_free_guidance,
653
+ negative_prompt,
654
+ prompt_embeds=prompt_embeds,
655
+ negative_prompt_embeds=negative_prompt_embeds,
656
+ )
657
+
658
+ # 4. Prepare timesteps
659
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
660
+
661
+ # 5. Prepare latent variables
662
+ num_channels_latents = self.unet.config.in_channels
663
+ latents = self.prepare_latents(
664
+ batch_size * num_images_per_prompt,
665
+ num_channels_latents,
666
+ height,
667
+ width,
668
+ prompt_embeds.dtype,
669
+ device,
670
+ generator,
671
+ latents,
672
+ )
673
+
674
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
675
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
676
+ extra_step_kwargs.pop("generator", None)
677
+
678
+ # # 7. Denoising loop
679
+ scheduler = self.scheduler
680
+ parallel = min(parallel, len(scheduler.timesteps))
681
+
682
+ begin_idx = 0
683
+ end_idx = parallel
684
+ latents_time_evolution_buffer = torch.stack([latents] * (len(scheduler.timesteps) + 1))
685
+
686
+ # We must make sure the noise of stochastic schedulers such as DDPM is sampled only once per timestep.
687
+ # Sampling inside the parallel denoising loop will mess this up, so we pre-sample the noise vectors outside the denoising loop.
688
+ noise_array = torch.zeros_like(latents_time_evolution_buffer)
689
+ for j in range(len(scheduler.timesteps)):
690
+ base_noise = randn_tensor(
691
+ shape=latents.shape, generator=generator, device=latents.device, dtype=prompt_embeds.dtype
692
+ )
693
+ noise = (self.scheduler._get_variance(scheduler.timesteps[j]) ** 0.5) * base_noise
694
+ noise_array[j] = noise.clone()
695
+
696
+ # We specify the error tolerance as a ratio of the scheduler's noise magnitude. We similarly compute the error tolerance
697
+ # outside of the denoising loop to avoid recomputing it at every step.
698
+ # We will be dividing the norm of the noise, so we store its inverse here to avoid a division at every step.
699
+ inverse_variance_norm = 1.0 / torch.tensor(
700
+ [scheduler._get_variance(scheduler.timesteps[j]) for j in range(len(scheduler.timesteps))] + [0]
701
+ ).to(noise_array.device)
702
+ latent_dim = noise_array[0, 0].numel()
703
+ inverse_variance_norm = inverse_variance_norm[:, None] / latent_dim
704
+
705
+ scaled_tolerance = tolerance**2
706
+
707
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
708
+ steps = 0
709
+ while begin_idx < len(scheduler.timesteps):
710
+ # these have shape (parallel_dim, 2*batch_size, ...)
711
+ # parallel_len is at most parallel, but could be less if we are at the end of the timesteps
712
+ # we are processing batch window of timesteps spanning [begin_idx, end_idx)
713
+ parallel_len = end_idx - begin_idx
714
+
715
+ block_prompt_embeds = torch.stack([prompt_embeds] * parallel_len)
716
+ block_latents = latents_time_evolution_buffer[begin_idx:end_idx]
717
+ block_t = scheduler.timesteps[begin_idx:end_idx, None].repeat(1, batch_size * num_images_per_prompt)
718
+ t_vec = block_t
719
+ if do_classifier_free_guidance:
720
+ t_vec = t_vec.repeat(1, 2)
721
+
722
+ # expand the latents if we are doing classifier free guidance
723
+ latent_model_input = (
724
+ torch.cat([block_latents] * 2, dim=1) if do_classifier_free_guidance else block_latents
725
+ )
726
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t_vec)
727
+
728
+ # if parallel_len is small, no need to use multiple GPUs
729
+ net = self.wrapped_unet if parallel_len > 3 else self.unet
730
+ # predict the noise residual, shape is now [parallel_len * 2 * batch_size * num_images_per_prompt, ...]
731
+ model_output = net(
732
+ latent_model_input.flatten(0, 1),
733
+ t_vec.flatten(0, 1),
734
+ encoder_hidden_states=block_prompt_embeds.flatten(0, 1),
735
+ cross_attention_kwargs=cross_attention_kwargs,
736
+ return_dict=False,
737
+ )[0]
738
+
739
+ per_latent_shape = model_output.shape[1:]
740
+ if do_classifier_free_guidance:
741
+ model_output = model_output.reshape(
742
+ parallel_len, 2, batch_size * num_images_per_prompt, *per_latent_shape
743
+ )
744
+ noise_pred_uncond, noise_pred_text = model_output[:, 0], model_output[:, 1]
745
+ model_output = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
746
+ model_output = model_output.reshape(
747
+ parallel_len * batch_size * num_images_per_prompt, *per_latent_shape
748
+ )
749
+
750
+ block_latents_denoise = scheduler.batch_step_no_noise(
751
+ model_output=model_output,
752
+ timesteps=block_t.flatten(0, 1),
753
+ sample=block_latents.flatten(0, 1),
754
+ **extra_step_kwargs,
755
+ ).reshape(block_latents.shape)
756
+
757
+ # back to shape (parallel_dim, batch_size, ...)
758
+ # now we want to add the pre-sampled noise
759
+ # parallel sampling algorithm requires computing the cumulative drift from the beginning
760
+ # of the window, so we need to compute cumulative sum of the deltas and the pre-sampled noises.
761
+ delta = block_latents_denoise - block_latents
762
+ cumulative_delta = self._cumsum(delta, dim=0, debug=debug)
763
+ cumulative_noise = self._cumsum(noise_array[begin_idx:end_idx], dim=0, debug=debug)
764
+
765
+ # if we are using an ODE-like scheduler (like DDIM), we don't want to add noise
766
+ if scheduler._is_ode_scheduler:
767
+ cumulative_noise = 0
768
+
769
+ block_latents_new = (
770
+ latents_time_evolution_buffer[begin_idx][None,] + cumulative_delta + cumulative_noise
771
+ )
772
+ cur_error = torch.linalg.norm(
773
+ (block_latents_new - latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1]).reshape(
774
+ parallel_len, batch_size * num_images_per_prompt, -1
775
+ ),
776
+ dim=-1,
777
+ ).pow(2)
778
+ error_ratio = cur_error * inverse_variance_norm[begin_idx + 1 : end_idx + 1]
779
+
780
+ # find the first index of the vector error_ratio that is greater than error tolerance
781
+ # we can shift the window for the next iteration up to this index
782
+ error_ratio = torch.nn.functional.pad(
783
+ error_ratio, (0, 0, 0, 1), value=1e9
784
+ ) # handle the case when everything is below ratio, by padding the end of parallel_len dimension
785
+ any_error_at_time = torch.max(error_ratio > scaled_tolerance, dim=1).values.int()
786
+ ind = torch.argmax(any_error_at_time).item()
787
+
788
+ # compute the new begin and end idxs for the window
789
+ new_begin_idx = begin_idx + min(1 + ind, parallel)
790
+ new_end_idx = min(new_begin_idx + parallel, len(scheduler.timesteps))
791
+
792
+ # store the computed latents for the current window in the global buffer
793
+ latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1] = block_latents_new
794
+ # initialize the new sliding window latents with the end of the current window,
795
+ # should be better than random initialization
796
+ latents_time_evolution_buffer[end_idx : new_end_idx + 1] = latents_time_evolution_buffer[end_idx][
797
+ None,
798
+ ]
799
+
800
+ steps += 1
801
+
802
+ progress_bar.update(new_begin_idx - begin_idx)
803
+ if callback is not None and steps % callback_steps == 0:
804
+ callback(begin_idx, block_t[begin_idx], latents_time_evolution_buffer[begin_idx])
805
+
806
+ begin_idx = new_begin_idx
807
+ end_idx = new_end_idx
808
+
809
+ latents = latents_time_evolution_buffer[-1]
810
+
811
+ if not output_type == "latent":
812
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
813
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
814
+ else:
815
+ image = latents
816
+ has_nsfw_concept = None
817
+
818
+ if has_nsfw_concept is None:
819
+ do_denormalize = [True] * image.shape[0]
820
+ else:
821
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
822
+
823
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
824
+
825
+ # Offload last model to CPU
826
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
827
+ self.final_offload_hook.offload()
828
+
829
+ if not return_dict:
830
+ return (image, has_nsfw_concept)
831
+
832
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)