diffusers 0.24.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (174) hide show
  1. diffusers/__init__.py +11 -1
  2. diffusers/commands/fp16_safetensors.py +10 -11
  3. diffusers/configuration_utils.py +12 -8
  4. diffusers/dependency_versions_table.py +2 -1
  5. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  6. diffusers/image_processor.py +286 -46
  7. diffusers/loaders/ip_adapter.py +11 -9
  8. diffusers/loaders/lora.py +198 -60
  9. diffusers/loaders/single_file.py +24 -18
  10. diffusers/loaders/textual_inversion.py +10 -14
  11. diffusers/loaders/unet.py +130 -37
  12. diffusers/models/__init__.py +18 -12
  13. diffusers/models/activations.py +9 -6
  14. diffusers/models/attention.py +137 -16
  15. diffusers/models/attention_processor.py +133 -46
  16. diffusers/models/autoencoders/__init__.py +5 -0
  17. diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +4 -4
  18. diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +45 -6
  19. diffusers/models/{autoencoder_kl_temporal_decoder.py → autoencoders/autoencoder_kl_temporal_decoder.py} +8 -8
  20. diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +4 -4
  21. diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +14 -14
  22. diffusers/models/{vae.py → autoencoders/vae.py} +9 -5
  23. diffusers/models/downsampling.py +338 -0
  24. diffusers/models/embeddings.py +112 -29
  25. diffusers/models/modeling_flax_utils.py +12 -7
  26. diffusers/models/modeling_utils.py +10 -10
  27. diffusers/models/normalization.py +108 -2
  28. diffusers/models/resnet.py +15 -699
  29. diffusers/models/transformer_2d.py +2 -2
  30. diffusers/models/unet_2d_condition.py +37 -0
  31. diffusers/models/{unet_kandi3.py → unet_kandinsky3.py} +105 -159
  32. diffusers/models/upsampling.py +454 -0
  33. diffusers/models/uvit_2d.py +471 -0
  34. diffusers/models/vq_model.py +9 -2
  35. diffusers/pipelines/__init__.py +81 -73
  36. diffusers/pipelines/amused/__init__.py +62 -0
  37. diffusers/pipelines/amused/pipeline_amused.py +328 -0
  38. diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
  39. diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +38 -10
  41. diffusers/pipelines/auto_pipeline.py +17 -13
  42. diffusers/pipelines/controlnet/pipeline_controlnet.py +27 -10
  43. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +47 -5
  44. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +25 -8
  45. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +4 -6
  46. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +26 -10
  47. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +4 -3
  48. diffusers/pipelines/deprecated/__init__.py +153 -0
  49. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
  50. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +91 -18
  51. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +91 -18
  52. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
  53. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
  54. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
  55. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
  56. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
  57. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
  58. diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
  59. diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
  60. diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
  61. diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
  62. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
  63. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +4 -4
  64. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
  65. diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
  66. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
  67. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
  68. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +7 -7
  69. diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
  70. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +16 -11
  71. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +6 -6
  72. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +11 -11
  73. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +16 -11
  74. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +10 -10
  75. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +13 -13
  76. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
  77. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
  78. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
  79. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +54 -11
  80. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
  81. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +6 -6
  82. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +6 -6
  83. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +6 -6
  84. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
  85. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
  86. diffusers/pipelines/kandinsky3/__init__.py +4 -4
  87. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
  88. diffusers/pipelines/kandinsky3/{kandinsky3_pipeline.py → pipeline_kandinsky3.py} +172 -35
  89. diffusers/pipelines/kandinsky3/{kandinsky3img2img_pipeline.py → pipeline_kandinsky3_img2img.py} +228 -34
  90. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +46 -5
  91. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +47 -6
  92. diffusers/pipelines/onnx_utils.py +8 -5
  93. diffusers/pipelines/pipeline_flax_utils.py +7 -6
  94. diffusers/pipelines/pipeline_utils.py +30 -29
  95. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +51 -2
  96. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
  97. diffusers/pipelines/stable_diffusion/__init__.py +1 -72
  98. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +67 -75
  99. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +92 -8
  100. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -8
  101. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +138 -10
  102. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +57 -7
  103. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -0
  104. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +6 -0
  105. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
  106. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
  107. diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
  108. diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +5 -2
  109. diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
  110. diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +2 -3
  111. diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
  112. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +2 -2
  113. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +3 -3
  114. diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
  115. diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +6 -1
  116. diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
  117. diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +50 -7
  118. diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
  119. diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +56 -8
  120. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
  121. diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
  122. diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +67 -10
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +97 -15
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +97 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +7 -5
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +12 -9
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +6 -0
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -0
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +5 -0
  131. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +331 -9
  132. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +468 -9
  133. diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
  134. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
  135. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  136. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +4 -0
  137. diffusers/schedulers/__init__.py +2 -0
  138. diffusers/schedulers/scheduling_amused.py +162 -0
  139. diffusers/schedulers/scheduling_consistency_models.py +2 -0
  140. diffusers/schedulers/scheduling_ddim_inverse.py +1 -4
  141. diffusers/schedulers/scheduling_ddpm.py +46 -0
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +46 -0
  143. diffusers/schedulers/scheduling_deis_multistep.py +13 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +13 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +13 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -0
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -1
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -0
  149. diffusers/schedulers/scheduling_euler_discrete.py +62 -3
  150. diffusers/schedulers/scheduling_heun_discrete.py +2 -0
  151. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -0
  152. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -0
  153. diffusers/schedulers/scheduling_lms_discrete.py +2 -0
  154. diffusers/schedulers/scheduling_unipc_multistep.py +13 -1
  155. diffusers/schedulers/scheduling_utils.py +3 -1
  156. diffusers/schedulers/scheduling_utils_flax.py +3 -1
  157. diffusers/training_utils.py +1 -1
  158. diffusers/utils/__init__.py +0 -2
  159. diffusers/utils/constants.py +2 -5
  160. diffusers/utils/dummy_pt_objects.py +30 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +45 -0
  162. diffusers/utils/dynamic_modules_utils.py +14 -18
  163. diffusers/utils/hub_utils.py +24 -36
  164. diffusers/utils/logging.py +1 -1
  165. diffusers/utils/state_dict_utils.py +8 -0
  166. diffusers/utils/testing_utils.py +199 -1
  167. diffusers/utils/torch_utils.py +3 -3
  168. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/METADATA +54 -53
  169. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/RECORD +174 -155
  170. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
  172. /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
  173. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
  174. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Callable, List, Optional, Union
