diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (238) hide show
  1. diffusers/__init__.py +26 -2
  2. diffusers/commands/fp16_safetensors.py +10 -11
  3. diffusers/configuration_utils.py +13 -8
  4. diffusers/dependency_versions_check.py +0 -1
  5. diffusers/dependency_versions_table.py +5 -5
  6. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  7. diffusers/image_processor.py +463 -51
  8. diffusers/loaders/__init__.py +82 -0
  9. diffusers/loaders/ip_adapter.py +159 -0
  10. diffusers/loaders/lora.py +1553 -0
  11. diffusers/loaders/lora_conversion_utils.py +284 -0
  12. diffusers/loaders/single_file.py +637 -0
  13. diffusers/loaders/textual_inversion.py +455 -0
  14. diffusers/loaders/unet.py +828 -0
  15. diffusers/loaders/utils.py +59 -0
  16. diffusers/models/__init__.py +26 -9
  17. diffusers/models/activations.py +9 -6
  18. diffusers/models/attention.py +301 -29
  19. diffusers/models/attention_flax.py +9 -1
  20. diffusers/models/attention_processor.py +378 -6
  21. diffusers/models/autoencoders/__init__.py +5 -0
  22. diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
  23. diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
  24. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
  25. diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
  26. diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
  27. diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
  28. diffusers/models/controlnet.py +59 -39
  29. diffusers/models/controlnet_flax.py +19 -18
  30. diffusers/models/downsampling.py +338 -0
  31. diffusers/models/embeddings.py +112 -29
  32. diffusers/models/embeddings_flax.py +2 -0
  33. diffusers/models/lora.py +131 -1
  34. diffusers/models/modeling_flax_utils.py +14 -8
  35. diffusers/models/modeling_outputs.py +17 -0
  36. diffusers/models/modeling_utils.py +37 -29
  37. diffusers/models/normalization.py +110 -4
  38. diffusers/models/resnet.py +299 -652
  39. diffusers/models/transformer_2d.py +22 -5
  40. diffusers/models/transformer_temporal.py +183 -1
  41. diffusers/models/unet_2d_blocks_flax.py +5 -0
  42. diffusers/models/unet_2d_condition.py +46 -0
  43. diffusers/models/unet_2d_condition_flax.py +13 -13
  44. diffusers/models/unet_3d_blocks.py +957 -173
  45. diffusers/models/unet_3d_condition.py +16 -8
  46. diffusers/models/unet_kandinsky3.py +535 -0
  47. diffusers/models/unet_motion_model.py +48 -33
  48. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  49. diffusers/models/upsampling.py +454 -0
  50. diffusers/models/uvit_2d.py +471 -0
  51. diffusers/models/vae_flax.py +7 -0
  52. diffusers/models/vq_model.py +12 -3
  53. diffusers/optimization.py +16 -9
  54. diffusers/pipelines/__init__.py +137 -76
  55. diffusers/pipelines/amused/__init__.py +62 -0
  56. diffusers/pipelines/amused/pipeline_amused.py +328 -0
  57. diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
  58. diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
  59. diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
  60. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  61. diffusers/pipelines/auto_pipeline.py +23 -13
  62. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  63. diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
  64. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
  65. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
  66. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
  67. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
  68. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
  69. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  70. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  71. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  72. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  73. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  74. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  75. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  76. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  77. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  78. diffusers/pipelines/deprecated/__init__.py +153 -0
  79. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
  80. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
  81. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
  82. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
  83. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
  84. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
  85. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
  86. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
  87. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
  88. diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
  89. diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
  90. diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
  91. diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
  92. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
  93. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
  94. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
  95. diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
  96. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
  97. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
  98. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
  100. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
  101. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
  102. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
  103. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
  104. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
  105. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
  106. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
  107. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
  108. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
  109. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
  110. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
  111. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
  112. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
  113. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
  114. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
  115. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
  116. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  117. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  118. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  119. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  120. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  122. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  123. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
  124. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
  125. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
  126. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
  127. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
  128. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  129. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  130. diffusers/pipelines/onnx_utils.py +8 -5
  131. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  132. diffusers/pipelines/pipeline_flax_utils.py +11 -8
  133. diffusers/pipelines/pipeline_utils.py +63 -42
  134. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
  135. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
  136. diffusers/pipelines/stable_diffusion/__init__.py +37 -65
  137. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
  138. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  139. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  140. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  141. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
  142. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  143. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  144. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
  145. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
  146. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
  147. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
  151. diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
  152. diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
  153. diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
  154. diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
  155. diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
  156. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
  157. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  158. diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
  159. diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
  160. diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
  161. diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
  162. diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
  163. diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
  164. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
  165. diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
  166. diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
  171. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  172. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
  175. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
  179. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
  180. diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
  181. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
  182. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  183. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  184. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  185. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  186. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  187. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
  188. diffusers/schedulers/__init__.py +4 -4
  189. diffusers/schedulers/deprecated/__init__.py +50 -0
  190. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  191. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  192. diffusers/schedulers/scheduling_amused.py +162 -0
  193. diffusers/schedulers/scheduling_consistency_models.py +2 -0
  194. diffusers/schedulers/scheduling_ddim.py +1 -3
  195. diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
  196. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  197. diffusers/schedulers/scheduling_ddpm.py +47 -3
  198. diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
  199. diffusers/schedulers/scheduling_deis_multistep.py +28 -6
  200. diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
  201. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
  202. diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
  203. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
  204. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
  205. diffusers/schedulers/scheduling_euler_discrete.py +102 -16
  206. diffusers/schedulers/scheduling_heun_discrete.py +17 -5
  207. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
  208. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
  209. diffusers/schedulers/scheduling_lcm.py +123 -29
  210. diffusers/schedulers/scheduling_lms_discrete.py +3 -3
  211. diffusers/schedulers/scheduling_pndm.py +1 -3
  212. diffusers/schedulers/scheduling_repaint.py +1 -3
  213. diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
  214. diffusers/schedulers/scheduling_utils.py +3 -1
  215. diffusers/schedulers/scheduling_utils_flax.py +3 -1
  216. diffusers/training_utils.py +1 -1
  217. diffusers/utils/__init__.py +1 -2
  218. diffusers/utils/constants.py +10 -12
  219. diffusers/utils/dummy_pt_objects.py +75 -0
  220. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  221. diffusers/utils/dynamic_modules_utils.py +18 -22
  222. diffusers/utils/export_utils.py +8 -3
  223. diffusers/utils/hub_utils.py +24 -36
  224. diffusers/utils/logging.py +11 -11
  225. diffusers/utils/outputs.py +5 -5
  226. diffusers/utils/peft_utils.py +88 -44
  227. diffusers/utils/state_dict_utils.py +8 -0
  228. diffusers/utils/testing_utils.py +199 -1
  229. diffusers/utils/torch_utils.py +4 -4
  230. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
  231. diffusers-0.25.0.dist-info/RECORD +360 -0
  232. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
  233. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
  234. diffusers/loaders.py +0 -3336
  235. diffusers-0.23.1.dist-info/RECORD +0 -323
  236. /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
  237. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
  238. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -16,17 +16,25 @@ import inspect
