diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +19 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  229. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  231. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  232. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
  267. diffusers-0.27.2.dist-info/RECORD +0 -399
  268. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  269. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -169,8 +169,8 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
169
169
  num_images_per_prompt,
170
170
  do_classifier_free_guidance,
171
171
  negative_prompt=None,
172
- prompt_embeds: Optional[torch.FloatTensor] = None,
173
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
172
+ prompt_embeds: Optional[torch.Tensor] = None,
173
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
174
174
  lora_scale: Optional[float] = None,
175
175
  **kwargs,
176
176
  ):
@@ -202,8 +202,8 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
202
202
  num_images_per_prompt,
203
203
  do_classifier_free_guidance,
204
204
  negative_prompt=None,
205
- prompt_embeds: Optional[torch.FloatTensor] = None,
206
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
205
+ prompt_embeds: Optional[torch.Tensor] = None,
206
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
207
207
  lora_scale: Optional[float] = None,
208
208
  clip_skip: Optional[int] = None,
209
209
  ):
@@ -223,10 +223,10 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
223
223
  The prompt or prompts not to guide the image generation. If not defined, one has to pass
224
224
  `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
225
225
  less than `1`).
226
- prompt_embeds (`torch.FloatTensor`, *optional*):
226
+ prompt_embeds (`torch.Tensor`, *optional*):
227
227
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
228
228
  provided, text embeddings will be generated from `prompt` input argument.
229
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
229
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
230
230
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
231
231
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
232
232
  argument.
@@ -401,6 +401,40 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
401
401
 
402
402
  return image_embeds, uncond_image_embeds
403
403
 
404
+ def prepare_ip_adapter_image_embeds(
405
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
406
+ ):
407
+ if ip_adapter_image_embeds is None:
408
+ if not isinstance(ip_adapter_image, list):
409
+ ip_adapter_image = [ip_adapter_image]
410
+
411
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
412
+ raise ValueError(
413
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
414
+ )
415
+
416
+ image_embeds = []
417
+ for single_ip_adapter_image, image_proj_layer in zip(
418
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
419
+ ):
420
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
421
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
422
+ single_ip_adapter_image, device, 1, output_hidden_state
423
+ )
424
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
425
+ single_negative_image_embeds = torch.stack(
426
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
427
+ )
428
+
429
+ if do_classifier_free_guidance:
430
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
431
+ single_image_embeds = single_image_embeds.to(device)
432
+
433
+ image_embeds.append(single_image_embeds)
434
+ else:
435
+ image_embeds = ip_adapter_image_embeds
436
+ return image_embeds
437
+
404
438
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
405
439
  def run_safety_checker(self, image, device, dtype):
406
440
  if self.safety_checker is None:
@@ -501,7 +535,12 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
501
535
 
502
536
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
503
537
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
504
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
538
+ shape = (
539
+ batch_size,
540
+ num_channels_latents,
541
+ int(height) // self.vae_scale_factor,
542
+ int(width) // self.vae_scale_factor,
543
+ )
505
544
  if isinstance(generator, list) and len(generator) != batch_size:
506
545
  raise ValueError(
507
546
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -531,13 +570,14 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
531
570
  num_images_per_prompt: Optional[int] = 1,
532
571
  eta: float = 0.0,
533
572
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
534
- latents: Optional[torch.FloatTensor] = None,
535
- prompt_embeds: Optional[torch.FloatTensor] = None,
536
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
573
+ latents: Optional[torch.Tensor] = None,
574
+ prompt_embeds: Optional[torch.Tensor] = None,
575
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
537
576
  ip_adapter_image: Optional[PipelineImageInput] = None,
577
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
538
578
  output_type: Optional[str] = "pil",
539
579
  return_dict: bool = True,
540
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
580
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
541
581
  callback_steps: Optional[int] = 1,
542
582
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
543
583
  clip_skip: Optional[int] = None,
@@ -571,18 +611,21 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
571
611
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
572
612
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
573
613
  generation deterministic.
574
- latents (`torch.FloatTensor`, *optional*):
614
+ latents (`torch.Tensor`, *optional*):
575
615
  Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
576
616
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
577
617
  tensor is generated by sampling using the supplied random `generator`.
578
- prompt_embeds (`torch.FloatTensor`, *optional*):
618
+ prompt_embeds (`torch.Tensor`, *optional*):
579
619
  Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
580
620
  provided, text embeddings are generated from the `prompt` input argument.
581
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
621
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
582
622
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
583
623
  not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
584
624
  ip_adapter_image: (`PipelineImageInput`, *optional*):
585
625
  Optional image input to work with IP Adapters.
626
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
627
+ Pre-generated image embeddings for IP-Adapter. If not provided, embeddings are computed from the
628
+ `ip_adapter_image` input argument.
586
629
  output_type (`str`, *optional*, defaults to `"pil"`):
587
630
  The output format of the generated image. Choose between `PIL.Image` or `np.array`.
588
631
  return_dict (`bool`, *optional*, defaults to `True`):
@@ -590,7 +633,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
590
633
  plain tuple.
591
634
  callback (`Callable`, *optional*):
592
635
  A function that calls every `callback_steps` steps during inference. The function is called with the
593
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
636
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
594
637
  callback_steps (`int`, *optional*, defaults to 1):
595
638
  The frequency at which the `callback` function is called. If not specified, the callback is called at
596
639
  every step.
@@ -632,17 +675,28 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
632
675
  # corresponds to doing no classifier free guidance.
633
676
  do_classifier_free_guidance = guidance_scale > 1.0
634
677
  # and `sag_scale` is` `s` of equation (16)
