diffusers 0.27.1__py3-none-any.whl → 0.32.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (445) hide show
  1. diffusers/__init__.py +233 -6
  2. diffusers/callbacks.py +209 -0
  3. diffusers/commands/env.py +102 -6
  4. diffusers/configuration_utils.py +45 -16
  5. diffusers/dependency_versions_table.py +4 -3
  6. diffusers/image_processor.py +434 -110
  7. diffusers/loaders/__init__.py +42 -9
  8. diffusers/loaders/ip_adapter.py +626 -36
  9. diffusers/loaders/lora_base.py +900 -0
  10. diffusers/loaders/lora_conversion_utils.py +991 -125
  11. diffusers/loaders/lora_pipeline.py +3812 -0
  12. diffusers/loaders/peft.py +571 -7
  13. diffusers/loaders/single_file.py +405 -173
  14. diffusers/loaders/single_file_model.py +385 -0
  15. diffusers/loaders/single_file_utils.py +1783 -713
  16. diffusers/loaders/textual_inversion.py +41 -23
  17. diffusers/loaders/transformer_flux.py +181 -0
  18. diffusers/loaders/transformer_sd3.py +89 -0
  19. diffusers/loaders/unet.py +464 -540
  20. diffusers/loaders/unet_loader_utils.py +163 -0
  21. diffusers/models/__init__.py +76 -7
  22. diffusers/models/activations.py +65 -10
  23. diffusers/models/adapter.py +53 -53
  24. diffusers/models/attention.py +605 -18
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +4304 -687
  27. diffusers/models/autoencoders/__init__.py +8 -0
  28. diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
  29. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  30. diffusers/models/autoencoders/autoencoder_kl.py +110 -28
  31. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  32. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  33. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  34. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  35. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  36. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
  37. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  38. diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
  39. diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
  40. diffusers/models/autoencoders/vae.py +41 -29
  41. diffusers/models/autoencoders/vq_model.py +182 -0
  42. diffusers/models/controlnet.py +47 -800
  43. diffusers/models/controlnet_flux.py +70 -0
  44. diffusers/models/controlnet_sd3.py +68 -0
  45. diffusers/models/controlnet_sparsectrl.py +116 -0
  46. diffusers/models/controlnets/__init__.py +23 -0
  47. diffusers/models/controlnets/controlnet.py +872 -0
  48. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
  49. diffusers/models/controlnets/controlnet_flux.py +536 -0
  50. diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
  51. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  52. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  53. diffusers/models/controlnets/controlnet_union.py +832 -0
  54. diffusers/models/controlnets/controlnet_xs.py +1946 -0
  55. diffusers/models/controlnets/multicontrolnet.py +183 -0
  56. diffusers/models/downsampling.py +85 -18
  57. diffusers/models/embeddings.py +1856 -158
  58. diffusers/models/embeddings_flax.py +23 -9
  59. diffusers/models/model_loading_utils.py +480 -0
  60. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  61. diffusers/models/modeling_flax_utils.py +2 -7
  62. diffusers/models/modeling_outputs.py +14 -0
  63. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  64. diffusers/models/modeling_utils.py +611 -146
  65. diffusers/models/normalization.py +361 -20
  66. diffusers/models/resnet.py +18 -23
  67. diffusers/models/transformers/__init__.py +16 -0
  68. diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
  69. diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
  70. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  71. diffusers/models/transformers/dual_transformer_2d.py +9 -8
  72. diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
  73. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  74. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  75. diffusers/models/transformers/pixart_transformer_2d.py +445 -0
  76. diffusers/models/transformers/prior_transformer.py +13 -13
  77. diffusers/models/transformers/sana_transformer.py +488 -0
  78. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  79. diffusers/models/transformers/t5_film_transformer.py +17 -19
  80. diffusers/models/transformers/transformer_2d.py +297 -187
  81. diffusers/models/transformers/transformer_allegro.py +422 -0
  82. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  83. diffusers/models/transformers/transformer_flux.py +593 -0
  84. diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
  85. diffusers/models/transformers/transformer_ltx.py +469 -0
  86. diffusers/models/transformers/transformer_mochi.py +499 -0
  87. diffusers/models/transformers/transformer_sd3.py +461 -0
  88. diffusers/models/transformers/transformer_temporal.py +21 -19
  89. diffusers/models/unets/unet_1d.py +8 -8
  90. diffusers/models/unets/unet_1d_blocks.py +31 -31
  91. diffusers/models/unets/unet_2d.py +17 -10
  92. diffusers/models/unets/unet_2d_blocks.py +225 -149
  93. diffusers/models/unets/unet_2d_condition.py +41 -40
  94. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  95. diffusers/models/unets/unet_3d_blocks.py +192 -1057
  96. diffusers/models/unets/unet_3d_condition.py +22 -27
  97. diffusers/models/unets/unet_i2vgen_xl.py +22 -18
  98. diffusers/models/unets/unet_kandinsky3.py +2 -2
  99. diffusers/models/unets/unet_motion_model.py +1413 -89
  100. diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
  101. diffusers/models/unets/unet_stable_cascade.py +19 -18
  102. diffusers/models/unets/uvit_2d.py +2 -2
  103. diffusers/models/upsampling.py +95 -26
  104. diffusers/models/vq_model.py +12 -164
  105. diffusers/optimization.py +1 -1
  106. diffusers/pipelines/__init__.py +202 -3
  107. diffusers/pipelines/allegro/__init__.py +48 -0
  108. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  109. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  110. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  111. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  112. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  113. diffusers/pipelines/animatediff/__init__.py +8 -0
  114. diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
  115. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
  116. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
  117. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
  118. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
  119. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  120. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  121. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  122. diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
  123. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
  124. diffusers/pipelines/aura_flow/__init__.py +48 -0
  125. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
  126. diffusers/pipelines/auto_pipeline.py +196 -28
  127. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  128. diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
  129. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  130. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  131. diffusers/pipelines/cogvideo/__init__.py +54 -0
  132. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
  133. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  134. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
  135. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
  136. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  137. diffusers/pipelines/cogview3/__init__.py +47 -0
  138. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  139. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  140. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
  141. diffusers/pipelines/controlnet/__init__.py +86 -80
  142. diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
  143. diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
  144. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  145. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
  146. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
  147. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
  148. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
  149. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
  150. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  151. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  152. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  153. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  154. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  155. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
  156. diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
  157. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
  158. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  159. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  160. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
  161. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
  162. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  163. diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
  164. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
  165. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
  166. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
  167. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
  168. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
  169. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  170. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  171. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  172. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
  173. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
  174. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  175. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  176. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
  177. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
  178. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
  179. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
  180. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
  181. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
  182. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  183. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  184. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  185. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  186. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  187. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  188. diffusers/pipelines/flux/__init__.py +69 -0
  189. diffusers/pipelines/flux/modeling_flux.py +47 -0
  190. diffusers/pipelines/flux/pipeline_flux.py +957 -0
  191. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  192. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  193. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  194. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  195. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  196. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  197. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  198. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  199. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  200. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  201. diffusers/pipelines/flux/pipeline_output.py +37 -0
  202. diffusers/pipelines/free_init_utils.py +41 -38
  203. diffusers/pipelines/free_noise_utils.py +596 -0
  204. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  205. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  206. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  207. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  208. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
  209. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  210. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  211. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
  212. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  213. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  214. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  215. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  216. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
  217. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  218. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  219. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  220. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  221. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  222. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  223. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  224. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
  225. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
  226. diffusers/pipelines/kolors/__init__.py +54 -0
  227. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  228. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
  229. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  230. diffusers/pipelines/kolors/text_encoder.py +889 -0
  231. diffusers/pipelines/kolors/tokenizer.py +338 -0
  232. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
  233. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
  234. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
  235. diffusers/pipelines/latte/__init__.py +48 -0
  236. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  237. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
  238. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
  239. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  240. diffusers/pipelines/ltx/__init__.py +50 -0
  241. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  242. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  243. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  244. diffusers/pipelines/lumina/__init__.py +48 -0
  245. diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
  246. diffusers/pipelines/marigold/__init__.py +50 -0
  247. diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
  248. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  249. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  250. diffusers/pipelines/mochi/__init__.py +48 -0
  251. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  252. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  253. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  254. diffusers/pipelines/pag/__init__.py +80 -0
  255. diffusers/pipelines/pag/pag_utils.py +243 -0
  256. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
  257. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  258. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
  259. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  260. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
  261. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  262. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
  263. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  264. diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
  265. diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
  266. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  267. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
  268. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  269. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  270. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
  271. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
  272. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
  273. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  274. diffusers/pipelines/pia/pipeline_pia.py +74 -164
  275. diffusers/pipelines/pipeline_flax_utils.py +5 -10
  276. diffusers/pipelines/pipeline_loading_utils.py +515 -53
  277. diffusers/pipelines/pipeline_utils.py +411 -222
  278. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  279. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
  280. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
  281. diffusers/pipelines/sana/__init__.py +47 -0
  282. diffusers/pipelines/sana/pipeline_output.py +21 -0
  283. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  284. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
  285. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  286. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  287. diffusers/pipelines/shap_e/renderer.py +1 -1
  288. diffusers/pipelines/stable_audio/__init__.py +50 -0
  289. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  290. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
  291. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
  292. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  293. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
  294. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  295. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
  296. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  297. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  298. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  299. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
  300. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
  301. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  302. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
  303. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
  304. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
  305. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
  306. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
  307. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
  308. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
  309. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  310. diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  311. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  312. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  313. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
  314. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
  315. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
  316. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
  317. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
  318. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
  319. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
  320. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
  321. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
  322. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
  323. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  324. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  325. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
  326. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
  327. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
  328. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
  329. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
  330. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  331. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  332. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
  333. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
  334. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  335. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
  336. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
  337. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
  338. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
  339. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  340. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  341. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  342. diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
  343. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
  344. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  345. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  346. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
  347. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  348. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  349. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
  350. diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
  351. diffusers/quantizers/auto.py +139 -0
  352. diffusers/quantizers/base.py +233 -0
  353. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  354. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  355. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  356. diffusers/quantizers/gguf/__init__.py +1 -0
  357. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  358. diffusers/quantizers/gguf/utils.py +456 -0
  359. diffusers/quantizers/quantization_config.py +669 -0
  360. diffusers/quantizers/torchao/__init__.py +15 -0
  361. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  362. diffusers/schedulers/__init__.py +12 -2
  363. diffusers/schedulers/deprecated/__init__.py +1 -1
  364. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  365. diffusers/schedulers/scheduling_amused.py +5 -5
  366. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  367. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  368. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  369. diffusers/schedulers/scheduling_ddim.py +27 -26
  370. diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
  371. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  372. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  373. diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
  374. diffusers/schedulers/scheduling_ddpm.py +27 -30
  375. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  376. diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
  377. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  378. diffusers/schedulers/scheduling_deis_multistep.py +150 -50
  379. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  380. diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
  381. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  382. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
  383. diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
  384. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
  385. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
  386. diffusers/schedulers/scheduling_edm_euler.py +62 -39
  387. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
  388. diffusers/schedulers/scheduling_euler_discrete.py +255 -74
  389. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
  390. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
  391. diffusers/schedulers/scheduling_heun_discrete.py +174 -46
  392. diffusers/schedulers/scheduling_ipndm.py +9 -9
  393. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
  394. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
  395. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  396. diffusers/schedulers/scheduling_lcm.py +23 -29
  397. diffusers/schedulers/scheduling_lms_discrete.py +105 -28
  398. diffusers/schedulers/scheduling_pndm.py +20 -20
  399. diffusers/schedulers/scheduling_repaint.py +21 -21
  400. diffusers/schedulers/scheduling_sasolver.py +157 -60
  401. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  402. diffusers/schedulers/scheduling_tcd.py +41 -36
  403. diffusers/schedulers/scheduling_unclip.py +19 -16
  404. diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
  405. diffusers/schedulers/scheduling_utils.py +12 -5
  406. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  407. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  408. diffusers/training_utils.py +214 -30
  409. diffusers/utils/__init__.py +17 -1
  410. diffusers/utils/constants.py +3 -0
  411. diffusers/utils/doc_utils.py +1 -0
  412. diffusers/utils/dummy_pt_objects.py +592 -7
  413. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  414. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  415. diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
  416. diffusers/utils/dynamic_modules_utils.py +34 -29
  417. diffusers/utils/export_utils.py +50 -6
  418. diffusers/utils/hub_utils.py +131 -17
  419. diffusers/utils/import_utils.py +210 -8
  420. diffusers/utils/loading_utils.py +118 -5
  421. diffusers/utils/logging.py +4 -2
  422. diffusers/utils/peft_utils.py +37 -7
  423. diffusers/utils/state_dict_utils.py +13 -2
  424. diffusers/utils/testing_utils.py +193 -11
  425. diffusers/utils/torch_utils.py +4 -0
  426. diffusers/video_processor.py +113 -0
  427. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
  428. diffusers-0.32.2.dist-info/RECORD +550 -0
  429. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
  430. diffusers/loaders/autoencoder.py +0 -146
  431. diffusers/loaders/controlnet.py +0 -136
  432. diffusers/loaders/lora.py +0 -1349
  433. diffusers/models/prior_transformer.py +0 -12
  434. diffusers/models/t5_film_transformer.py +0 -70
  435. diffusers/models/transformer_2d.py +0 -25
  436. diffusers/models/transformer_temporal.py +0 -34
  437. diffusers/models/unet_1d.py +0 -26
  438. diffusers/models/unet_1d_blocks.py +0 -203
  439. diffusers/models/unet_2d.py +0 -27
  440. diffusers/models/unet_2d_blocks.py +0 -375
  441. diffusers/models/unet_2d_condition.py +0 -25
  442. diffusers-0.27.1.dist-info/RECORD +0 -399
  443. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
  444. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
  445. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -26,6 +26,7 @@ from transformers import (
26
26
  CLIPVisionModelWithProjection,
27
27
  )
