diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,541 @@
1
+ from typing import List, Optional, Union
2
+
3
+ import PIL
4
+ import torch
5
+ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
6
+
7
+ from ...models import PriorTransformer
8
+ from ...pipelines import DiffusionPipeline
9
+ from ...schedulers import UnCLIPScheduler
10
+ from ...utils import (
11
+ is_accelerate_available,
12
+ logging,
13
+ randn_tensor,
14
+ replace_example_docstring,
15
+ )
16
+ from ..kandinsky import KandinskyPriorPipelineOutput
17
+
18
+
19
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
+
21
+ EXAMPLE_DOC_STRING = """
22
+ Examples:
23
+ ```py
24
+ >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline
25
+ >>> import torch
26
+
27
+ >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior")
28
+ >>> pipe_prior.to("cuda")
29
+ >>> prompt = "red cat, 4k photo"
30
+ >>> image_emb, negative_image_emb = pipe_prior(prompt).to_tuple()
31
+
32
+ >>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder")
33
+ >>> pipe.to("cuda")
34
+ >>> image = pipe(
35
+ ... image_embeds=image_emb,
36
+ ... negative_image_embeds=negative_image_emb,
37
+ ... height=768,
38
+ ... width=768,
39
+ ... num_inference_steps=50,
40
+ ... ).images
41
+ >>> image[0].save("cat.png")
42
+ ```
43
+ """
44
+
45
+ EXAMPLE_INTERPOLATE_DOC_STRING = """
46
+ Examples:
47
+ ```py
48
+ >>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline
49
+ >>> from diffusers.utils import load_image
50
+ >>> import PIL
51
+ >>> import torch
52
+ >>> from torchvision import transforms
53
+
54
+ >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
55
+ ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
56
+ ... )
57
+ >>> pipe_prior.to("cuda")
58
+ >>> img1 = load_image(
59
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
60
+ ... "/kandinsky/cat.png"
61
+ ... )
62
+ >>> img2 = load_image(
63
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
64
+ ... "/kandinsky/starry_night.jpeg"
65
+ ... )
66
+ >>> images_texts = ["a cat", img1, img2]
67
+ >>> weights = [0.3, 0.3, 0.4]
68
+ >>> out = pipe_prior.interpolate(images_texts, weights)
69
+ >>> pipe = KandinskyV22Pipeline.from_pretrained(
70
+ ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
71
+ ... )
72
+ >>> pipe.to("cuda")
73
+ >>> image = pipe(
74
+ ... image_embeds=out.image_embeds,
75
+ ... negative_image_embeds=out.negative_image_embeds,
76
+ ... height=768,
77
+ ... width=768,
78
+ ... num_inference_steps=50,
79
+ ... ).images[0]
80
+ >>> image.save("starry_cat.png")
81
+ ```
82
+ """
83
+
84
+
85
+ class KandinskyV22PriorPipeline(DiffusionPipeline):
86
+ """
87
+ Pipeline for generating image prior for Kandinsky
88
+
89
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
90
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
91
+
92
+ Args:
93
+ prior ([`PriorTransformer`]):
94
+ The canonincal unCLIP prior to approximate the image embedding from the text embedding.
95
+ image_encoder ([`CLIPVisionModelWithProjection`]):
96
+ Frozen image-encoder.
97
+ text_encoder ([`CLIPTextModelWithProjection`]):
98
+ Frozen text-encoder.
99
+ tokenizer (`CLIPTokenizer`):
100
+ Tokenizer of class
101
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
102
+ scheduler ([`UnCLIPScheduler`]):
103
+ A scheduler to be used in combination with `prior` to generate image embedding.
104
+ image_processor ([`CLIPImageProcessor`]):
105
+ A image_processor to be used to preprocess image from clip.
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ prior: PriorTransformer,
111
+ image_encoder: CLIPVisionModelWithProjection,
112
+ text_encoder: CLIPTextModelWithProjection,
113
+ tokenizer: CLIPTokenizer,
114
+ scheduler: UnCLIPScheduler,
115
+ image_processor: CLIPImageProcessor,
116
+ ):
117
+ super().__init__()
118
+
119
+ self.register_modules(
120
+ prior=prior,
121
+ text_encoder=text_encoder,
122
+ tokenizer=tokenizer,
123
+ scheduler=scheduler,
124
+ image_encoder=image_encoder,
125
+ image_processor=image_processor,
126
+ )
127
+
128
+ @torch.no_grad()
129
+ @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING)
130
+ def interpolate(
131
+ self,
132
+ images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]],
133
+ weights: List[float],
134
+ num_images_per_prompt: int = 1,
135
+ num_inference_steps: int = 25,
136
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
137
+ latents: Optional[torch.FloatTensor] = None,
138
+ negative_prior_prompt: Optional[str] = None,
139
+ negative_prompt: Union[str] = "",
140
+ guidance_scale: float = 4.0,
141
+ device=None,
142
+ ):
143
+ """
144
+ Function invoked when using the prior pipeline for interpolation.
145
+
146
+ Args:
147
+ images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`):
148
+ list of prompts and images to guide the image generation.
149
+ weights: (`List[float]`):
150
+ list of weights for each condition in `images_and_prompts`
151
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
152
+ The number of images to generate per prompt.
153
+ num_inference_steps (`int`, *optional*, defaults to 100):
154
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
155
+ expense of slower inference.
156
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
157
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
158
+ to make generation deterministic.
159
+ latents (`torch.FloatTensor`, *optional*):
160
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
161
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
162
+ tensor will ge generated by sampling using the supplied random `generator`.
163
+ negative_prior_prompt (`str`, *optional*):
164
+ The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
165
+ `guidance_scale` is less than `1`).
166
+ negative_prompt (`str` or `List[str]`, *optional*):
167
+ The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if
168
+ `guidance_scale` is less than `1`).
169
+ guidance_scale (`float`, *optional*, defaults to 4.0):
170
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
171
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
172
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
173
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
174
+ usually at the expense of lower image quality.
175
+
176
+ Examples:
177
+
178
+ Returns:
179
+ [`KandinskyPriorPipelineOutput`] or `tuple`
180
+ """
181
+
182
+ device = device or self.device
183
+
184
+ if len(images_and_prompts) != len(weights):
185
+ raise ValueError(
186
+ f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length"
187
+ )
188
+
189
+ image_embeddings = []
190
+ for cond, weight in zip(images_and_prompts, weights):
191
+ if isinstance(cond, str):
192
+ image_emb = self(
193
+ cond,
194
+ num_inference_steps=num_inference_steps,
195
+ num_images_per_prompt=num_images_per_prompt,
196
+ generator=generator,
197
+ latents=latents,
198
+ negative_prompt=negative_prior_prompt,
199
+ guidance_scale=guidance_scale,
200
+ ).image_embeds.unsqueeze(0)
201
+
202
+ elif isinstance(cond, (PIL.Image.Image, torch.Tensor)):
203
+ if isinstance(cond, PIL.Image.Image):
204
+ cond = (
205
+ self.image_processor(cond, return_tensors="pt")
206
+ .pixel_values[0]
207
+ .unsqueeze(0)
208
+ .to(dtype=self.image_encoder.dtype, device=device)
209
+ )
210
+
211
+ image_emb = self.image_encoder(cond)["image_embeds"].repeat(num_images_per_prompt, 1).unsqueeze(0)
212
+
213
+ else:
214
+ raise ValueError(
215
+ f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}"
216
+ )
217
+
218
+ image_embeddings.append(image_emb * weight)
219
+
220
+ image_emb = torch.cat(image_embeddings).sum(dim=0)
221
+
222
+ out_zero = self(
223
+ negative_prompt,
224
+ num_inference_steps=num_inference_steps,
225
+ num_images_per_prompt=num_images_per_prompt,
226
+ generator=generator,
227
+ latents=latents,
228
+ negative_prompt=negative_prior_prompt,
229
+ guidance_scale=guidance_scale,
230
+ )
231
+ zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds
232
+
233
+ return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb)
234
+
235
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
236
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
237
+ if latents is None:
238
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
239
+ else:
240
+ if latents.shape != shape:
241
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
242
+ latents = latents.to(device)
243
+
244
+ latents = latents * scheduler.init_noise_sigma
245
+ return latents
246
+
247
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed
248
+ def get_zero_embed(self, batch_size=1, device=None):
249
+ device = device or self.device
250
+ zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to(
251
+ device=device, dtype=self.image_encoder.dtype
252
+ )
253
+ zero_image_emb = self.image_encoder(zero_img)["image_embeds"]
254
+ zero_image_emb = zero_image_emb.repeat(batch_size, 1)
255
+ return zero_image_emb
256
+
257
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.enable_sequential_cpu_offload
258
+ def enable_sequential_cpu_offload(self, gpu_id=0):
259
+ r"""
260
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
261
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
262
+ when their specific submodule has its `forward` method called.
263
+ """
264
+ if is_accelerate_available():
265
+ from accelerate import cpu_offload
266
+ else:
267
+ raise ImportError("Please install accelerate via `pip install accelerate`")
268
+
269
+ device = torch.device(f"cuda:{gpu_id}")
270
+
271
+ models = [
272
+ self.image_encoder,
273
+ self.text_encoder,
274
+ ]
275
+ for cpu_offloaded_model in models:
276
+ if cpu_offloaded_model is not None:
277
+ cpu_offload(cpu_offloaded_model, device)
278
+
279
+ @property
280
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._execution_device
281
+ def _execution_device(self):
282
+ r"""
283
+ Returns the device on which the pipeline's models will be executed. After calling
284
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
285
+ hooks.
286
+ """
287
+ if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"):
288
+ return self.device
289
+ for module in self.text_encoder.modules():
290
+ if (
291
+ hasattr(module, "_hf_hook")
292
+ and hasattr(module._hf_hook, "execution_device")
293
+ and module._hf_hook.execution_device is not None
294
+ ):
295
+ return torch.device(module._hf_hook.execution_device)
296
+ return self.device
297
+
298
+ # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt
299
+ def _encode_prompt(
300
+ self,
301
+ prompt,
302
+ device,
303
+ num_images_per_prompt,
304
+ do_classifier_free_guidance,
305
+ negative_prompt=None,
306
+ ):
307
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
308
+ # get prompt text embeddings
309
+ text_inputs = self.tokenizer(
310
+ prompt,
311
+ padding="max_length",
312
+ max_length=self.tokenizer.model_max_length,
313
+ truncation=True,
314
+ return_tensors="pt",
315
+ )
316
+ text_input_ids = text_inputs.input_ids
317
+ text_mask = text_inputs.attention_mask.bool().to(device)
318
+
319
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
320
+
321
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
322
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
323
+ logger.warning(
324
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
325
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
326
+ )
327
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
328
+
329
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
330
+
331
+ prompt_embeds = text_encoder_output.text_embeds
332
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
333
+
334
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
335
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
336
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
337
+
338
+ if do_classifier_free_guidance:
339
+ uncond_tokens: List[str]
340
+ if negative_prompt is None:
341
+ uncond_tokens = [""] * batch_size
342
+ elif type(prompt) is not type(negative_prompt):
343
+ raise TypeError(
344
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
345
+ f" {type(prompt)}."
346
+ )
347
+ elif isinstance(negative_prompt, str):
348
+ uncond_tokens = [negative_prompt]
349
+ elif batch_size != len(negative_prompt):
350
+ raise ValueError(
351
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
352
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
353
+ " the batch size of `prompt`."
354
+ )
355
+ else:
356
+ uncond_tokens = negative_prompt
357
+
358
+ uncond_input = self.tokenizer(
359
+ uncond_tokens,
360
+ padding="max_length",
361
+ max_length=self.tokenizer.model_max_length,
362
+ truncation=True,
363
+ return_tensors="pt",
364
+ )
365
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
366
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
367
+
368
+ negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
369
+ uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
370
+
371
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
372
+
373
+ seq_len = negative_prompt_embeds.shape[1]
374
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
375
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
376
+
377
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
378
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
379
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
380
+ batch_size * num_images_per_prompt, seq_len, -1
381
+ )
382
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
383
+
384
+ # done duplicates
385
+
386
+ # For classifier free guidance, we need to do two forward passes.
387
+ # Here we concatenate the unconditional and text embeddings into a single batch
388
+ # to avoid doing two forward passes
389
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
390
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
391
+
392
+ text_mask = torch.cat([uncond_text_mask, text_mask])
393
+
394
+ return prompt_embeds, text_encoder_hidden_states, text_mask
395
+
396
+ @torch.no_grad()
397
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
398
+ def __call__(
399
+ self,
400
+ prompt: Union[str, List[str]],
401
+ negative_prompt: Optional[Union[str, List[str]]] = None,
402
+ num_images_per_prompt: int = 1,
403
+ num_inference_steps: int = 25,
404
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
405
+ latents: Optional[torch.FloatTensor] = None,
406
+ guidance_scale: float = 4.0,
407
+ output_type: Optional[str] = "pt", # pt only
408
+ return_dict: bool = True,
409
+ ):
410
+ """
411
+ Function invoked when calling the pipeline for generation.
412
+
413
+ Args:
414
+ prompt (`str` or `List[str]`):
415
+ The prompt or prompts to guide the image generation.
416
+ negative_prompt (`str` or `List[str]`, *optional*):
417
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
418
+ if `guidance_scale` is less than `1`).
419
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
420
+ The number of images to generate per prompt.
421
+ num_inference_steps (`int`, *optional*, defaults to 100):
422
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
423
+ expense of slower inference.
424
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
425
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
426
+ to make generation deterministic.
427
+ latents (`torch.FloatTensor`, *optional*):
428
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
429
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
430
+ tensor will ge generated by sampling using the supplied random `generator`.
431
+ guidance_scale (`float`, *optional*, defaults to 4.0):
432
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
433
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
434
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
435
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
436
+ usually at the expense of lower image quality.
437
+ output_type (`str`, *optional*, defaults to `"pt"`):
438
+ The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
439
+ (`torch.Tensor`).
440
+ return_dict (`bool`, *optional*, defaults to `True`):
441
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
442
+
443
+ Examples:
444
+
445
+ Returns:
446
+ [`KandinskyPriorPipelineOutput`] or `tuple`
447
+ """
448
+
449
+ if isinstance(prompt, str):
450
+ prompt = [prompt]
451
+ elif not isinstance(prompt, list):
452
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
453
+
454
+ if isinstance(negative_prompt, str):
455
+ negative_prompt = [negative_prompt]
456
+ elif not isinstance(negative_prompt, list) and negative_prompt is not None:
457
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
458
+
459
+ # if the negative prompt is defined we double the batch size to
460
+ # directly retrieve the negative prompt embedding
461
+ if negative_prompt is not None:
462
+ prompt = prompt + negative_prompt
463
+ negative_prompt = 2 * negative_prompt
464
+
465
+ device = self._execution_device
466
+
467
+ batch_size = len(prompt)
468
+ batch_size = batch_size * num_images_per_prompt
469
+
470
+ do_classifier_free_guidance = guidance_scale > 1.0
471
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
472
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
473
+ )
474
+
475
+ # prior
476
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
477
+ prior_timesteps_tensor = self.scheduler.timesteps
478
+
479
+ embedding_dim = self.prior.config.embedding_dim
480
+
481
+ latents = self.prepare_latents(
482
+ (batch_size, embedding_dim),
483
+ prompt_embeds.dtype,
484
+ device,
485
+ generator,
486
+ latents,
487
+ self.scheduler,
488
+ )
489
+
490
+ for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
491
+ # expand the latents if we are doing classifier free guidance
492
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
493
+
494
+ predicted_image_embedding = self.prior(
495
+ latent_model_input,
496
+ timestep=t,
497
+ proj_embedding=prompt_embeds,
498
+ encoder_hidden_states=text_encoder_hidden_states,
499
+ attention_mask=text_mask,
500
+ ).predicted_image_embedding
501
+
502
+ if do_classifier_free_guidance:
503
+ predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
504
+ predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * (
505
+ predicted_image_embedding_text - predicted_image_embedding_uncond
506
+ )
507
+
508
+ if i + 1 == prior_timesteps_tensor.shape[0]:
509
+ prev_timestep = None
510
+ else:
511
+ prev_timestep = prior_timesteps_tensor[i + 1]
512
+
513
+ latents = self.scheduler.step(
514
+ predicted_image_embedding,
515
+ timestep=t,
516
+ sample=latents,
517
+ generator=generator,
518
+ prev_timestep=prev_timestep,
519
+ ).prev_sample
520
+
521
+ latents = self.prior.post_process_latents(latents)
522
+
523
+ image_embeddings = latents
524
+
525
+ # if negative prompt has been defined, we retrieve split the image embedding into two
526
+ if negative_prompt is None:
527
+ zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
528
+ else:
529
+ image_embeddings, zero_embeds = image_embeddings.chunk(2)
530
+
531
+ if output_type not in ["pt", "np"]:
532
+ raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
533
+
534
+ if output_type == "np":
535
+ image_embeddings = image_embeddings.cpu().numpy()
536
+ zero_embeds = zero_embeds.cpu().numpy()
537
+
538
+ if not return_dict:
539
+ return (image_embeddings, zero_embeds)
540
+
541
+ return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds)