diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,605 @@
1
+ from typing import List, Optional, Union
2
+
3
+ import PIL
4
+ import torch
5
+ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
6
+
7
+ from ...models import PriorTransformer
8
+ from ...pipelines import DiffusionPipeline
9
+ from ...schedulers import UnCLIPScheduler
10
+ from ...utils import (
11
+ is_accelerate_available,
12
+ logging,
13
+ randn_tensor,
14
+ replace_example_docstring,
15
+ )
16
+ from ..kandinsky import KandinskyPriorPipelineOutput
17
+
18
+
19
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
+
21
+ EXAMPLE_DOC_STRING = """
22
+ Examples:
23
+ ```py
24
+ >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorEmb2EmbPipeline
25
+ >>> import torch
26
+
27
+ >>> pipe_prior = KandinskyPriorPipeline.from_pretrained(
28
+ ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
29
+ ... )
30
+ >>> pipe_prior.to("cuda")
31
+
32
+ >>> prompt = "red cat, 4k photo"
33
+ >>> img = load_image(
34
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
35
+ ... "/kandinsky/cat.png"
36
+ ... )
37
+ >>> image_emb, nagative_image_emb = pipe_prior(prompt, image=img, strength=0.2).to_tuple()
38
+
39
+ >>> pipe = KandinskyPipeline.from_pretrained(
40
+ ... "kandinsky-community/kandinsky-2-2-decoder, torch_dtype=torch.float16"
41
+ ... )
42
+ >>> pipe.to("cuda")
43
+
44
+ >>> image = pipe(
45
+ ... image_embeds=image_emb,
46
+ ... negative_image_embeds=negative_image_emb,
47
+ ... height=768,
48
+ ... width=768,
49
+ ... num_inference_steps=100,
50
+ ... ).images
51
+
52
+ >>> image[0].save("cat.png")
53
+ ```
54
+ """
55
+
56
+ EXAMPLE_INTERPOLATE_DOC_STRING = """
57
+ Examples:
58
+ ```py
59
+ >>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22Pipeline
60
+ >>> from diffusers.utils import load_image
61
+ >>> import PIL
62
+
63
+ >>> import torch
64
+ >>> from torchvision import transforms
65
+
66
+ >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
67
+ ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
68
+ ... )
69
+ >>> pipe_prior.to("cuda")
70
+
71
+ >>> img1 = load_image(
72
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
73
+ ... "/kandinsky/cat.png"
74
+ ... )
75
+
76
+ >>> img2 = load_image(
77
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
78
+ ... "/kandinsky/starry_night.jpeg"
79
+ ... )
80
+
81
+ >>> images_texts = ["a cat", img1, img2]
82
+ >>> weights = [0.3, 0.3, 0.4]
83
+ >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights)
84
+
85
+ >>> pipe = KandinskyV22Pipeline.from_pretrained(
86
+ ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
87
+ ... )
88
+ >>> pipe.to("cuda")
89
+
90
+ >>> image = pipe(
91
+ ... image_embeds=image_emb,
92
+ ... negative_image_embeds=zero_image_emb,
93
+ ... height=768,
94
+ ... width=768,
95
+ ... num_inference_steps=150,
96
+ ... ).images[0]
97
+
98
+ >>> image.save("starry_cat.png")
99
+ ```
100
+ """
101
+
102
+
103
+ class KandinskyV22PriorEmb2EmbPipeline(DiffusionPipeline):
104
+ """
105
+ Pipeline for generating image prior for Kandinsky
106
+
107
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
108
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
109
+
110
+ Args:
111
+ prior ([`PriorTransformer`]):
112
+ The canonincal unCLIP prior to approximate the image embedding from the text embedding.
113
+ image_encoder ([`CLIPVisionModelWithProjection`]):
114
+ Frozen image-encoder.
115
+ text_encoder ([`CLIPTextModelWithProjection`]):
116
+ Frozen text-encoder.
117
+ tokenizer (`CLIPTokenizer`):
118
+ Tokenizer of class
119
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
120
+ scheduler ([`UnCLIPScheduler`]):
121
+ A scheduler to be used in combination with `prior` to generate image embedding.
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ prior: PriorTransformer,
127
+ image_encoder: CLIPVisionModelWithProjection,
128
+ text_encoder: CLIPTextModelWithProjection,
129
+ tokenizer: CLIPTokenizer,
130
+ scheduler: UnCLIPScheduler,
131
+ image_processor: CLIPImageProcessor,
132
+ ):
133
+ super().__init__()
134
+
135
+ self.register_modules(
136
+ prior=prior,
137
+ text_encoder=text_encoder,
138
+ tokenizer=tokenizer,
139
+ scheduler=scheduler,
140
+ image_encoder=image_encoder,
141
+ image_processor=image_processor,
142
+ )
143
+
144
+ def get_timesteps(self, num_inference_steps, strength, device):
145
+ # get the original timestep using init_timestep
146
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
147
+
148
+ t_start = max(num_inference_steps - init_timestep, 0)
149
+ timesteps = self.scheduler.timesteps[t_start:]
150
+
151
+ return timesteps, num_inference_steps - t_start
152
+
153
+ @torch.no_grad()
154
+ @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING)
155
+ def interpolate(
156
+ self,
157
+ images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]],
158
+ weights: List[float],
159
+ num_images_per_prompt: int = 1,
160
+ num_inference_steps: int = 25,
161
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
162
+ latents: Optional[torch.FloatTensor] = None,
163
+ negative_prior_prompt: Optional[str] = None,
164
+ negative_prompt: Union[str] = "",
165
+ guidance_scale: float = 4.0,
166
+ device=None,
167
+ ):
168
+ """
169
+ Function invoked when using the prior pipeline for interpolation.
170
+
171
+ Args:
172
+ images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`):
173
+ list of prompts and images to guide the image generation.
174
+ weights: (`List[float]`):
175
+ list of weights for each condition in `images_and_prompts`
176
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
177
+ The number of images to generate per prompt.
178
+ num_inference_steps (`int`, *optional*, defaults to 100):
179
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
180
+ expense of slower inference.
181
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
182
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
183
+ to make generation deterministic.
184
+ latents (`torch.FloatTensor`, *optional*):
185
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
186
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
187
+ tensor will ge generated by sampling using the supplied random `generator`.
188
+ negative_prior_prompt (`str`, *optional*):
189
+ The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
190
+ `guidance_scale` is less than `1`).
191
+ negative_prompt (`str` or `List[str]`, *optional*):
192
+ The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if
193
+ `guidance_scale` is less than `1`).
194
+ guidance_scale (`float`, *optional*, defaults to 4.0):
195
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
196
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
197
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
198
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
199
+ usually at the expense of lower image quality.
200
+
201
+ Examples:
202
+
203
+ Returns:
204
+ [`KandinskyPriorPipelineOutput`] or `tuple`
205
+ """
206
+
207
+ device = device or self.device
208
+
209
+ if len(images_and_prompts) != len(weights):
210
+ raise ValueError(
211
+ f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length"
212
+ )
213
+
214
+ image_embeddings = []
215
+ for cond, weight in zip(images_and_prompts, weights):
216
+ if isinstance(cond, str):
217
+ image_emb = self(
218
+ cond,
219
+ num_inference_steps=num_inference_steps,
220
+ num_images_per_prompt=num_images_per_prompt,
221
+ generator=generator,
222
+ latents=latents,
223
+ negative_prompt=negative_prior_prompt,
224
+ guidance_scale=guidance_scale,
225
+ ).image_embeds.unsqueeze(0)
226
+
227
+ elif isinstance(cond, (PIL.Image.Image, torch.Tensor)):
228
+ image_emb = self._encode_image(
229
+ cond, device=device, num_images_per_prompt=num_images_per_prompt
230
+ ).unsqueeze(0)
231
+
232
+ else:
233
+ raise ValueError(
234
+ f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}"
235
+ )
236
+
237
+ image_embeddings.append(image_emb * weight)
238
+
239
+ image_emb = torch.cat(image_embeddings).sum(dim=0)
240
+
241
+ return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=torch.randn_like(image_emb))
242
+
243
+ def _encode_image(
244
+ self,
245
+ image: Union[torch.Tensor, List[PIL.Image.Image]],
246
+ device,
247
+ num_images_per_prompt,
248
+ ):
249
+ if not isinstance(image, torch.Tensor):
250
+ image = self.image_processor(image, return_tensors="pt").pixel_values.to(
251
+ dtype=self.image_encoder.dtype, device=device
252
+ )
253
+
254
+ image_emb = self.image_encoder(image)["image_embeds"] # B, D
255
+ image_emb = image_emb.repeat_interleave(num_images_per_prompt, dim=0)
256
+ image_emb.to(device=device)
257
+
258
+ return image_emb
259
+
260
+ def prepare_latents(self, emb, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
261
+ emb = emb.to(device=device, dtype=dtype)
262
+
263
+ batch_size = batch_size * num_images_per_prompt
264
+
265
+ init_latents = emb
266
+
267
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
268
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
269
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
270
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
271
+ raise ValueError(
272
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
273
+ )
274
+ else:
275
+ init_latents = torch.cat([init_latents], dim=0)
276
+
277
+ shape = init_latents.shape
278
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
279
+
280
+ # get latents
281
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
282
+ latents = init_latents
283
+
284
+ return latents
285
+
286
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed
287
+ def get_zero_embed(self, batch_size=1, device=None):
288
+ device = device or self.device
289
+ zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to(
290
+ device=device, dtype=self.image_encoder.dtype
291
+ )
292
+ zero_image_emb = self.image_encoder(zero_img)["image_embeds"]
293
+ zero_image_emb = zero_image_emb.repeat(batch_size, 1)
294
+ return zero_image_emb
295
+
296
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.enable_sequential_cpu_offload
297
+ def enable_sequential_cpu_offload(self, gpu_id=0):
298
+ r"""
299
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
300
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
301
+ when their specific submodule has its `forward` method called.
302
+ """
303
+ if is_accelerate_available():
304
+ from accelerate import cpu_offload
305
+ else:
306
+ raise ImportError("Please install accelerate via `pip install accelerate`")
307
+
308
+ device = torch.device(f"cuda:{gpu_id}")
309
+
310
+ models = [
311
+ self.image_encoder,
312
+ self.text_encoder,
313
+ ]
314
+ for cpu_offloaded_model in models:
315
+ if cpu_offloaded_model is not None:
316
+ cpu_offload(cpu_offloaded_model, device)
317
+
318
+ @property
319
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._execution_device
320
+ def _execution_device(self):
321
+ r"""
322
+ Returns the device on which the pipeline's models will be executed. After calling
323
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
324
+ hooks.
325
+ """
326
+ if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"):
327
+ return self.device
328
+ for module in self.text_encoder.modules():
329
+ if (
330
+ hasattr(module, "_hf_hook")
331
+ and hasattr(module._hf_hook, "execution_device")
332
+ and module._hf_hook.execution_device is not None
333
+ ):
334
+ return torch.device(module._hf_hook.execution_device)
335
+ return self.device
336
+
337
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt
338
+ def _encode_prompt(
339
+ self,
340
+ prompt,
341
+ device,
342
+ num_images_per_prompt,
343
+ do_classifier_free_guidance,
344
+ negative_prompt=None,
345
+ ):
346
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
347
+ # get prompt text embeddings
348
+ text_inputs = self.tokenizer(
349
+ prompt,
350
+ padding="max_length",
351
+ max_length=self.tokenizer.model_max_length,
352
+ truncation=True,
353
+ return_tensors="pt",
354
+ )
355
+ text_input_ids = text_inputs.input_ids
356
+ text_mask = text_inputs.attention_mask.bool().to(device)
357
+
358
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
359
+
360
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
361
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
362
+ logger.warning(
363
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
364
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
365
+ )
366
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
367
+
368
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
369
+
370
+ prompt_embeds = text_encoder_output.text_embeds
371
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
372
+
373
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
374
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
375
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
376
+
377
+ if do_classifier_free_guidance:
378
+ uncond_tokens: List[str]
379
+ if negative_prompt is None:
380
+ uncond_tokens = [""] * batch_size
381
+ elif type(prompt) is not type(negative_prompt):
382
+ raise TypeError(
383
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
384
+ f" {type(prompt)}."
385
+ )
386
+ elif isinstance(negative_prompt, str):
387
+ uncond_tokens = [negative_prompt]
388
+ elif batch_size != len(negative_prompt):
389
+ raise ValueError(
390
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
391
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
392
+ " the batch size of `prompt`."
393
+ )
394
+ else:
395
+ uncond_tokens = negative_prompt
396
+
397
+ uncond_input = self.tokenizer(
398
+ uncond_tokens,
399
+ padding="max_length",
400
+ max_length=self.tokenizer.model_max_length,
401
+ truncation=True,
402
+ return_tensors="pt",
403
+ )
404
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
405
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
406
+
407
+ negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
408
+ uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
409
+
410
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
411
+
412
+ seq_len = negative_prompt_embeds.shape[1]
413
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
414
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
415
+
416
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
417
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
418
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
419
+ batch_size * num_images_per_prompt, seq_len, -1
420
+ )
421
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
422
+
423
+ # done duplicates
424
+
425
+ # For classifier free guidance, we need to do two forward passes.
426
+ # Here we concatenate the unconditional and text embeddings into a single batch
427
+ # to avoid doing two forward passes
428
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
429
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
430
+
431
+ text_mask = torch.cat([uncond_text_mask, text_mask])
432
+
433
+ return prompt_embeds, text_encoder_hidden_states, text_mask
434
+
435
+ @torch.no_grad()
436
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
437
+ def __call__(
438
+ self,
439
+ prompt: Union[str, List[str]],
440
+ image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]],
441
+ strength: float = 0.3,
442
+ negative_prompt: Optional[Union[str, List[str]]] = None,
443
+ num_images_per_prompt: int = 1,
444
+ num_inference_steps: int = 25,
445
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
446
+ latents: Optional[torch.FloatTensor] = None,
447
+ guidance_scale: float = 4.0,
448
+ output_type: Optional[str] = "pt", # pt only
449
+ return_dict: bool = True,
450
+ ):
451
+ """
452
+ Function invoked when calling the pipeline for generation.
453
+
454
+ Args:
455
+ prompt (`str` or `List[str]`):
456
+ The prompt or prompts to guide the image generation.
457
+ strength (`float`, *optional*, defaults to 0.8):
458
+ Conceptually, indicates how much to transform the reference `emb`. Must be between 0 and 1. `image`
459
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
460
+ denoising steps depends on the amount of noise initially added.
461
+ emb (`torch.FloatTensor`):
462
+ The image embedding.
463
+ negative_prompt (`str` or `List[str]`, *optional*):
464
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
465
+ if `guidance_scale` is less than `1`).
466
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
467
+ The number of images to generate per prompt.
468
+ num_inference_steps (`int`, *optional*, defaults to 100):
469
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
470
+ expense of slower inference.
471
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
472
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
473
+ to make generation deterministic.
474
+ latents (`torch.FloatTensor`, *optional*):
475
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
476
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
477
+ tensor will ge generated by sampling using the supplied random `generator`.
478
+ guidance_scale (`float`, *optional*, defaults to 4.0):
479
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
480
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
481
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
482
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
483
+ usually at the expense of lower image quality.
484
+ output_type (`str`, *optional*, defaults to `"pt"`):
485
+ The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
486
+ (`torch.Tensor`).
487
+ return_dict (`bool`, *optional*, defaults to `True`):
488
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
489
+
490
+ Examples:
491
+
492
+ Returns:
493
+ [`KandinskyPriorPipelineOutput`] or `tuple`
494
+ """
495
+
496
+ if isinstance(prompt, str):
497
+ prompt = [prompt]
498
+ elif not isinstance(prompt, list):
499
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
500
+
501
+ if isinstance(negative_prompt, str):
502
+ negative_prompt = [negative_prompt]
503
+ elif not isinstance(negative_prompt, list) and negative_prompt is not None:
504
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
505
+
506
+ # if the negative prompt is defined we double the batch size to
507
+ # directly retrieve the negative prompt embedding
508
+ if negative_prompt is not None:
509
+ prompt = prompt + negative_prompt
510
+ negative_prompt = 2 * negative_prompt
511
+
512
+ device = self._execution_device
513
+
514
+ batch_size = len(prompt)
515
+ batch_size = batch_size * num_images_per_prompt
516
+
517
+ do_classifier_free_guidance = guidance_scale > 1.0
518
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
519
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
520
+ )
521
+
522
+ if not isinstance(image, List):
523
+ image = [image]
524
+
525
+ if isinstance(image[0], torch.Tensor):
526
+ image = torch.cat(image, dim=0)
527
+
528
+ if isinstance(image, torch.Tensor) and image.ndim == 2:
529
+ # allow user to pass image_embeds directly
530
+ image_embeds = image.repeat_interleave(num_images_per_prompt, dim=0)
531
+ elif isinstance(image, torch.Tensor) and image.ndim != 4:
532
+ raise ValueError(
533
+ f" if pass `image` as pytorch tensor, or a list of pytorch tensor, please make sure each tensor has shape [batch_size, channels, height, width], currently {image[0].unsqueeze(0).shape}"
534
+ )
535
+ else:
536
+ image_embeds = self._encode_image(image, device, num_images_per_prompt)
537
+
538
+ # prior
539
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
540
+
541
+ latents = image_embeds
542
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
543
+ latent_timestep = timesteps[:1].repeat(batch_size)
544
+ latents = self.prepare_latents(
545
+ latents,
546
+ latent_timestep,
547
+ batch_size // num_images_per_prompt,
548
+ num_images_per_prompt,
549
+ prompt_embeds.dtype,
550
+ device,
551
+ generator,
552
+ )
553
+
554
+ for i, t in enumerate(self.progress_bar(timesteps)):
555
+ # expand the latents if we are doing classifier free guidance
556
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
557
+
558
+ predicted_image_embedding = self.prior(
559
+ latent_model_input,
560
+ timestep=t,
561
+ proj_embedding=prompt_embeds,
562
+ encoder_hidden_states=text_encoder_hidden_states,
563
+ attention_mask=text_mask,
564
+ ).predicted_image_embedding
565
+
566
+ if do_classifier_free_guidance:
567
+ predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
568
+ predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * (
569
+ predicted_image_embedding_text - predicted_image_embedding_uncond
570
+ )
571
+
572
+ if i + 1 == timesteps.shape[0]:
573
+ prev_timestep = None
574
+ else:
575
+ prev_timestep = timesteps[i + 1]
576
+
577
+ latents = self.scheduler.step(
578
+ predicted_image_embedding,
579
+ timestep=t,
580
+ sample=latents,
581
+ generator=generator,
582
+ prev_timestep=prev_timestep,
583
+ ).prev_sample
584
+
585
+ latents = self.prior.post_process_latents(latents)
586
+
587
+ image_embeddings = latents
588
+
589
+ # if negative prompt has been defined, we retrieve split the image embedding into two
590
+ if negative_prompt is None:
591
+ zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
592
+ else:
593
+ image_embeddings, zero_embeds = image_embeddings.chunk(2)
594
+
595
+ if output_type not in ["pt", "np"]:
596
+ raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
597
+
598
+ if output_type == "np":
599
+ image_embeddings = image_embeddings.cpu().numpy()
600
+ zero_embeds = zero_embeds.cpu().numpy()
601
+
602
+ if not return_dict:
603
+ return (image_embeddings, zero_embeds)
604
+
605
+ return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds)
@@ -83,8 +83,8 @@ class FlaxImagePipelineOutput(BaseOutput):
83
83
 
84
84
  Args:
85
85
  images (`List[PIL.Image.Image]` or `np.ndarray`)
86
- List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
87
- num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
86
+ List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
87
+ num_channels)`.
88
88
  """
89
89
 
90
90
  images: Union[List[PIL.Image.Image], np.ndarray]