635
- # of the self-attentnion guidance paper: https://arxiv.org/pdf/2210.00939.pdf
678
+ # of the self-attention guidance paper: https://arxiv.org/pdf/2210.00939.pdf
636
679
  # `sag_scale = 0` means no self-attention guidance
637
680
  do_self_attention_guidance = sag_scale > 0.0
638
681
 
639
- if ip_adapter_image is not None:
640
- output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
641
- image_embeds, negative_image_embeds = self.encode_image(
642
- ip_adapter_image, device, num_images_per_prompt, output_hidden_state
682
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
683
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
684
+ ip_adapter_image,
685
+ ip_adapter_image_embeds,
686
+ device,
687
+ batch_size * num_images_per_prompt,
688
+ do_classifier_free_guidance,
643
689
  )
690
+
644
691
  if do_classifier_free_guidance:
645
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
692
+ image_embeds = []
693
+ negative_image_embeds = []
694
+ for tmp_image_embeds in ip_adapter_image_embeds:
695
+ single_negative_image_embeds, single_image_embeds = tmp_image_embeds.chunk(2)
696
+ image_embeds.append(single_image_embeds)
697
+ negative_image_embeds.append(single_negative_image_embeds)
698
+ else:
699
+ image_embeds = ip_adapter_image_embeds
646
700
 
647
701
  # 3. Encode input prompt