28
28
 
29
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
29
30
  from ...image_processor import PipelineImageInput, VaeImageProcessor
30
31
  from ...loaders import (
31
32
  FromSingleFileMixin,
@@ -36,8 +37,6 @@ from ...loaders import (
36
37
  from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
37
38
  from ...models.attention_processor import (
38
39
  AttnProcessor2_0,
39
- LoRAAttnProcessor2_0,
40
- LoRAXFormersAttnProcessor,
41
40
  XFormersAttnProcessor,
42
41
  )
43
42
  from ...models.lora import adjust_lora_scale_text_encoder
@@ -102,9 +101,21 @@ EXAMPLE_DOC_STRING = """
102
101
 
103
102
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
104
103
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
105
- """
106
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
107
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
104
+ r"""
105
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
106
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
107
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
108
+
109
+ Args:
110
+ noise_cfg (`torch.Tensor`):
111
+ The predicted noise tensor for the guided diffusion process.
112
+ noise_pred_text (`torch.Tensor`):
113
+ The predicted noise tensor for the text-guided diffusion process.
114
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
115
+ A rescale factor applied to the noise predictions.
116
+
117
+ Returns:
118
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
108
119
  """
109
120
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
110
121
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -131,124 +142,6 @@ def mask_pil_to_torch(mask, height, width):
131
142
  return mask
132
143
 
133
144
 
134
- def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
135
- """
136
- Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
137
- converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
138
- ``image`` and ``1`` for the ``mask``.
139
-
140
- The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
141
- binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
142
-
143
- Args:
144
- image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
145
- It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
146
- ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
147
- mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
148
- It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
149
- ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
150
-
151
-
152
- Raises:
153
- ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
154
- should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
155
- TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
156
- (ot the other way around).
157
-
158
- Returns:
159
- tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
160
- dimensions: ``batch x channels x height x width``.
161
- """
162
-
163
- # checkpoint. TOD(Yiyi) - need to clean this up later
164
- deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
165
- deprecate(
166
- "prepare_mask_and_masked_image",
167
- "0.30.0",
168
- deprecation_message,
169
- )
170
- if image is None:
171
- raise ValueError("`image` input cannot be undefined.")
172
-
173
- if mask is None:
174
- raise ValueError("`mask_image` input cannot be undefined.")
175
-
176
- if isinstance(image, torch.Tensor):
177
- if not isinstance(mask, torch.Tensor):
178
- mask = mask_pil_to_torch(mask, height, width)
179
-
180
- if image.ndim == 3:
181
- image = image.unsqueeze(0)
182
-
183
- # Batch and add channel dim for single mask
184
- if mask.ndim == 2:
185
- mask = mask.unsqueeze(0).unsqueeze(0)
186
-
187
- # Batch single mask or add channel dim
188
- if mask.ndim == 3:
189
- # Single batched mask, no channel dim or single mask not batched but channel dim
190
- if mask.shape[0] == 1:
191
- mask = mask.unsqueeze(0)
192
-
193
- # Batched masks no channel dim
194
- else:
195
- mask = mask.unsqueeze(1)
196
-
197
- assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
198
- # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
199
- assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
200
-
201
- # Check image is in [-1, 1]
202
- # if image.min() < -1 or image.max() > 1:
203
- # raise ValueError("Image should be in [-1, 1] range")
204
-
205
- # Check mask is in [0, 1]
206
- if mask.min() < 0 or mask.max() > 1:
207
- raise ValueError("Mask should be in [0, 1] range")
208
-
209
- # Binarize mask
210
- mask[mask < 0.5] = 0
211
- mask[mask >= 0.5] = 1
212
-
213
- # Image as float32
214
- image = image.to(dtype=torch.float32)
215
- elif isinstance(mask, torch.Tensor):
216
- raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
217
- else:
218
- # preprocess image
219
- if isinstance(image, (PIL.Image.Image, np.ndarray)):
220
- image = [image]
221
- if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
222
- # resize all images w.r.t passed height an width
223
- image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
224
- image = [np.array(i.convert("RGB"))[None, :] for i in image]
225
- image = np.concatenate(image, axis=0)
226
- elif isinstance(image, list) and isinstance(image[0], np.ndarray):
227
- image = np.concatenate([i[None, :] for i in image], axis=0)
228
-
229
- image = image.transpose(0, 3, 1, 2)
230
- image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
231
-
232
- mask = mask_pil_to_torch(mask, height, width)
233
- mask[mask < 0.5] = 0
234
- mask[mask >= 0.5] = 1
235
-
236
- if image.shape[1] == 4:
237
- # images are in latent space and thus can't
238
- # be masked set masked_image to None
239
- # we assume that the checkpoint is not an inpainting
240
- # checkpoint. TOD(Yiyi) - need to clean this up later
241
- masked_image = None
242
- else:
243
- masked_image = image * (mask < 0.5)
244
-
245
- # n.b. ensure backwards compatibility as old function does not return image
246
- if return_image:
247
- return mask, masked_image, image
248
-
249
- return mask, masked_image
250
-
251
-
252
145
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
253
146
  def retrieve_latents(
254
147
  encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -269,9 +162,10 @@ def retrieve_timesteps(
269
162
  num_inference_steps: Optional[int] = None,
270
163
  device: Optional[Union[str, torch.device]] = None,
271
164
  timesteps: Optional[List[int]] = None,
165
+ sigmas: Optional[List[float]] = None,
272
166
  **kwargs,
273
167
  ):
274
- """
168
+ r"""
275
169
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
276
170
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
277
171
 
@@ -279,19 +173,23 @@ def retrieve_timesteps(
279
173
  scheduler (`SchedulerMixin`):
280
174
  The scheduler to get timesteps from.
281
175
  num_inference_steps (`int`):
282
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
283
- `timesteps` must be `None`.
176
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
177
+ must be `None`.
284
178
  device (`str` or `torch.device`, *optional*):
285
179
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
286
180
  timesteps (`List[int]`, *optional*):
287
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
288
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
289
- must be `None`.
181
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
182
+ `num_inference_steps` and `sigmas` must be `None`.
183
+ sigmas (`List[float]`, *optional*):
184
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
185
+ `num_inference_steps` and `timesteps` must be `None`.
290
186
 
291
187
  Returns:
292
188
  `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
293
189
  second element is the number of inference steps.
294
190
  """
191
+ if timesteps is not None and sigmas is not None:
192
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
295
193
  if timesteps is not None:
296
194
  accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
297
195
  if not accepts_timesteps:
@@ -302,6 +200,16 @@ def retrieve_timesteps(
302
200
  scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
303
201
  timesteps = scheduler.timesteps
304
202
  num_inference_steps = len(timesteps)
203
+ elif sigmas is not None:
204
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
205
+ if not accept_sigmas:
206
+ raise ValueError(
207
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
208
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
209
+ )
210
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
211
+ timesteps = scheduler.timesteps
212
+ num_inference_steps = len(timesteps)
305
213
  else:
306
214
  scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
307
215
  timesteps = scheduler.timesteps
@@ -377,11 +285,8 @@ class StableDiffusionXLInpaintPipeline(
377
285
  _callback_tensor_inputs = [
378
286
  "latents",
379
287
  "prompt_embeds",
380
- "negative_prompt_embeds",
381
288
  "add_text_embeds",
382
289
  "add_time_ids",
383
- "negative_pooled_prompt_embeds",
384
- "add_neg_time_ids",
385
290
  "mask",
386
291
  "masked_image_latents",
387
292
  ]
@@ -458,6 +363,9 @@ class StableDiffusionXLInpaintPipeline(
458
363
  def prepare_ip_adapter_image_embeds(
459
364
  self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
460
365
  ):
366
+ image_embeds = []
367
+ if do_classifier_free_guidance:
368
+ negative_image_embeds = []
461
369
  if ip_adapter_image_embeds is None:
462
370
  if not isinstance(ip_adapter_image, list):
463
371
  ip_adapter_image = [ip_adapter_image]
@@ -467,7 +375,6 @@ class StableDiffusionXLInpaintPipeline(
467
375
  f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
468
376
  )
469
377
 
470
- image_embeds = []
471
378
  for single_ip_adapter_image, image_proj_layer in zip(
472
379
  ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
473
380
  ):
@@ -475,36 +382,28 @@ class StableDiffusionXLInpaintPipeline(
475
382
  single_image_embeds, single_negative_image_embeds = self.encode_image(
476
383
  single_ip_adapter_image, device, 1, output_hidden_state
477
384
  )
478
- single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
479
- single_negative_image_embeds = torch.stack(
480
- [single_negative_image_embeds] * num_images_per_prompt, dim=0
481
- )
482
385
 
386
+ image_embeds.append(single_image_embeds[None, :])
483
387
  if do_classifier_free_guidance:
484
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
485
- single_image_embeds = single_image_embeds.to(device)
486
-
487
- image_embeds.append(single_image_embeds)
388
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
488
389
  else:
489
- repeat_dims = [1]
490
- image_embeds = []
491
390
  for single_image_embeds in ip_adapter_image_embeds:
492
391
  if do_classifier_free_guidance:
493
392
  single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
494
- single_image_embeds = single_image_embeds.repeat(
495
- num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
496
- )
497
- single_negative_image_embeds = single_negative_image_embeds.repeat(
498
- num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
499
- )
500
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
501
- else:
502
- single_image_embeds = single_image_embeds.repeat(
503
- num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
504
- )
393
+ negative_image_embeds.append(single_negative_image_embeds)
505
394
  image_embeds.append(single_image_embeds)
506
395
 
507
- return image_embeds
396
+ ip_adapter_image_embeds = []
397
+ for i, single_image_embeds in enumerate(image_embeds):
398
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
399
+ if do_classifier_free_guidance:
400
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
401
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
402
+
403
+ single_image_embeds = single_image_embeds.to(device=device)
404
+ ip_adapter_image_embeds.append(single_image_embeds)
405
+
406
+ return ip_adapter_image_embeds
508
407
 
509
408
  # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
510
409
  def encode_prompt(
@@ -516,10 +415,10 @@ class StableDiffusionXLInpaintPipeline(
516
415
  do_classifier_free_guidance: bool = True,
517
416
  negative_prompt: Optional[str] = None,
518
417
  negative_prompt_2: Optional[str] = None,
519
- prompt_embeds: Optional[torch.FloatTensor] = None,
520
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
521
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
522
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
418
+ prompt_embeds: Optional[torch.Tensor] = None,
419
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
420
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
421
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
523
422
  lora_scale: Optional[float] = None,
524
423
  clip_skip: Optional[int] = None,
525
424
  ):
@@ -545,17 +444,17 @@ class StableDiffusionXLInpaintPipeline(
545
444
  negative_prompt_2 (`str` or `List[str]`, *optional*):
546
445
  The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
547
446
  `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
548
- prompt_embeds (`torch.FloatTensor`, *optional*):
447
+ prompt_embeds (`torch.Tensor`, *optional*):
549
448
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
550
449
  provided, text embeddings will be generated from `prompt` input argument.
551
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
450
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
552
451
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
553
452
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
554
453
  argument.
555
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
454
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
556
455
  Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
557
456
  If not provided, pooled text embeddings will be generated from `prompt` input argument.
558
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
457
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
559
458
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
560
459
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
561
460
  input argument.
@@ -880,7 +779,12 @@ class StableDiffusionXLInpaintPipeline(
880
779
  return_noise=False,
881
780
  return_image_latents=False,
882
781
  ):
883
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
782
+ shape = (
783
+ batch_size,
784
+ num_channels_latents,
785
+ int(height) // self.vae_scale_factor,
786
+ int(width) // self.vae_scale_factor,
787
+ )
884
788
  if isinstance(generator, list) and len(generator) != batch_size:
885
789
  raise ValueError(
886
790
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -1006,14 +910,16 @@ class StableDiffusionXLInpaintPipeline(
1006
910
  if denoising_start is None:
1007
911
  init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
1008
912
  t_start = max(num_inference_steps - init_timestep, 0)
1009
- else:
1010
- t_start = 0
1011
913
 
1012
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
914
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
915
+ if hasattr(self.scheduler, "set_begin_index"):
916
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
1013
917
 
1014
- # Strength is irrelevant if we directly request a timestep to start at;
1015
- # that is, strength is determined by the denoising_start instead.
1016
- if denoising_start is not None:
918
+ return timesteps, num_inference_steps - t_start
919
+
920
+ else:
921
+ # Strength is irrelevant if we directly request a timestep to start at;
922
+ # that is, strength is determined by the denoising_start instead.
1017
923
  discrete_timestep_cutoff = int(
1018
924
  round(
1019
925
  self.scheduler.config.num_train_timesteps
@@ -1021,22 +927,23 @@ class StableDiffusionXLInpaintPipeline(
1021
927
  )
1022
928
  )
1023
929
 
1024
- num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
930
+ num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
1025
931
  if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
1026
932
  # if the scheduler is a 2nd order scheduler we might have to do +1
1027
933
  # because `num_inference_steps` might be even given that every timestep
1028
934
  # (except the highest one) is duplicated. If `num_inference_steps` is even it would
1029
935
  # mean that we cut the timesteps in the middle of the denoising step
1030
- # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
936
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
1031
937
  # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
1032
938
  num_inference_steps = num_inference_steps + 1
1033
939
 
1034
940
  # because t_n+1 >= t_n, we slice the timesteps starting from the end
1035
- timesteps = timesteps[-num_inference_steps:]
941
+ t_start = len(self.scheduler.timesteps) - num_inference_steps
942
+ timesteps = self.scheduler.timesteps[t_start:]
943
+ if hasattr(self.scheduler, "set_begin_index"):
944
+ self.scheduler.set_begin_index(t_start)
1036
945
  return timesteps, num_inference_steps
1037
946
 
1038
- return timesteps, num_inference_steps - t_start
1039
-
1040
947
  # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
1041
948
  def _get_add_time_ids(
1042
949
  self,
@@ -1098,8 +1005,6 @@ class StableDiffusionXLInpaintPipeline(
1098
1005
  (
1099
1006
  AttnProcessor2_0,
1100
1007
  XFormersAttnProcessor,
1101
- LoRAXFormersAttnProcessor,
1102
- LoRAAttnProcessor2_0,
1103
1008
  ),
1104
1009
  )
1105
1010
  # if xformers or torch_2_0 is used attention block does not need
@@ -1110,20 +1015,22 @@ class StableDiffusionXLInpaintPipeline(
1110
1015
  self.vae.decoder.mid_block.to(dtype)
1111
1016
 
1112
1017
  # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
1113
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
1018
+ def get_guidance_scale_embedding(
1019
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
1020
+ ) -> torch.Tensor:
1114
1021
  """
1115
1022
  See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
1116
1023
 
1117
1024
  Args:
1118
- timesteps (`torch.Tensor`):
1119
- generate embedding vectors at these timesteps
1025
+ w (`torch.Tensor`):
1026
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
1120
1027
  embedding_dim (`int`, *optional*, defaults to 512):
1121
- dimension of the embeddings to generate
1122
- dtype:
1123
- data type of the generated embeddings
1028
+ Dimension of the embeddings to generate.
1029
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
1030
+ Data type of the generated embeddings.
1124
1031
 
1125
1032
  Returns:
1126
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
1033
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
1127
1034
  """
1128
1035
  assert len(w.shape) == 1
1129
1036
  w = w * 1000.0
@@ -1185,13 +1092,14 @@ class StableDiffusionXLInpaintPipeline(
1185
1092
  prompt_2: Optional[Union[str, List[str]]] = None,
1186
1093
  image: PipelineImageInput = None,
1187
1094
  mask_image: PipelineImageInput = None,
1188
- masked_image_latents: torch.FloatTensor = None,
1095
+ masked_image_latents: torch.Tensor = None,
1189
1096
  height: Optional[int] = None,
1190
1097
  width: Optional[int] = None,
1191
1098
  padding_mask_crop: Optional[int] = None,
1192
1099
  strength: float = 0.9999,
1193
1100
  num_inference_steps: int = 50,
1194
1101
  timesteps: List[int] = None,
1102
+ sigmas: List[float] = None,
1195
1103
  denoising_start: Optional[float] = None,
1196
1104
  denoising_end: Optional[float] = None,
1197
1105
  guidance_scale: float = 7.5,
@@ -1200,13 +1108,13 @@ class StableDiffusionXLInpaintPipeline(
1200
1108
  num_images_per_prompt: Optional[int] = 1,
1201
1109
  eta: float = 0.0,
1202
1110
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1203
- latents: Optional[torch.FloatTensor] = None,
1204
- prompt_embeds: Optional[torch.FloatTensor] = None,
1205
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1206
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1207
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1111
+ latents: Optional[torch.Tensor] = None,
1112
+ prompt_embeds: Optional[torch.Tensor] = None,
1113
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
1114
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
1115
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
1208
1116
  ip_adapter_image: Optional[PipelineImageInput] = None,
1209
- ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
1117
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1210
1118
  output_type: Optional[str] = "pil",
1211
1119
  return_dict: bool = True,
1212
1120
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -1220,7 +1128,9 @@ class StableDiffusionXLInpaintPipeline(
1220
1128
  aesthetic_score: float = 6.0,
1221
1129
  negative_aesthetic_score: float = 2.5,
1222
1130
  clip_skip: Optional[int] = None,
1223
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1131
+ callback_on_step_end: Optional[
1132
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
1133
+ ] = None,
1224
1134
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1225
1135
  **kwargs,
1226
1136
  ):
@@ -1253,11 +1163,12 @@ class StableDiffusionXLInpaintPipeline(
1253
1163
  [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1254
1164
  and checkpoints that are not specifically fine-tuned on low resolutions.
1255
1165
  padding_mask_crop (`int`, *optional*, defaults to `None`):
1256
- The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
1257
- `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
1258
- contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
1259
- the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
1260
- and contain information inreleant for inpainging, such as background.
1166
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
1167
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
1168
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
1169
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
1170
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
1171
+ the image is large and contain information irrelevant for inpainting, such as background.
1261
1172
  strength (`float`, *optional*, defaults to 0.9999):
1262
1173
  Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
1263
1174
  between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
@@ -1273,6 +1184,10 @@ class StableDiffusionXLInpaintPipeline(
1273
1184
  Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1274
1185
  in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1275
1186
  passed will be used. Must be in descending order.
1187
+ sigmas (`List[float]`, *optional*):
1188
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
1189
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
1190
+ will be used.
1276
1191
  denoising_start (`float`, *optional*):
1277
1192
  When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
1278
1193
  bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
@@ -1301,26 +1216,26 @@ class StableDiffusionXLInpaintPipeline(
1301
1216
  negative_prompt_2 (`str` or `List[str]`, *optional*):
1302
1217
  The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1303
1218
  `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1304
- prompt_embeds (`torch.FloatTensor`, *optional*):
1219
+ prompt_embeds (`torch.Tensor`, *optional*):
1305
1220
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1306
1221
  provided, text embeddings will be generated from `prompt` input argument.
1307
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1222
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
1308
1223
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1309
1224
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1310
1225
  argument.
1311
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1226
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
1312
1227
  Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1313
1228
  If not provided, pooled text embeddings will be generated from `prompt` input argument.
1314
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1229
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
1315
1230
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1316
1231
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1317
1232
  input argument.
1318
1233
  ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1319
- ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
1320
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
1321
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
1322
- if `do_classifier_free_guidance` is set to `True`.
1323
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
1234
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1235
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1236
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1237
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1238
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1324
1239
  num_images_per_prompt (`int`, *optional*, defaults to 1):
1325
1240
  The number of images to generate per prompt.
1326
1241
  eta (`float`, *optional*, defaults to 0.0):
@@ -1329,7 +1244,7 @@ class StableDiffusionXLInpaintPipeline(
1329
1244
  generator (`torch.Generator`, *optional*):
1330
1245
  One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1331
1246
  to make generation deterministic.
1332
- latents (`torch.FloatTensor`, *optional*):
1247
+ latents (`torch.Tensor`, *optional*):
1333
1248
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1334
1249
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1335
1250
  tensor will ge generated by sampling using the supplied random `generator`.
@@ -1383,11 +1298,11 @@ class StableDiffusionXLInpaintPipeline(
1383
1298
  clip_skip (`int`, *optional*):
1384
1299
  Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1385
1300
  the output of the pre-final layer will be used for computing the prompt embeddings.
1386
- callback_on_step_end (`Callable`, *optional*):
1387
- A function that calls at the end of each denoising steps during the inference. The function is called
1388
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1389
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1390
- `callback_on_step_end_tensor_inputs`.
1301
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1302
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1303
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1304
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1305
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1391
1306
  callback_on_step_end_tensor_inputs (`List`, *optional*):
1392
1307
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1393
1308
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
@@ -1417,6 +1332,9 @@ class StableDiffusionXLInpaintPipeline(
1417
1332
  "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1418
1333
  )
1419
1334
 
1335
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1336
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1337
+
1420
1338
  # 0. Default height and width to unet
1421
1339
  height = height or self.unet.config.sample_size * self.vae_scale_factor
1422
1340
  width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -1490,7 +1408,9 @@ class StableDiffusionXLInpaintPipeline(
1490
1408
  def denoising_value_valid(dnv):
1491
1409
  return isinstance(dnv, float) and 0 < dnv < 1
1492
1410
 
1493
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1411
+ timesteps, num_inference_steps = retrieve_timesteps(
1412
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1413
+ )
1494
1414
  timesteps, num_inference_steps = self.get_timesteps(
1495
1415
  num_inference_steps,
1496
1416
  strength,
@@ -1718,7 +1638,12 @@ class StableDiffusionXLInpaintPipeline(
1718
1638
  noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1719
1639
 
1720
1640
  # compute the previous noisy sample x_t -> x_t-1
1641
+ latents_dtype = latents.dtype
1721
1642
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1643
+ if latents.dtype != latents_dtype:
1644
+ if torch.backends.mps.is_available():
1645
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1646
+ latents = latents.to(latents_dtype)
1722
1647
 
1723
1648
  if num_channels_unet == 4:
1724
1649
  init_latents_proper = image_latents
@@ -1743,13 +1668,8 @@ class StableDiffusionXLInpaintPipeline(
1743
1668
 
1744
1669
  latents = callback_outputs.pop("latents", latents)
1745
1670
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1746
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1747
1671
  add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1748
- negative_pooled_prompt_embeds = callback_outputs.pop(
1749
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1750
- )
1751
1672
  add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1752
- add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
1753
1673
  mask = callback_outputs.pop("mask", mask)
1754
1674
  masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
1755
1675
 
@@ -1770,6 +1690,10 @@ class StableDiffusionXLInpaintPipeline(
1770
1690
  if needs_upcasting:
1771
1691
  self.upcast_vae()
1772
1692
  latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1693
+ elif latents.dtype != self.vae.dtype:
1694
+ if torch.backends.mps.is_available():
1695
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1696
+ self.vae = self.vae.to(latents.dtype)
1773
1697
 
1774
1698
  # unscale/denormalize the latents
1775
1699
  # denormalize with the mean and std if available and not None