16
16
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
- from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
19
+ from transformers import (
20
+ CLIPImageProcessor,
21
+ CLIPTextModel,
22
+ CLIPTextModelWithProjection,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ )
20
26
 
21
- from ...image_processor import VaeImageProcessor
27
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
22
28
  from ...loaders import (
23
29
  FromSingleFileMixin,
30
+ IPAdapterMixin,
24
31
  StableDiffusionXLLoraLoaderMixin,
25
32
  TextualInversionLoaderMixin,
26
33
  )
27
- from ...models import AutoencoderKL, UNet2DConditionModel
34
+ from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
28
35
  from ...models.attention_processor import (
29
36
  AttnProcessor2_0,
37
+ FusedAttnProcessor2_0,
30
38
  LoRAAttnProcessor2_0,
31
39
  LoRAXFormersAttnProcessor,
32
40
  XFormersAttnProcessor,
@@ -93,8 +101,57 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
93
101
  return noise_cfg
94
102
 
95
103
 
104
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
105
+ def retrieve_timesteps(
106
+ scheduler,
107
+ num_inference_steps: Optional[int] = None,
108
+ device: Optional[Union[str, torch.device]] = None,
109
+ timesteps: Optional[List[int]] = None,
110
+ **kwargs,
111
+ ):
112
+ """
113
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
114
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
115
+
116
+ Args:
117
+ scheduler (`SchedulerMixin`):
118
+ The scheduler to get timesteps from.
119
+ num_inference_steps (`int`):
120
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
121
+ `timesteps` must be `None`.
122
+ device (`str` or `torch.device`, *optional*):
123
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
124
+ timesteps (`List[int]`, *optional*):
125
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
126
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
127
+ must be `None`.
128
+
129
+ Returns:
130
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
131
+ second element is the number of inference steps.
132
+ """
133
+ if timesteps is not None:
134
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
135
+ if not accepts_timesteps:
136
+ raise ValueError(
137
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
138
+ f" timestep schedules. Please check whether you are using the correct scheduler."
139
+ )
140
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
141
+ timesteps = scheduler.timesteps
142
+ num_inference_steps = len(timesteps)
143
+ else:
144
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
145
+ timesteps = scheduler.timesteps
146
+ return timesteps, num_inference_steps
147
+
148
+
96
149
  class StableDiffusionXLPipeline(
97
- DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
150
+ DiffusionPipeline,
151
+ FromSingleFileMixin,
152
+ StableDiffusionXLLoraLoaderMixin,
153
+ TextualInversionLoaderMixin,
154
+ IPAdapterMixin,
98
155
  ):
99
156
  r"""
100
157
  Pipeline for text-to-image generation using Stable Diffusion XL.
@@ -102,12 +159,12 @@ class StableDiffusionXLPipeline(
102
159
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
103
160
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
104
161
 
105
- In addition the pipeline inherits the following loading methods:
106
- - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
107
- - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
108
-
109
- as well as the following saving methods:
110
- - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
162
+ The pipeline also inherits the following loading methods:
163
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
164
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
165
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
166
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
167
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
111
168
 
112
169
  Args:
113
170
  vae ([`AutoencoderKL`]):
@@ -140,8 +197,16 @@ class StableDiffusionXLPipeline(
140
197
  watermark output images. If not defined, it will default to True if the package is installed, otherwise no
141
198
  watermarker will be used.
142
199
  """
143
- model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
144
- _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
200
+
201
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
202
+ _optional_components = [
203
+ "tokenizer",
204
+ "tokenizer_2",
205
+ "text_encoder",
206
+ "text_encoder_2",
207
+ "image_encoder",
208
+ "feature_extractor",
209
+ ]
145
210
  _callback_tensor_inputs = [
146
211
  "latents",
147
212
  "prompt_embeds",
@@ -161,6 +226,8 @@ class StableDiffusionXLPipeline(
161
226
  tokenizer_2: CLIPTokenizer,
162
227
  unet: UNet2DConditionModel,
163
228
  scheduler: KarrasDiffusionSchedulers,
229
+ image_encoder: CLIPVisionModelWithProjection = None,
230
+ feature_extractor: CLIPImageProcessor = None,
164
231
  force_zeros_for_empty_prompt: bool = True,
165
232
  add_watermarker: Optional[bool] = None,
166
233
  ):
@@ -174,6 +241,8 @@ class StableDiffusionXLPipeline(
174
241
  tokenizer_2=tokenizer_2,
175
242
  unet=unet,
176
243
  scheduler=scheduler,
244
+ image_encoder=image_encoder,
245
+ feature_extractor=feature_extractor,
177
246
  )
178
247
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
179
248
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
@@ -455,6 +524,31 @@ class StableDiffusionXLPipeline(
455
524
 
456
525
  return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
457
526
 
527
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
528
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
529
+ dtype = next(self.image_encoder.parameters()).dtype
530
+
531
+ if not isinstance(image, torch.Tensor):
532
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
533
+
534
+ image = image.to(device=device, dtype=dtype)
535
+ if output_hidden_states:
536
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
537
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
538
+ uncond_image_enc_hidden_states = self.image_encoder(
539
+ torch.zeros_like(image), output_hidden_states=True
540
+ ).hidden_states[-2]
541
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
542
+ num_images_per_prompt, dim=0
543
+ )
544
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
545
+ else:
546
+ image_embeds = self.image_encoder(image).image_embeds
547
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
548
+ uncond_image_embeds = torch.zeros_like(image_embeds)
549
+
550
+ return image_embeds, uncond_image_embeds
551
+
458
552
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
459
553
  def prepare_extra_step_kwargs(self, generator, eta):
460
554
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -588,7 +682,6 @@ class StableDiffusionXLPipeline(
588
682
  add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
589
683
  return add_time_ids
590
684
 
591
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
592
685
  def upcast_vae(self):
593
686
  dtype = self.vae.dtype
594
687
  self.vae.to(dtype=torch.float32)
@@ -599,6 +692,7 @@ class StableDiffusionXLPipeline(
599
692
  XFormersAttnProcessor,
600
693
  LoRAXFormersAttnProcessor,
601
694
  LoRAAttnProcessor2_0,
695
+ FusedAttnProcessor2_0,
602
696
  ),
603
697
  )
604
698
  # if xformers or torch_2_0 is used attention block does not need
@@ -636,6 +730,65 @@ class StableDiffusionXLPipeline(
636
730
  """Disables the FreeU mechanism if enabled."""
637
731
  self.unet.disable_freeu()
638
732
 
733
+ def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
734
+ """
735
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
736
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
737
+
738
+ <Tip warning={true}>
739
+
740
+ This API is 🧪 experimental.
741
+
742
+ </Tip>
743
+
744
+ Args:
745
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
746
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
747
+ """
748
+ self.fusing_unet = False
749
+ self.fusing_vae = False
750
+
751
+ if unet:
752
+ self.fusing_unet = True
753
+ self.unet.fuse_qkv_projections()
754
+ self.unet.set_attn_processor(FusedAttnProcessor2_0())
755
+
756
+ if vae:
757
+ if not isinstance(self.vae, AutoencoderKL):
758
+ raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
759
+
760
+ self.fusing_vae = True
761
+ self.vae.fuse_qkv_projections()
762
+ self.vae.set_attn_processor(FusedAttnProcessor2_0())
763
+
764
+ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
765
+ """Disable QKV projection fusion if enabled.
766
+
767
+ <Tip warning={true}>
768
+
769
+ This API is 🧪 experimental.
770
+
771
+ </Tip>
772
+
773
+ Args:
774
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
775
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
776
+
777
+ """
778
+ if unet:
779
+ if not self.fusing_unet:
780
+ logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
781
+ else:
782
+ self.unet.unfuse_qkv_projections()
783
+ self.fusing_unet = False
784
+
785
+ if vae:
786
+ if not self.fusing_vae:
787
+ logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
788
+ else:
789
+ self.vae.unfuse_qkv_projections()
790
+ self.fusing_vae = False
791
+
639
792
  # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
640
793
  def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
641
794
  """
@@ -696,6 +849,10 @@ class StableDiffusionXLPipeline(
696
849
  def num_timesteps(self):
697
850
  return self._num_timesteps
698
851
 
852
+ @property
853
+ def interrupt(self):
854
+ return self._interrupt
855
+
699
856
  @torch.no_grad()
700
857
  @replace_example_docstring(EXAMPLE_DOC_STRING)
701
858
  def __call__(
@@ -705,6 +862,7 @@ class StableDiffusionXLPipeline(
705
862
  height: Optional[int] = None,
706
863
  width: Optional[int] = None,
707
864
  num_inference_steps: int = 50,
865
+ timesteps: List[int] = None,
708
866
  denoising_end: Optional[float] = None,
709
867
  guidance_scale: float = 5.0,
710
868
  negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -717,6 +875,7 @@ class StableDiffusionXLPipeline(
717
875
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
718
876
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
719
877
  negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
878
+ ip_adapter_image: Optional[PipelineImageInput] = None,
720
879
  output_type: Optional[str] = "pil",
721
880
  return_dict: bool = True,
722
881
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -755,6 +914,10 @@ class StableDiffusionXLPipeline(
755
914
  num_inference_steps (`int`, *optional*, defaults to 50):
756
915
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
757
916
  expense of slower inference.
917
+ timesteps (`List[int]`, *optional*):
918
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
919
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
920
+ passed will be used. Must be in descending order.
758
921
  denoising_end (`float`, *optional*):
759
922
  When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
760
923
  completed before it is intentionally prematurely terminated. As a result, the returned sample will
@@ -801,6 +964,7 @@ class StableDiffusionXLPipeline(
801
964
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
802
965
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
803
966
  input argument.
967
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
804
968
  output_type (`str`, *optional*, defaults to `"pil"`):
805
969
  The output format of the generate image. Choose between
806
970
  [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -853,7 +1017,7 @@ class StableDiffusionXLPipeline(
853
1017
  callback_on_step_end_tensor_inputs (`List`, *optional*):
854
1018
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
855
1019
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
856
- `._callback_tensor_inputs` attribute of your pipeine class.
1020
+ `._callback_tensor_inputs` attribute of your pipeline class.
857
1021
 
858
1022
  Examples:
859
1023
 
@@ -907,6 +1071,7 @@ class StableDiffusionXLPipeline(
907
1071
  self._clip_skip = clip_skip
908
1072
  self._cross_attention_kwargs = cross_attention_kwargs
909
1073
  self._denoising_end = denoising_end
1074
+ self._interrupt = False
910
1075
 
911
1076
  # 2. Define call parameters
912
1077
  if prompt is not None and isinstance(prompt, str):
@@ -945,9 +1110,7 @@ class StableDiffusionXLPipeline(
945
1110
  )
946
1111
 
947
1112
  # 4. Prepare timesteps
948
- self.scheduler.set_timesteps(num_inference_steps, device=device)
949
-
950
- timesteps = self.scheduler.timesteps
1113
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
951
1114
 
952
1115
  # 5. Prepare latent variables
953
1116
  num_channels_latents = self.unet.config.in_channels
@@ -999,6 +1162,15 @@ class StableDiffusionXLPipeline(
999
1162
  add_text_embeds = add_text_embeds.to(device)
1000
1163
  add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1001
1164
 
1165
+ if ip_adapter_image is not None:
1166
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
1167
+ image_embeds, negative_image_embeds = self.encode_image(
1168
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
1169
+ )
1170
+ if self.do_classifier_free_guidance:
1171
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
1172
+ image_embeds = image_embeds.to(device)
1173
+
1002
1174
  # 8. Denoising loop
1003
1175
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1004
1176
 
@@ -1029,6 +1201,9 @@ class StableDiffusionXLPipeline(
1029
1201
  self._num_timesteps = len(timesteps)
1030
1202
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1031
1203
  for i, t in enumerate(timesteps):
1204
+ if self.interrupt:
1205
+ continue
1206
+
1032
1207
  # expand the latents if we are doing classifier free guidance
1033
1208
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1034
1209
 
@@ -1036,6 +1211,8 @@ class StableDiffusionXLPipeline(
1036
1211
 
1037
1212
  # predict the noise residual
1038
1213
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1214
+ if ip_adapter_image is not None:
1215
+ added_cond_kwargs["image_embeds"] = image_embeds
1039
1216
  noise_pred = self.unet(
1040
1217
  latent_model_input,
1041
1218
  t,