1
+ from typing import Callable, Dict, List, Optional, Union
2
2
 
3
3
  import torch
4
4
  from transformers import T5EncoderModel, T5Tokenizer
@@ -7,8 +7,10 @@ from ...loaders import LoraLoaderMixin
7
7
  from ...models import Kandinsky3UNet, VQModel
8
8
  from ...schedulers import DDPMScheduler
9
9
  from ...utils import (
10
+ deprecate,
10
11
  is_accelerate_available,
11
12
  logging,
13
+ replace_example_docstring,
12
14
  )
13
15
  from ...utils.torch_utils import randn_tensor
14
16
  from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -16,6 +18,23 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
16
18
 
17
19
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
20
 
21
+ EXAMPLE_DOC_STRING = """
22
+ Examples:
23
+ ```py
24
+ >>> from diffusers import AutoPipelineForText2Image
25
+ >>> import torch
26
+
27
+ >>> pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
28
+ >>> pipe.enable_model_cpu_offload()
29
+
30
+ >>> prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
31
+
32
+ >>> generator = torch.Generator(device="cpu").manual_seed(0)
33
+ >>> image = pipe(prompt, num_inference_steps=25, generator=generator).images[0]
34
+ ```
35
+
36
+ """
37
+
19
38
 
20
39
  def downscale_height_and_width(height, width, scale_factor=8):