648
702
  prompt_embeds, negative_prompt_embeds = self.encode_prompt(
@@ -667,7 +721,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
667
721
 
668
722
  if timesteps.dtype not in [torch.int16, torch.int32, torch.int64]:
669
723
  raise ValueError(
670
- f"{self.__class__.__name__} does not support using a scheduler of type {self.scheduler.__class__.__name__}. Please make sure to use one of 'DDIMScheduler, PNDMScheduler, DDPMScheduler, DEISMultistepScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinlgestepScheduler'."
724
+ f"{self.__class__.__name__} does not support using a scheduler of type {self.scheduler.__class__.__name__}. Please make sure to use one of 'DDIMScheduler, PNDMScheduler, DDPMScheduler, DEISMultistepScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler'."
671
725
  )
672
726
 
673
727
  # 5. Prepare latent variables
@@ -687,10 +741,21 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
687
741
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
688
742
 
689
743
  # 6.1 Add image embeds for IP-Adapter
690
- added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
691
- added_uncond_kwargs = {"image_embeds": negative_image_embeds} if ip_adapter_image is not None else None
744
+ added_cond_kwargs = (
745
+ {"image_embeds": image_embeds}
746
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
747
+ else None
748
+ )
749
+
750
+ if do_classifier_free_guidance:
751
+ added_uncond_kwargs = (
752
+ {"image_embeds": negative_image_embeds}
753
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
754
+ else None
755
+ )
692
756
 
693
757
  # 7. Denoising loop
758
+ original_attn_proc = self.unet.attn_processors
694
759
  store_processor = CrossAttnStoreProcessor()
695
760
  self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor
696
761
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
@@ -723,7 +788,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
723
788
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
724
789
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
725
790
 
726
- # perform self-attention guidance with the stored self-attentnion map
791
+ # perform self-attention guidance with the stored self-attention map
727
792
  if do_self_attention_guidance:
728
793
  # classifier-free guidance produces two chunks of attention map
729
794
  # and we only use unconditional one according to equation (25)
@@ -789,6 +854,8 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
789
854
  image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
790
855
 
791
856
  self.maybe_free_model_hooks()
857
+ # make sure to set the original attention processors back
858
+ self.unet.set_attn_processor(original_attn_proc)
792
859
 
793
860
  if not return_dict:
794
861
  return (image, has_nsfw_concept)
@@ -24,6 +24,7 @@ from transformers import (
24
24
  CLIPVisionModelWithProjection,
25
25
  )
26
26
 
27
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
27
28
  from ...image_processor import PipelineImageInput, VaeImageProcessor
28
29
  from ...loaders import (
29
30
  FromSingleFileMixin,
@@ -107,6 +108,7 @@ def retrieve_timesteps(
107
108
  num_inference_steps: Optional[int] = None,
108
109
  device: Optional[Union[str, torch.device]] = None,
109
110
  timesteps: Optional[List[int]] = None,
111
+ sigmas: Optional[List[float]] = None,
110
112
  **kwargs,
111
113
  ):
112
114
  """
@@ -117,19 +119,23 @@ def retrieve_timesteps(
117
119
  scheduler (`SchedulerMixin`):
118
120
  The scheduler to get timesteps from.
119
121
  num_inference_steps (`int`):
120
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
121
- `timesteps` must be `None`.
122
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
123
+ must be `None`.
122
124
  device (`str` or `torch.device`, *optional*):
123
125
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
124
126
  timesteps (`List[int]`, *optional*):
125
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
126
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
127
- must be `None`.
127
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
128
+ `num_inference_steps` and `sigmas` must be `None`.
129
+ sigmas (`List[float]`, *optional*):
130
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
131
+ `num_inference_steps` and `timesteps` must be `None`.
128
132
 
129
133
  Returns:
130
134
  `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
131
135
  second element is the number of inference steps.
132
136
  """
137
+ if timesteps is not None and sigmas is not None:
138
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
133
139
  if timesteps is not None:
134
140
  accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
135
141
  if not accepts_timesteps:
@@ -140,6 +146,16 @@ def retrieve_timesteps(
140
146
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
141
147
  timesteps = scheduler.timesteps
142
148
  num_inference_steps = len(timesteps)
149
+ elif sigmas is not None:
150
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
151
+ if not accept_sigmas:
152
+ raise ValueError(
153
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
154
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
155
+ )
156
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
157
+ timesteps = scheduler.timesteps
158
+ num_inference_steps = len(timesteps)
143
159
  else:
144
160
  scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
145
161
  timesteps = scheduler.timesteps
@@ -267,10 +283,10 @@ class StableDiffusionXLPipeline(
267
283
  do_classifier_free_guidance: bool = True,
268
284
  negative_prompt: Optional[str] = None,
269
285
  negative_prompt_2: Optional[str] = None,
270
- prompt_embeds: Optional[torch.FloatTensor] = None,
271
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
272
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
273
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
286
+ prompt_embeds: Optional[torch.Tensor] = None,
287
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
288
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
289
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
274
290
  lora_scale: Optional[float] = None,
275
291
  clip_skip: Optional[int] = None,
276
292
  ):
@@ -296,17 +312,17 @@ class StableDiffusionXLPipeline(
296
312
  negative_prompt_2 (`str` or `List[str]`, *optional*):
297
313
  The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
298
314
  `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
299
- prompt_embeds (`torch.FloatTensor`, *optional*):
315
+ prompt_embeds (`torch.Tensor`, *optional*):
300
316
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
301
317
  provided, text embeddings will be generated from `prompt` input argument.
302
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
318
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
303
319
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
304
320
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
305
321
  argument.
306
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
322
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
307
323
  Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
308
324
  If not provided, pooled text embeddings will be generated from `prompt` input argument.
309
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
325
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
310
326
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
311
327
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
312
328
  input argument.
@@ -685,7 +701,12 @@ class StableDiffusionXLPipeline(
685
701
 
686
702
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
687
703
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
688
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
704
+ shape = (
705
+ batch_size,
706
+ num_channels_latents,
707
+ int(height) // self.vae_scale_factor,
708
+ int(width) // self.vae_scale_factor,
709
+ )
689
710
  if isinstance(generator, list) and len(generator) != batch_size:
690
711
  raise ValueError(
691
712
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -740,20 +761,22 @@ class StableDiffusionXLPipeline(
740
761
  self.vae.decoder.mid_block.to(dtype)
741
762
 
742
763
  # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
743
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
764
+ def get_guidance_scale_embedding(
765
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
766
+ ) -> torch.Tensor:
744
767
  """
745
768
  See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
746
769
 
747
770
  Args:
748
- timesteps (`torch.Tensor`):
749
- generate embedding vectors at these timesteps
771
+ w (`torch.Tensor`):
772
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
750
773
  embedding_dim (`int`, *optional*, defaults to 512):
751
- dimension of the embeddings to generate
752
- dtype:
753
- data type of the generated embeddings
774
+ Dimension of the embeddings to generate.
775
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
776
+ Data type of the generated embeddings.
754
777
 
755
778
  Returns:
756
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
779
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
757
780
  """
758
781
  assert len(w.shape) == 1
759
782
  w = w * 1000.0
@@ -813,6 +836,7 @@ class StableDiffusionXLPipeline(
813
836
  width: Optional[int] = None,
814
837
  num_inference_steps: int = 50,
815
838
  timesteps: List[int] = None,
839
+ sigmas: List[float] = None,
816
840
  denoising_end: Optional[float] = None,
817
841
  guidance_scale: float = 5.0,
818
842
  negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -820,13 +844,13 @@ class StableDiffusionXLPipeline(
820
844
  num_images_per_prompt: Optional[int] = 1,
821
845
  eta: float = 0.0,
822
846
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
823
- latents: Optional[torch.FloatTensor] = None,
824
- prompt_embeds: Optional[torch.FloatTensor] = None,
825
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
826
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
827
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
847
+ latents: Optional[torch.Tensor] = None,
848
+ prompt_embeds: Optional[torch.Tensor] = None,
849
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
850
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
851
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
828
852
  ip_adapter_image: Optional[PipelineImageInput] = None,
829
- ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
853
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
830
854
  output_type: Optional[str] = "pil",
831
855
  return_dict: bool = True,
832
856
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -838,7 +862,9 @@ class StableDiffusionXLPipeline(
838
862
  negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
839
863
  negative_target_size: Optional[Tuple[int, int]] = None,
840
864
  clip_skip: Optional[int] = None,
841
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
865
+ callback_on_step_end: Optional[
866
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
867
+ ] = None,
842
868
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
843
869
  **kwargs,
844
870
  ):
@@ -869,6 +895,10 @@ class StableDiffusionXLPipeline(
869
895
  Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
870
896
  in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
871
897
  passed will be used. Must be in descending order.
898
+ sigmas (`List[float]`, *optional*):
899
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
900
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
901
+ will be used.
872
902
  denoising_end (`float`, *optional*):
873
903
  When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
874
904
  completed before it is intentionally prematurely terminated. As a result, the returned sample will
@@ -897,30 +927,30 @@ class StableDiffusionXLPipeline(
897
927
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
898
928
  One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
899
929
  to make generation deterministic.
900
- latents (`torch.FloatTensor`, *optional*):
930
+ latents (`torch.Tensor`, *optional*):
901
931
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
902
932
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
903
933
  tensor will ge generated by sampling using the supplied random `generator`.
904
- prompt_embeds (`torch.FloatTensor`, *optional*):
934
+ prompt_embeds (`torch.Tensor`, *optional*):
905
935
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
906
936
  provided, text embeddings will be generated from `prompt` input argument.
907
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
937
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
908
938
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
909
939
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
910
940
  argument.
911
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
941
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
912
942
  Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
913
943
  If not provided, pooled text embeddings will be generated from `prompt` input argument.
914
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
944
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
915
945
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
916
946
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
917
947
  input argument.
918
948
  ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
919
- ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
920
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
921
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
922
- if `do_classifier_free_guidance` is set to `True`.
923
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
949
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
950
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
951
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
952
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
953
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
924
954
  output_type (`str`, *optional*, defaults to `"pil"`):
925
955
  The output format of the generate image. Choose between
926
956
  [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -965,11 +995,11 @@ class StableDiffusionXLPipeline(
965
995
  as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
966
996
  [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
967
997
  information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
968
- callback_on_step_end (`Callable`, *optional*):
969
- A function that calls at the end of each denoising steps during the inference. The function is called
970
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
971
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
972
- `callback_on_step_end_tensor_inputs`.
998
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
999
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1000
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1001
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1002
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
973
1003
  callback_on_step_end_tensor_inputs (`List`, *optional*):
974
1004
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
975
1005
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
@@ -999,6 +1029,9 @@ class StableDiffusionXLPipeline(
999
1029
  "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1000
1030
  )
1001
1031
 
1032
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1033
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1034
+
1002
1035
  # 0. Default height and width to unet
1003
1036
  height = height or self.default_sample_size * self.vae_scale_factor
1004
1037
  width = width or self.default_sample_size * self.vae_scale_factor
@@ -1068,7 +1101,9 @@ class StableDiffusionXLPipeline(
1068
1101
  )
1069
1102
 
1070
1103
  # 4. Prepare timesteps
1071
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1104
+ timesteps, num_inference_steps = retrieve_timesteps(
1105
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1106
+ )
1072
1107
 
1073
1108
  # 5. Prepare latent variables
1074
1109
  num_channels_latents = self.unet.config.in_channels
@@ -1191,7 +1226,12 @@ class StableDiffusionXLPipeline(
1191
1226
  noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1192
1227
 
1193
1228
  # compute the previous noisy sample x_t -> x_t-1
1229
+ latents_dtype = latents.dtype
1194
1230
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1231
+ if latents.dtype != latents_dtype:
1232
+ if torch.backends.mps.is_available():
1233
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1234
+ latents = latents.to(latents_dtype)
1195
1235
 
1196
1236
  if callback_on_step_end is not None:
1197
1237
  callback_kwargs = {}
@@ -1226,6 +1266,10 @@ class StableDiffusionXLPipeline(
1226
1266
  if needs_upcasting:
1227
1267
  self.upcast_vae()
1228
1268
  latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1269
+ elif latents.dtype != self.vae.dtype:
1270
+ if torch.backends.mps.is_available():
1271
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1272
+ self.vae = self.vae.to(latents.dtype)
1229
1273
 
1230
1274
  # unscale/denormalize the latents
1231
1275
  # denormalize with the mean and std if available and not None