diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,531 @@
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from copy import deepcopy
16
+ from typing import List, Optional, Union
17
+
18
+ import numpy as np
19
+ import PIL
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from PIL import Image
23
+
24
+ from ...models import UNet2DConditionModel, VQModel
25
+ from ...pipelines import DiffusionPipeline
26
+ from ...pipelines.pipeline_utils import ImagePipelineOutput
27
+ from ...schedulers import DDPMScheduler
28
+ from ...utils import (
29
+ is_accelerate_available,
30
+ is_accelerate_version,
31
+ logging,
32
+ randn_tensor,
33
+ replace_example_docstring,
34
+ )
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+ EXAMPLE_DOC_STRING = """
40
+ Examples:
41
+ ```py
42
+ >>> from diffusers import KandinskyV22InpaintPipeline, KandinskyV22PriorPipeline
43
+ >>> from diffusers.utils import load_image
44
+ >>> import torch
45
+ >>> import numpy as np
46
+
47
+ >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
48
+ ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
49
+ ... )
50
+ >>> pipe_prior.to("cuda")
51
+
52
+ >>> prompt = "a hat"
53
+ >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False)
54
+
55
+ >>> pipe = KandinskyV22InpaintPipeline.from_pretrained(
56
+ ... "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
57
+ ... )
58
+ >>> pipe.to("cuda")
59
+
60
+ >>> init_image = load_image(
61
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
62
+ ... "/kandinsky/cat.png"
63
+ ... )
64
+
65
+ >>> mask = np.ones((768, 768), dtype=np.float32)
66
+ >>> mask[:250, 250:-250] = 0
67
+
68
+ >>> out = pipe(
69
+ ... image=init_image,
70
+ ... mask_image=mask,
71
+ ... image_embeds=image_emb,
72
+ ... negative_image_embeds=zero_image_emb,
73
+ ... height=768,
74
+ ... width=768,
75
+ ... num_inference_steps=50,
76
+ ... )
77
+
78
+ >>> image = out.images[0]
79
+ >>> image.save("cat_with_hat.png")
80
+ ```
81
+ """
82
+
83
+
84
+ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
85
+ def downscale_height_and_width(height, width, scale_factor=8):
86
+ new_height = height // scale_factor**2
87
+ if height % scale_factor**2 != 0:
88
+ new_height += 1
89
+ new_width = width // scale_factor**2
90
+ if width % scale_factor**2 != 0:
91
+ new_width += 1
92
+ return new_height * scale_factor, new_width * scale_factor
93
+
94
+
95
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask
96
+ def prepare_mask(masks):
97
+ prepared_masks = []
98
+ for mask in masks:
99
+ old_mask = deepcopy(mask)
100
+ for i in range(mask.shape[1]):
101
+ for j in range(mask.shape[2]):
102
+ if old_mask[0][i][j] == 1:
103
+ continue
104
+ if i != 0:
105
+ mask[:, i - 1, j] = 0
106
+ if j != 0:
107
+ mask[:, i, j - 1] = 0
108
+ if i != 0 and j != 0:
109
+ mask[:, i - 1, j - 1] = 0
110
+ if i != mask.shape[1] - 1:
111
+ mask[:, i + 1, j] = 0
112
+ if j != mask.shape[2] - 1:
113
+ mask[:, i, j + 1] = 0
114
+ if i != mask.shape[1] - 1 and j != mask.shape[2] - 1:
115
+ mask[:, i + 1, j + 1] = 0
116
+ prepared_masks.append(mask)
117
+ return torch.stack(prepared_masks, dim=0)
118
+
119
+
120
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask_and_masked_image
121
+ def prepare_mask_and_masked_image(image, mask, height, width):
122
+ r"""
123
+ Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will
124
+ be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for
125
+ the ``image`` and ``1`` for the ``mask``.
126
+
127
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
128
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
129
+
130
+ Args:
131
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
132
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
133
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
134
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
135
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
136
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
137
+ height (`int`, *optional*, defaults to 512):
138
+ The height in pixels of the generated image.
139
+ width (`int`, *optional*, defaults to 512):
140
+ The width in pixels of the generated image.
141
+
142
+
143
+ Raises:
144
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
145
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
146
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
147
+ (ot the other way around).
148
+
149
+ Returns:
150
+ tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4
151
+ dimensions: ``batch x channels x height x width``.
152
+ """
153
+
154
+ if image is None:
155
+ raise ValueError("`image` input cannot be undefined.")
156
+
157
+ if mask is None:
158
+ raise ValueError("`mask_image` input cannot be undefined.")
159
+
160
+ if isinstance(image, torch.Tensor):
161
+ if not isinstance(mask, torch.Tensor):
162
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
163
+
164
+ # Batch single image
165
+ if image.ndim == 3:
166
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
167
+ image = image.unsqueeze(0)
168
+
169
+ # Batch and add channel dim for single mask
170
+ if mask.ndim == 2:
171
+ mask = mask.unsqueeze(0).unsqueeze(0)
172
+
173
+ # Batch single mask or add channel dim
174
+ if mask.ndim == 3:
175
+ # Single batched mask, no channel dim or single mask not batched but channel dim
176
+ if mask.shape[0] == 1:
177
+ mask = mask.unsqueeze(0)
178
+
179
+ # Batched masks no channel dim
180
+ else:
181
+ mask = mask.unsqueeze(1)
182
+
183
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
184
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
185
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
186
+
187
+ # Check image is in [-1, 1]
188
+ if image.min() < -1 or image.max() > 1:
189
+ raise ValueError("Image should be in [-1, 1] range")
190
+
191
+ # Check mask is in [0, 1]
192
+ if mask.min() < 0 or mask.max() > 1:
193
+ raise ValueError("Mask should be in [0, 1] range")
194
+
195
+ # Binarize mask
196
+ mask[mask < 0.5] = 0
197
+ mask[mask >= 0.5] = 1
198
+
199
+ # Image as float32
200
+ image = image.to(dtype=torch.float32)
201
+ elif isinstance(mask, torch.Tensor):
202
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
203
+ else:
204
+ # preprocess image
205
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
206
+ image = [image]
207
+
208
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
209
+ # resize all images w.r.t passed height an width
210
+ image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image]
211
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
212
+ image = np.concatenate(image, axis=0)
213
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
214
+ image = np.concatenate([i[None, :] for i in image], axis=0)
215
+
216
+ image = image.transpose(0, 3, 1, 2)
217
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
218
+
219
+ # preprocess mask
220
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
221
+ mask = [mask]
222
+
223
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
224
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
225
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
226
+ mask = mask.astype(np.float32) / 255.0
227
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
228
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
229
+
230
+ mask[mask < 0.5] = 0
231
+ mask[mask >= 0.5] = 1
232
+ mask = torch.from_numpy(mask)
233
+
234
+ return mask, image
235
+
236
+
237
+ class KandinskyV22InpaintPipeline(DiffusionPipeline):
238
+ """
239
+ Pipeline for text-guided image inpainting using Kandinsky2.1
240
+
241
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
242
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
243
+
244
+ Args:
245
+ scheduler ([`DDIMScheduler`]):
246
+ A scheduler to be used in combination with `unet` to generate image latents.
247
+ unet ([`UNet2DConditionModel`]):
248
+ Conditional U-Net architecture to denoise the image embedding.
249
+ movq ([`VQModel`]):
250
+ MoVQ Decoder to generate the image from the latents.
251
+ """
252
+
253
+ def __init__(
254
+ self,
255
+ unet: UNet2DConditionModel,
256
+ scheduler: DDPMScheduler,
257
+ movq: VQModel,
258
+ ):
259
+ super().__init__()
260
+
261
+ self.register_modules(
262
+ unet=unet,
263
+ scheduler=scheduler,
264
+ movq=movq,
265
+ )
266
+ self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
267
+
268
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
269
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
270
+ if latents is None:
271
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
272
+ else:
273
+ if latents.shape != shape:
274
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
275
+ latents = latents.to(device)
276
+
277
+ latents = latents * scheduler.init_noise_sigma
278
+ return latents
279
+
280
+ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_sequential_cpu_offload
281
+ def enable_sequential_cpu_offload(self, gpu_id=0):
282
+ r"""
283
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
284
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
285
+ when their specific submodule has its `forward` method called.
286
+ """
287
+ if is_accelerate_available():
288
+ from accelerate import cpu_offload
289
+ else:
290
+ raise ImportError("Please install accelerate via `pip install accelerate`")
291
+
292
+ device = torch.device(f"cuda:{gpu_id}")
293
+
294
+ models = [
295
+ self.unet,
296
+ self.movq,
297
+ ]
298
+ for cpu_offloaded_model in models:
299
+ if cpu_offloaded_model is not None:
300
+ cpu_offload(cpu_offloaded_model, device)
301
+
302
+ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload
303
+ def enable_model_cpu_offload(self, gpu_id=0):
304
+ r"""
305
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
306
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
307
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
308
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
309
+ """
310
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
311
+ from accelerate import cpu_offload_with_hook
312
+ else:
313
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
314
+
315
+ device = torch.device(f"cuda:{gpu_id}")
316
+
317
+ if self.device.type != "cpu":
318
+ self.to("cpu", silence_dtype_warnings=True)
319
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
320
+
321
+ hook = None
322
+ for cpu_offloaded_model in [self.unet, self.movq]:
323
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
324
+
325
+ # We'll offload the last model manually.
326
+ self.final_offload_hook = hook
327
+
328
+ @property
329
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
330
+ def _execution_device(self):
331
+ r"""
332
+ Returns the device on which the pipeline's models will be executed. After calling
333
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
334
+ hooks.
335
+ """
336
+ if not hasattr(self.unet, "_hf_hook"):
337
+ return self.device
338
+ for module in self.unet.modules():
339
+ if (
340
+ hasattr(module, "_hf_hook")
341
+ and hasattr(module._hf_hook, "execution_device")
342
+ and module._hf_hook.execution_device is not None
343
+ ):
344
+ return torch.device(module._hf_hook.execution_device)
345
+ return self.device
346
+
347
+ @torch.no_grad()
348
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
349
+ def __call__(
350
+ self,
351
+ image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
352
+ image: Union[torch.FloatTensor, PIL.Image.Image],
353
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
354
+ negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]],
355
+ height: int = 512,
356
+ width: int = 512,
357
+ num_inference_steps: int = 100,
358
+ guidance_scale: float = 4.0,
359
+ num_images_per_prompt: int = 1,
360
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
361
+ latents: Optional[torch.FloatTensor] = None,
362
+ output_type: Optional[str] = "pil",
363
+ return_dict: bool = True,
364
+ ):
365
+ """
366
+ Args:
367
+ Function invoked when calling the pipeline for generation.
368
+ image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
369
+ The clip image embeddings for text prompt, that will be used to condition the image generation.
370
+ image (`PIL.Image.Image`):
371
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
372
+ be masked out with `mask_image` and repainted according to `prompt`.
373
+ mask_image (`np.array`):
374
+ Tensor representing an image batch, to mask `image`. Black pixels in the mask will be repainted, while
375
+ white pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single
376
+ channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3,
377
+ so the expected shape would be `(B, H, W, 1)`.
378
+ negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`):
379
+ The clip image embeddings for negative text prompt, will be used to condition the image generation.
380
+ height (`int`, *optional*, defaults to 512):
381
+ The height in pixels of the generated image.
382
+ width (`int`, *optional*, defaults to 512):
383
+ The width in pixels of the generated image.
384
+ num_inference_steps (`int`, *optional*, defaults to 100):
385
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
386
+ expense of slower inference.
387
+ guidance_scale (`float`, *optional*, defaults to 4.0):
388
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
389
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
390
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
391
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
392
+ usually at the expense of lower image quality.
393
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
394
+ The number of images to generate per prompt.
395
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
396
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
397
+ to make generation deterministic.
398
+ latents (`torch.FloatTensor`, *optional*):
399
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
400
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
401
+ tensor will ge generated by sampling using the supplied random `generator`.
402
+ output_type (`str`, *optional*, defaults to `"pil"`):
403
+ The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
404
+ (`np.array`) or `"pt"` (`torch.Tensor`).
405
+ return_dict (`bool`, *optional*, defaults to `True`):
406
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
407
+
408
+ Examples:
409
+
410
+ Returns:
411
+ [`~pipelines.ImagePipelineOutput`] or `tuple`
412
+ """
413
+ device = self._execution_device
414
+
415
+ do_classifier_free_guidance = guidance_scale > 1.0
416
+
417
+ if isinstance(image_embeds, list):
418
+ image_embeds = torch.cat(image_embeds, dim=0)
419
+ batch_size = image_embeds.shape[0] * num_images_per_prompt
420
+ if isinstance(negative_image_embeds, list):
421
+ negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
422
+
423
+ if do_classifier_free_guidance:
424
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
425
+ negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
426
+
427
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
428
+
429
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
430
+ timesteps_tensor = self.scheduler.timesteps
431
+
432
+ # preprocess image and mask
433
+ mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width)
434
+
435
+ image = image.to(dtype=image_embeds.dtype, device=device)
436
+ image = self.movq.encode(image)["latents"]
437
+
438
+ mask_image = mask_image.to(dtype=image_embeds.dtype, device=device)
439
+
440
+ image_shape = tuple(image.shape[-2:])
441
+ mask_image = F.interpolate(
442
+ mask_image,
443
+ image_shape,
444
+ mode="nearest",
445
+ )
446
+ mask_image = prepare_mask(mask_image)
447
+ masked_image = image * mask_image
448
+
449
+ mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0)
450
+ masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0)
451
+ if do_classifier_free_guidance:
452
+ mask_image = mask_image.repeat(2, 1, 1, 1)
453
+ masked_image = masked_image.repeat(2, 1, 1, 1)
454
+
455
+ num_channels_latents = self.movq.config.latent_channels
456
+
457
+ height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
458
+
459
+ # create initial latent
460
+ latents = self.prepare_latents(
461
+ (batch_size, num_channels_latents, height, width),
462
+ image_embeds.dtype,
463
+ device,
464
+ generator,
465
+ latents,
466
+ self.scheduler,
467
+ )
468
+ noise = torch.clone(latents)
469
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
470
+ # expand the latents if we are doing classifier free guidance
471
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
472
+ latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1)
473
+
474
+ added_cond_kwargs = {"image_embeds": image_embeds}
475
+ noise_pred = self.unet(
476
+ sample=latent_model_input,
477
+ timestep=t,
478
+ encoder_hidden_states=None,
479
+ added_cond_kwargs=added_cond_kwargs,
480
+ return_dict=False,
481
+ )[0]
482
+
483
+ if do_classifier_free_guidance:
484
+ noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
485
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
486
+ _, variance_pred_text = variance_pred.chunk(2)
487
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
488
+ noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
489
+
490
+ if not (
491
+ hasattr(self.scheduler.config, "variance_type")
492
+ and self.scheduler.config.variance_type in ["learned", "learned_range"]
493
+ ):
494
+ noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
495
+
496
+ # compute the previous noisy sample x_t -> x_t-1
497
+ latents = self.scheduler.step(
498
+ noise_pred,
499
+ t,
500
+ latents,
501
+ generator=generator,
502
+ )[0]
503
+ init_latents_proper = image[:1]
504
+ init_mask = mask_image[:1]
505
+
506
+ if i < len(timesteps_tensor) - 1:
507
+ noise_timestep = timesteps_tensor[i + 1]
508
+ init_latents_proper = self.scheduler.add_noise(
509
+ init_latents_proper, noise, torch.tensor([noise_timestep])
510
+ )
511
+
512
+ latents = init_mask * init_latents_proper + (1 - init_mask) * latents
513
+ # post-processing
514
+ latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents
515
+ image = self.movq.decode(latents, force_not_quantize=True)["sample"]
516
+
517
+ if output_type not in ["pt", "np", "pil"]:
518
+ raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
519
+
520
+ if output_type in ["np", "pil"]:
521
+ image = image * 0.5 + 0.5
522
+ image = image.clamp(0, 1)
523
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
524
+
525
+ if output_type == "pil":
526
+ image = self.numpy_to_pil(image)
527
+
528
+ if not return_dict:
529
+ return (image,)
530
+
531
+ return ImagePipelineOutput(images=image)