21
40
  new_height = height // scale_factor**2
@@ -29,6 +48,13 @@ def downscale_height_and_width(height, width, scale_factor=8):
29
48
 
30
49
  class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
31
50
  model_cpu_offload_seq = "text_encoder->unet->movq"
51
+ _callback_tensor_inputs = [
52
+ "latents",
53
+ "prompt_embeds",
54
+ "negative_prompt_embeds",
55
+ "negative_attention_mask",
56
+ "attention_mask",
57
+ ]
32
58
 
33
59
  def __init__(
34
60
  self,
@@ -50,7 +76,7 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
50
76
  else:
51
77
  raise ImportError("Please install accelerate via `pip install accelerate`")
52
78
 
53
- for model in [self.text_encoder, self.unet]:
79
+ for model in [self.text_encoder, self.unet, self.movq]:
54
80
  if model is not None:
55
81
  remove_hook_from_module(model, recurse=True)
56
82
 
@@ -77,12 +103,14 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
77
103
  prompt_embeds: Optional[torch.FloatTensor] = None,
78
104
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
79
105
  _cut_context=False,
106
+ attention_mask: Optional[torch.FloatTensor] = None,
107
+ negative_attention_mask: Optional[torch.FloatTensor] = None,
80
108
  ):
81
109
  r"""
82
110
  Encodes the prompt into text encoder hidden states.
83
111
 
84
112
  Args:
85
- prompt (`str` or `List[str]`, *optional*):
113
+ prompt (`str` or `List[str]`, *optional*):
86
114
  prompt to be encoded
87
115
  device: (`torch.device`, *optional*):
88
116
  torch device to place the resulting embeddings on
@@ -101,6 +129,10 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
101
129
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
102
130
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
103
131
  argument.
132
+ attention_mask (`torch.FloatTensor`, *optional*):
133
+ Pre-generated attention mask. Must provide if passing `prompt_embeds` directly.
134
+ negative_attention_mask (`torch.FloatTensor`, *optional*):
135
+ Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly.
104
136
  """
105
137
  if prompt is not None and negative_prompt is not None:
106
138
  if type(prompt) is not type(negative_prompt):
@@ -228,14 +260,21 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
228
260
  negative_prompt=None,
229
261
  prompt_embeds=None,
230
262
  negative_prompt_embeds=None,
263
+ callback_on_step_end_tensor_inputs=None,
264
+ attention_mask=None,
265
+ negative_attention_mask=None,
231
266
  ):
232
- if (callback_steps is None) or (
233
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
234
- ):
267
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
235
268
  raise ValueError(
236
269
  f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
237
270
  f" {type(callback_steps)}."
238
271
  )
272
+ if callback_on_step_end_tensor_inputs is not None and not all(
273
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
274
+ ):
275
+ raise ValueError(
276
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
277
+ )
239
278
 
240
279
  if prompt is not None and prompt_embeds is not None:
241
280
  raise ValueError(
@@ -262,8 +301,42 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
262
301
  f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
263
302
  f" {negative_prompt_embeds.shape}."
264
303
  )
304
+ if negative_prompt_embeds is not None and negative_attention_mask is None:
305
+ raise ValueError("Please provide `negative_attention_mask` along with `negative_prompt_embeds`")
306
+
307
+ if negative_prompt_embeds is not None and negative_attention_mask is not None:
308
+ if negative_prompt_embeds.shape[:2] != negative_attention_mask.shape:
309
+ raise ValueError(
310
+ "`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but"
311
+ f" got: `negative_prompt_embeds` {negative_prompt_embeds.shape[:2]} != `negative_attention_mask`"
312
+ f" {negative_attention_mask.shape}."
313
+ )
314
+
315
+ if prompt_embeds is not None and attention_mask is None:
316
+ raise ValueError("Please provide `attention_mask` along with `prompt_embeds`")
317
+
318
+ if prompt_embeds is not None and attention_mask is not None:
319
+ if prompt_embeds.shape[:2] != attention_mask.shape:
320
+ raise ValueError(
321
+ "`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but"
322
+ f" got: `prompt_embeds` {prompt_embeds.shape[:2]} != `attention_mask`"
323
+ f" {attention_mask.shape}."
324
+ )
325
+
326
+ @property
327
+ def guidance_scale(self):
328
+ return self._guidance_scale
329
+
330
+ @property
331
+ def do_classifier_free_guidance(self):
332
+ return self._guidance_scale > 1
333
+
334
+ @property
335
+ def num_timesteps(self):
336
+ return self._num_timesteps
265
337
 
266
338
  @torch.no_grad()
339
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
267
340
  def __call__(
268
341
  self,
269
342
  prompt: Union[str, List[str]] = None,
@@ -276,11 +349,14 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
276
349
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
277
350
  prompt_embeds: Optional[torch.FloatTensor] = None,
278
351
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
352
+ attention_mask: Optional[torch.FloatTensor] = None,
353
+ negative_attention_mask: Optional[torch.FloatTensor] = None,
279
354
  output_type: Optional[str] = "pil",
280
355
  return_dict: bool = True,
281
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
282
- callback_steps: int = 1,
283
356
  latents=None,
357
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
358
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
359
+ **kwargs,
284
360
  ):
285
361
  """
286
362
  Function invoked when calling the pipeline for generation.
@@ -289,7 +365,7 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
289
365
  prompt (`str` or `List[str]`, *optional*):
290
366
  The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
291
367
  instead.
292
- num_inference_steps (`int`, *optional*, defaults to 50):
368
+ num_inference_steps (`int`, *optional*, defaults to 25):
293
369
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
294
370
  expense of slower inference.
295
371
  timesteps (`List[int]`, *optional*):
@@ -324,6 +400,10 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
324
400
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
325
401
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
326
402
  argument.
403
+ attention_mask (`torch.FloatTensor`, *optional*):
404
+ Pre-generated attention mask. Must provide if passing `prompt_embeds` directly.
405
+ negative_attention_mask (`torch.FloatTensor`, *optional*):
406
+ Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly.
327
407
  output_type (`str`, *optional*, defaults to `"pil"`):
328
408
  The output format of the generate image. Choose between
329
409
  [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -343,12 +423,53 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
343
423
  A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
344
424
  `self.processor` in
345
425
  [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
426
+
427
+ Examples:
428
+
429
+ Returns:
430
+ [`~pipelines.ImagePipelineOutput`] or `tuple`
431
+
346
432
  """
433
+
434
+ callback = kwargs.pop("callback", None)
435
+ callback_steps = kwargs.pop("callback_steps", None)
436
+
437
+ if callback is not None:
438
+ deprecate(
439
+ "callback",
440
+ "1.0.0",
441
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
442
+ )
443
+ if callback_steps is not None:
444
+ deprecate(
445
+ "callback_steps",
446
+ "1.0.0",
447
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
448
+ )
449
+
450
+ if callback_on_step_end_tensor_inputs is not None and not all(
451
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
452
+ ):
453
+ raise ValueError(
454
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
455
+ )
456
+
347
457
  cut_context = True
348
458
  device = self._execution_device
349
459
 
350
460
  # 1. Check inputs. Raise error if not correct
351
- self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
461
+ self.check_inputs(
462
+ prompt,
463
+ callback_steps,
464
+ negative_prompt,
465
+ prompt_embeds,
466
+ negative_prompt_embeds,
467
+ callback_on_step_end_tensor_inputs,
468
+ attention_mask,
469
+ negative_attention_mask,
470
+ )
471
+
472
+ self._guidance_scale = guidance_scale
352
473
 
353
474
  if prompt is not None and isinstance(prompt, str):
354
475
  batch_size = 1
@@ -357,24 +478,21 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
357
478
  else:
358
479
  batch_size = prompt_embeds.shape[0]
359
480
 
360
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
361
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
362
- # corresponds to doing no classifier free guidance.
363
- do_classifier_free_guidance = guidance_scale > 1.0
364
-
365
481
  # 3. Encode input prompt
366
482
  prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt(
367
483
  prompt,
368
- do_classifier_free_guidance,
484
+ self.do_classifier_free_guidance,
369
485
  num_images_per_prompt=num_images_per_prompt,
370
486
  device=device,
371
487
  negative_prompt=negative_prompt,
372
488
  prompt_embeds=prompt_embeds,
373
489
  negative_prompt_embeds=negative_prompt_embeds,
374
490
  _cut_context=cut_context,
491
+ attention_mask=attention_mask,
492
+ negative_attention_mask=negative_attention_mask,
375
493
  )
376
494
 
377
- if do_classifier_free_guidance:
495
+ if self.do_classifier_free_guidance:
378
496
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
379
497
  attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool()
380
498
  # 4. Prepare timesteps
@@ -397,11 +515,11 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
397
515
  self.text_encoder_offload_hook.offload()
398
516
 
399
517
  # 7. Denoising loop
400
- # TODO(Yiyi): Correct the following line and use correctly
401
- # num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
518
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
519
+ self._num_timesteps = len(timesteps)
402
520
  with self.progress_bar(total=num_inference_steps) as progress_bar:
403
521
  for i, t in enumerate(timesteps):
404
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
522
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
405
523
 
406
524
  # predict the noise residual
407
525
  noise_pred = self.unet(
@@ -412,7 +530,7 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
412
530
  return_dict=False,
413
531
  )[0]
414
532
 
415
- if do_classifier_free_guidance:
533
+ if self.do_classifier_free_guidance:
416
534
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
417
535
 
418
536
  noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond
@@ -425,26 +543,45 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
425
543
  latents,
426
544
  generator=generator,
427
545
  ).prev_sample
428
- progress_bar.update()
429
- if callback is not None and i % callback_steps == 0:
430
- step_idx = i // getattr(self.scheduler, "order", 1)
431
- callback(step_idx, t, latents)
432
546
 
433
- # post-processing
434
- image = self.movq.decode(latents, force_not_quantize=True)["sample"]
547
+ if callback_on_step_end is not None:
548
+ callback_kwargs = {}
549
+ for k in callback_on_step_end_tensor_inputs:
550
+ callback_kwargs[k] = locals()[k]
551
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
552
+
553
+ latents = callback_outputs.pop("latents", latents)
554
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
555
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
556
+ attention_mask = callback_outputs.pop("attention_mask", attention_mask)
557
+ negative_attention_mask = callback_outputs.pop("negative_attention_mask", negative_attention_mask)
435
558
 
436
- if output_type not in ["pt", "np", "pil"]:
559
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
560
+ progress_bar.update()
561
+ if callback is not None and i % callback_steps == 0:
562
+ step_idx = i // getattr(self.scheduler, "order", 1)
563
+ callback(step_idx, t, latents)
564
+
565
+ # post-processing
566
+ if output_type not in ["pt", "np", "pil", "latent"]:
437
567
  raise ValueError(
438
- f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}"
568
+ f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}"
439
569
  )
440
570
 
441
- if output_type in ["np", "pil"]:
442
- image = image * 0.5 + 0.5
443
- image = image.clamp(0, 1)
444
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
571
+ if not output_type == "latent":
572
+ image = self.movq.decode(latents, force_not_quantize=True)["sample"]
573
+
574
+ if output_type in ["np", "pil"]:
575
+ image = image * 0.5 + 0.5
576
+ image = image.clamp(0, 1)
577
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
578
+
579
+ if output_type == "pil":
580
+ image = self.numpy_to_pil(image)
581
+ else:
582
+ image = latents
445
583
 
446
- if output_type == "pil":
447
- image = self.numpy_to_pil(image)
584
+ self.maybe_free_model_hooks()
448
585
 
449
586
  if not return_dict:
450
587
  return (image,)