diffusers 0.27.1__py3-none-any.whl → 0.32.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (445) hide show
  1. diffusers/__init__.py +233 -6
  2. diffusers/callbacks.py +209 -0
  3. diffusers/commands/env.py +102 -6
  4. diffusers/configuration_utils.py +45 -16
  5. diffusers/dependency_versions_table.py +4 -3
  6. diffusers/image_processor.py +434 -110
  7. diffusers/loaders/__init__.py +42 -9
  8. diffusers/loaders/ip_adapter.py +626 -36
  9. diffusers/loaders/lora_base.py +900 -0
  10. diffusers/loaders/lora_conversion_utils.py +991 -125
  11. diffusers/loaders/lora_pipeline.py +3812 -0
  12. diffusers/loaders/peft.py +571 -7
  13. diffusers/loaders/single_file.py +405 -173
  14. diffusers/loaders/single_file_model.py +385 -0
  15. diffusers/loaders/single_file_utils.py +1783 -713
  16. diffusers/loaders/textual_inversion.py +41 -23
  17. diffusers/loaders/transformer_flux.py +181 -0
  18. diffusers/loaders/transformer_sd3.py +89 -0
  19. diffusers/loaders/unet.py +464 -540
  20. diffusers/loaders/unet_loader_utils.py +163 -0
  21. diffusers/models/__init__.py +76 -7
  22. diffusers/models/activations.py +65 -10
  23. diffusers/models/adapter.py +53 -53
  24. diffusers/models/attention.py +605 -18
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +4304 -687
  27. diffusers/models/autoencoders/__init__.py +8 -0
  28. diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
  29. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  30. diffusers/models/autoencoders/autoencoder_kl.py +110 -28
  31. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  32. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  33. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  34. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  35. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  36. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
  37. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  38. diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
  39. diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
  40. diffusers/models/autoencoders/vae.py +41 -29
  41. diffusers/models/autoencoders/vq_model.py +182 -0
  42. diffusers/models/controlnet.py +47 -800
  43. diffusers/models/controlnet_flux.py +70 -0
  44. diffusers/models/controlnet_sd3.py +68 -0
  45. diffusers/models/controlnet_sparsectrl.py +116 -0
  46. diffusers/models/controlnets/__init__.py +23 -0
  47. diffusers/models/controlnets/controlnet.py +872 -0
  48. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
  49. diffusers/models/controlnets/controlnet_flux.py +536 -0
  50. diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
  51. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  52. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  53. diffusers/models/controlnets/controlnet_union.py +832 -0
  54. diffusers/models/controlnets/controlnet_xs.py +1946 -0
  55. diffusers/models/controlnets/multicontrolnet.py +183 -0
  56. diffusers/models/downsampling.py +85 -18
  57. diffusers/models/embeddings.py +1856 -158
  58. diffusers/models/embeddings_flax.py +23 -9
  59. diffusers/models/model_loading_utils.py +480 -0
  60. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  61. diffusers/models/modeling_flax_utils.py +2 -7
  62. diffusers/models/modeling_outputs.py +14 -0
  63. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  64. diffusers/models/modeling_utils.py +611 -146
  65. diffusers/models/normalization.py +361 -20
  66. diffusers/models/resnet.py +18 -23
  67. diffusers/models/transformers/__init__.py +16 -0
  68. diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
  69. diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
  70. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  71. diffusers/models/transformers/dual_transformer_2d.py +9 -8
  72. diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
  73. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  74. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  75. diffusers/models/transformers/pixart_transformer_2d.py +445 -0
  76. diffusers/models/transformers/prior_transformer.py +13 -13
  77. diffusers/models/transformers/sana_transformer.py +488 -0
  78. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  79. diffusers/models/transformers/t5_film_transformer.py +17 -19
  80. diffusers/models/transformers/transformer_2d.py +297 -187
  81. diffusers/models/transformers/transformer_allegro.py +422 -0
  82. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  83. diffusers/models/transformers/transformer_flux.py +593 -0
  84. diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
  85. diffusers/models/transformers/transformer_ltx.py +469 -0
  86. diffusers/models/transformers/transformer_mochi.py +499 -0
  87. diffusers/models/transformers/transformer_sd3.py +461 -0
  88. diffusers/models/transformers/transformer_temporal.py +21 -19
  89. diffusers/models/unets/unet_1d.py +8 -8
  90. diffusers/models/unets/unet_1d_blocks.py +31 -31
  91. diffusers/models/unets/unet_2d.py +17 -10
  92. diffusers/models/unets/unet_2d_blocks.py +225 -149
  93. diffusers/models/unets/unet_2d_condition.py +41 -40
  94. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  95. diffusers/models/unets/unet_3d_blocks.py +192 -1057
  96. diffusers/models/unets/unet_3d_condition.py +22 -27
  97. diffusers/models/unets/unet_i2vgen_xl.py +22 -18
  98. diffusers/models/unets/unet_kandinsky3.py +2 -2
  99. diffusers/models/unets/unet_motion_model.py +1413 -89
  100. diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
  101. diffusers/models/unets/unet_stable_cascade.py +19 -18
  102. diffusers/models/unets/uvit_2d.py +2 -2
  103. diffusers/models/upsampling.py +95 -26
  104. diffusers/models/vq_model.py +12 -164
  105. diffusers/optimization.py +1 -1
  106. diffusers/pipelines/__init__.py +202 -3
  107. diffusers/pipelines/allegro/__init__.py +48 -0
  108. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  109. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  110. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  111. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  112. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  113. diffusers/pipelines/animatediff/__init__.py +8 -0
  114. diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
  115. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
  116. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
  117. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
  118. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
  119. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  120. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  121. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  122. diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
  123. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
  124. diffusers/pipelines/aura_flow/__init__.py +48 -0
  125. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
  126. diffusers/pipelines/auto_pipeline.py +196 -28
  127. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  128. diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
  129. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  130. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  131. diffusers/pipelines/cogvideo/__init__.py +54 -0
  132. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
  133. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  134. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
  135. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
  136. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  137. diffusers/pipelines/cogview3/__init__.py +47 -0
  138. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  139. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  140. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
  141. diffusers/pipelines/controlnet/__init__.py +86 -80
  142. diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
  143. diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
  144. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  145. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
  146. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
  147. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
  148. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
  149. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
  150. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  151. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  152. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  153. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  154. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  155. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
  156. diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
  157. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
  158. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  159. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  160. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
  161. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
  162. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  163. diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
  164. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
  165. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
  166. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
  167. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
  168. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
  169. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  170. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  171. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  172. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
  173. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
  174. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  175. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  176. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
  177. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
  178. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
  179. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
  180. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
  181. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
  182. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  183. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  184. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  185. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  186. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  187. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  188. diffusers/pipelines/flux/__init__.py +69 -0
  189. diffusers/pipelines/flux/modeling_flux.py +47 -0
  190. diffusers/pipelines/flux/pipeline_flux.py +957 -0
  191. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  192. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  193. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  194. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  195. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  196. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  197. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  198. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  199. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  200. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  201. diffusers/pipelines/flux/pipeline_output.py +37 -0
  202. diffusers/pipelines/free_init_utils.py +41 -38
  203. diffusers/pipelines/free_noise_utils.py +596 -0
  204. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  205. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  206. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  207. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  208. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
  209. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  210. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  211. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
  212. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  213. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  214. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  215. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  216. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
  217. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  218. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  219. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  220. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  221. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  222. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  223. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  224. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
  225. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
  226. diffusers/pipelines/kolors/__init__.py +54 -0
  227. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  228. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
  229. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  230. diffusers/pipelines/kolors/text_encoder.py +889 -0
  231. diffusers/pipelines/kolors/tokenizer.py +338 -0
  232. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
  233. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
  234. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
  235. diffusers/pipelines/latte/__init__.py +48 -0
  236. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  237. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
  238. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
  239. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  240. diffusers/pipelines/ltx/__init__.py +50 -0
  241. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  242. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  243. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  244. diffusers/pipelines/lumina/__init__.py +48 -0
  245. diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
  246. diffusers/pipelines/marigold/__init__.py +50 -0
  247. diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
  248. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  249. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  250. diffusers/pipelines/mochi/__init__.py +48 -0
  251. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  252. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  253. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  254. diffusers/pipelines/pag/__init__.py +80 -0
  255. diffusers/pipelines/pag/pag_utils.py +243 -0
  256. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
  257. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  258. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
  259. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  260. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
  261. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  262. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
  263. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  264. diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
  265. diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
  266. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  267. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
  268. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  269. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  270. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
  271. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
  272. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
  273. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  274. diffusers/pipelines/pia/pipeline_pia.py +74 -164
  275. diffusers/pipelines/pipeline_flax_utils.py +5 -10
  276. diffusers/pipelines/pipeline_loading_utils.py +515 -53
  277. diffusers/pipelines/pipeline_utils.py +411 -222
  278. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  279. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
  280. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
  281. diffusers/pipelines/sana/__init__.py +47 -0
  282. diffusers/pipelines/sana/pipeline_output.py +21 -0
  283. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  284. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
  285. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  286. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  287. diffusers/pipelines/shap_e/renderer.py +1 -1
  288. diffusers/pipelines/stable_audio/__init__.py +50 -0
  289. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  290. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
  291. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
  292. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  293. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
  294. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  295. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
  296. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  297. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  298. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  299. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
  300. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
  301. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  302. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
  303. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
  304. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
  305. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
  306. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
  307. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
  308. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
  309. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  310. diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  311. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  312. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  313. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
  314. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
  315. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
  316. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
  317. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
  318. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
  319. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
  320. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
  321. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
  322. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
  323. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  324. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  325. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
  326. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
  327. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
  328. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
  329. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
  330. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  331. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  332. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
  333. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
  334. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  335. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
  336. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
  337. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
  338. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
  339. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  340. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  341. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  342. diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
  343. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
  344. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  345. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  346. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
  347. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  348. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  349. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
  350. diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
  351. diffusers/quantizers/auto.py +139 -0
  352. diffusers/quantizers/base.py +233 -0
  353. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  354. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  355. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  356. diffusers/quantizers/gguf/__init__.py +1 -0
  357. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  358. diffusers/quantizers/gguf/utils.py +456 -0
  359. diffusers/quantizers/quantization_config.py +669 -0
  360. diffusers/quantizers/torchao/__init__.py +15 -0
  361. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  362. diffusers/schedulers/__init__.py +12 -2
  363. diffusers/schedulers/deprecated/__init__.py +1 -1
  364. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  365. diffusers/schedulers/scheduling_amused.py +5 -5
  366. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  367. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  368. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  369. diffusers/schedulers/scheduling_ddim.py +27 -26
  370. diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
  371. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  372. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  373. diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
  374. diffusers/schedulers/scheduling_ddpm.py +27 -30
  375. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  376. diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
  377. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  378. diffusers/schedulers/scheduling_deis_multistep.py +150 -50
  379. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  380. diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
  381. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  382. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
  383. diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
  384. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
  385. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
  386. diffusers/schedulers/scheduling_edm_euler.py +62 -39
  387. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
  388. diffusers/schedulers/scheduling_euler_discrete.py +255 -74
  389. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
  390. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
  391. diffusers/schedulers/scheduling_heun_discrete.py +174 -46
  392. diffusers/schedulers/scheduling_ipndm.py +9 -9
  393. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
  394. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
  395. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  396. diffusers/schedulers/scheduling_lcm.py +23 -29
  397. diffusers/schedulers/scheduling_lms_discrete.py +105 -28
  398. diffusers/schedulers/scheduling_pndm.py +20 -20
  399. diffusers/schedulers/scheduling_repaint.py +21 -21
  400. diffusers/schedulers/scheduling_sasolver.py +157 -60
  401. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  402. diffusers/schedulers/scheduling_tcd.py +41 -36
  403. diffusers/schedulers/scheduling_unclip.py +19 -16
  404. diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
  405. diffusers/schedulers/scheduling_utils.py +12 -5
  406. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  407. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  408. diffusers/training_utils.py +214 -30
  409. diffusers/utils/__init__.py +17 -1
  410. diffusers/utils/constants.py +3 -0
  411. diffusers/utils/doc_utils.py +1 -0
  412. diffusers/utils/dummy_pt_objects.py +592 -7
  413. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  414. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  415. diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
  416. diffusers/utils/dynamic_modules_utils.py +34 -29
  417. diffusers/utils/export_utils.py +50 -6
  418. diffusers/utils/hub_utils.py +131 -17
  419. diffusers/utils/import_utils.py +210 -8
  420. diffusers/utils/loading_utils.py +118 -5
  421. diffusers/utils/logging.py +4 -2
  422. diffusers/utils/peft_utils.py +37 -7
  423. diffusers/utils/state_dict_utils.py +13 -2
  424. diffusers/utils/testing_utils.py +193 -11
  425. diffusers/utils/torch_utils.py +4 -0
  426. diffusers/video_processor.py +113 -0
  427. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
  428. diffusers-0.32.2.dist-info/RECORD +550 -0
  429. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
  430. diffusers/loaders/autoencoder.py +0 -146
  431. diffusers/loaders/controlnet.py +0 -136
  432. diffusers/loaders/lora.py +0 -1349
  433. diffusers/models/prior_transformer.py +0 -12
  434. diffusers/models/t5_film_transformer.py +0 -70
  435. diffusers/models/transformer_2d.py +0 -25
  436. diffusers/models/transformer_temporal.py +0 -34
  437. diffusers/models/unet_1d.py +0 -26
  438. diffusers/models/unet_1d_blocks.py +0 -203
  439. diffusers/models/unet_2d.py +0 -27
  440. diffusers/models/unet_2d_blocks.py +0 -375
  441. diffusers/models/unet_2d_condition.py +0 -25
  442. diffusers-0.27.1.dist-info/RECORD +0 -399
  443. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
  444. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
  445. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -13,13 +13,38 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import math
16
+ from dataclasses import dataclass
16
17
  from typing import List, Optional, Tuple, Union
17
18
 
18
19
  import numpy as np
19
20
  import torch
20
21
 
21
22
  from ..configuration_utils import ConfigMixin, register_to_config
22
- from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
23
+ from ..utils import BaseOutput, is_scipy_available
24
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
25
+
26
+
27
+ if is_scipy_available():
28
+ import scipy.stats
29
+
30
+
31
+ @dataclass
32
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->HeunDiscrete
33
+ class HeunDiscreteSchedulerOutput(BaseOutput):
34
+ """
35
+ Output class for the scheduler's `step` function output.
36
+
37
+ Args:
38
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40
+ denoising loop.
41
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43
+ `pred_original_sample` can be used to preview progress or for guidance.
44
+ """
45
+
46
+ prev_sample: torch.Tensor
47
+ pred_original_sample: Optional[torch.Tensor] = None
23
48
 
24
49
 
25
50
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
@@ -57,7 +82,7 @@ def betas_for_alpha_bar(
57
82
  return math.exp(t * -12.0)
58
83
 
59
84
  else:
60
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
85
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
61
86
 
62
87
  betas = []
63
88
  for i in range(num_diffusion_timesteps):
@@ -97,6 +122,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
97
122
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
98
123
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
99
124
  the sigmas are determined according to a sequence of noise levels {σi}.
125
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
126
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
127
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
128
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
129
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
100
130
  timestep_spacing (`str`, defaults to `"linspace"`):
101
131
  The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
102
132
  Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -117,11 +147,19 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
117
147
  trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
118
148
  prediction_type: str = "epsilon",
119
149
  use_karras_sigmas: Optional[bool] = False,
150
+ use_exponential_sigmas: Optional[bool] = False,
151
+ use_beta_sigmas: Optional[bool] = False,
120
152
  clip_sample: Optional[bool] = False,
121
153
  clip_sample_range: float = 1.0,
122
154
  timestep_spacing: str = "linspace",
123
155
  steps_offset: int = 0,
124
156
  ):
157
+ if self.config.use_beta_sigmas and not is_scipy_available():
158
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
159
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
160
+ raise ValueError(
161
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
162
+ )
125
163
  if trained_betas is not None:
126
164
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
127
165
  elif beta_schedule == "linear":
@@ -135,7 +173,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
135
173
  elif beta_schedule == "exp":
136
174
  self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp")
137
175
  else:
138
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
176
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
139
177
 
140
178
  self.alphas = 1.0 - self.betas
141
179
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@@ -174,7 +212,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
174
212
  @property
175
213
  def step_index(self):
176
214
  """
177
- The index counter for current timestep. It will increae 1 after each scheduler step.
215
+ The index counter for current timestep. It will increase 1 after each scheduler step.
178
216
  """
179
217
  return self._step_index
180
218
 
@@ -198,21 +236,21 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
198
236
 
199
237
  def scale_model_input(
200
238
  self,
201
- sample: torch.FloatTensor,
202
- timestep: Union[float, torch.FloatTensor],
203
- ) -> torch.FloatTensor:
239
+ sample: torch.Tensor,
240
+ timestep: Union[float, torch.Tensor],
241
+ ) -> torch.Tensor:
204
242
  """
205
243
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
206
244
  current timestep.
207
245
 
208
246
  Args:
209
- sample (`torch.FloatTensor`):
247
+ sample (`torch.Tensor`):
210
248
  The input sample.
211
249
  timestep (`int`, *optional*):
212
250
  The current timestep in the diffusion chain.
213
251
 
214
252
  Returns:
215
- `torch.FloatTensor`:
253
+ `torch.Tensor`:
216
254
  A scaled input sample.
217
255
  """
218
256
  if self.step_index is None:
@@ -224,9 +262,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
224
262
 
225
263
  def set_timesteps(
226
264
  self,
227
- num_inference_steps: int,
265
+ num_inference_steps: Optional[int] = None,
228
266
  device: Union[str, torch.device] = None,
229
267
  num_train_timesteps: Optional[int] = None,
268
+ timesteps: Optional[List[int]] = None,
230
269
  ):
231
270
  """
232
271
  Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -236,30 +275,51 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
236
275
  The number of diffusion steps used when generating samples with a pre-trained model.
237
276
  device (`str` or `torch.device`, *optional*):
238
277
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
278
+ num_train_timesteps (`int`, *optional*):
279
+ The number of diffusion steps used when training the model. If `None`, the default
280
+ `num_train_timesteps` attribute is used.
281
+ timesteps (`List[int]`, *optional*):
282
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, timesteps will be
283
+ generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps`
284
+ must be `None`, and `timestep_spacing` attribute will be ignored.
239
285
  """
286
+ if num_inference_steps is None and timesteps is None:
287
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
288
+ if num_inference_steps is not None and timesteps is not None:
289
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
290
+ if timesteps is not None and self.config.use_karras_sigmas:
291
+ raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
292
+ if timesteps is not None and self.config.use_exponential_sigmas:
293
+ raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
294
+ if timesteps is not None and self.config.use_beta_sigmas:
295
+ raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
296
+
297
+ num_inference_steps = num_inference_steps or len(timesteps)
240
298
  self.num_inference_steps = num_inference_steps
241
-
242
299
  num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
243
300
 
244
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
245
- if self.config.timestep_spacing == "linspace":
246
- timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
247
- elif self.config.timestep_spacing == "leading":
248
- step_ratio = num_train_timesteps // self.num_inference_steps
249
- # creates integer timesteps by multiplying by ratio
250
- # casting to int to avoid issues when num_inference_step is power of 3
251
- timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
252
- timesteps += self.config.steps_offset
253
- elif self.config.timestep_spacing == "trailing":
254
- step_ratio = num_train_timesteps / self.num_inference_steps
255
- # creates integer timesteps by multiplying by ratio
256
- # casting to int to avoid issues when num_inference_step is power of 3
257
- timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
258
- timesteps -= 1
301
+ if timesteps is not None:
302
+ timesteps = np.array(timesteps, dtype=np.float32)
259
303
  else:
260
- raise ValueError(
261
- f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
262
- )
304
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
305
+ if self.config.timestep_spacing == "linspace":
306
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
307
+ elif self.config.timestep_spacing == "leading":
308
+ step_ratio = num_train_timesteps // self.num_inference_steps
309
+ # creates integer timesteps by multiplying by ratio
310
+ # casting to int to avoid issues when num_inference_step is power of 3
311
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
312
+ timesteps += self.config.steps_offset
313
+ elif self.config.timestep_spacing == "trailing":
314
+ step_ratio = num_train_timesteps / self.num_inference_steps
315
+ # creates integer timesteps by multiplying by ratio
316
+ # casting to int to avoid issues when num_inference_step is power of 3
317
+ timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
318
+ timesteps -= 1
319
+ else:
320
+ raise ValueError(
321
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
322
+ )
263
323
 
264
324
  sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
265
325
  log_sigmas = np.log(sigmas)
@@ -268,6 +328,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
268
328
  if self.config.use_karras_sigmas:
269
329
  sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
270
330
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
331
+ elif self.config.use_exponential_sigmas:
332
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
333
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
334
+ elif self.config.use_beta_sigmas:
335
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
336
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
271
337
 
272
338
  sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
273
339
  sigmas = torch.from_numpy(sigmas).to(device=device)
@@ -311,7 +377,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
311
377
  return t
312
378
 
313
379
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
314
- def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
380
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
315
381
  """Constructs the noise schedule of Karras et al. (2022)."""
316
382
 
317
383
  # Hack to make sure that other schedulers which copy this function don't break
@@ -336,6 +402,60 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
336
402
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
337
403
  return sigmas
338
404
 
405
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
406
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
407
+ """Constructs an exponential noise schedule."""
408
+
409
+ # Hack to make sure that other schedulers which copy this function don't break
410
+ # TODO: Add this logic to the other schedulers
411
+ if hasattr(self.config, "sigma_min"):
412
+ sigma_min = self.config.sigma_min
413
+ else:
414
+ sigma_min = None
415
+
416
+ if hasattr(self.config, "sigma_max"):
417
+ sigma_max = self.config.sigma_max
418
+ else:
419
+ sigma_max = None
420
+
421
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
422
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
423
+
424
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
425
+ return sigmas
426
+
427
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
428
+ def _convert_to_beta(
429
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
430
+ ) -> torch.Tensor:
431
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
432
+
433
+ # Hack to make sure that other schedulers which copy this function don't break
434
+ # TODO: Add this logic to the other schedulers
435
+ if hasattr(self.config, "sigma_min"):
436
+ sigma_min = self.config.sigma_min
437
+ else:
438
+ sigma_min = None
439
+
440
+ if hasattr(self.config, "sigma_max"):
441
+ sigma_max = self.config.sigma_max
442
+ else:
443
+ sigma_max = None
444
+
445
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
446
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
447
+
448
+ sigmas = np.array(
449
+ [
450
+ sigma_min + (ppf * (sigma_max - sigma_min))
451
+ for ppf in [
452
+ scipy.stats.beta.ppf(timestep, alpha, beta)
453
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
454
+ ]
455
+ ]
456
+ )
457
+ return sigmas
458
+
339
459
  @property
340
460
  def state_in_first_order(self):
341
461
  return self.dt is None
@@ -351,29 +471,30 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
351
471
 
352
472
  def step(
353
473
  self,
354
- model_output: Union[torch.FloatTensor, np.ndarray],
355
- timestep: Union[float, torch.FloatTensor],
356
- sample: Union[torch.FloatTensor, np.ndarray],
474
+ model_output: Union[torch.Tensor, np.ndarray],
475
+ timestep: Union[float, torch.Tensor],
476
+ sample: Union[torch.Tensor, np.ndarray],
357
477
  return_dict: bool = True,
358
- ) -> Union[SchedulerOutput, Tuple]:
478
+ ) -> Union[HeunDiscreteSchedulerOutput, Tuple]:
359
479
  """
360
480
  Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
361
481
  process from the learned model outputs (most often the predicted noise).
362
482
 
363
483
  Args:
364
- model_output (`torch.FloatTensor`):
484
+ model_output (`torch.Tensor`):
365
485
  The direct output from learned diffusion model.
366
486
  timestep (`float`):
367
487
  The current discrete timestep in the diffusion chain.
368
- sample (`torch.FloatTensor`):
488
+ sample (`torch.Tensor`):
369
489
  A current instance of a sample created by the diffusion process.
370
490
  return_dict (`bool`):
371
- Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
491
+ Whether or not to return a [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or
492
+ tuple.
372
493
 
373
494
  Returns:
374
- [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
375
- If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
376
- tuple is returned where the first element is the sample tensor.
495
+ [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
496
+ If return_dict is `True`, [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] is
497
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
377
498
  """
378
499
  if self.step_index is None:
379
500
  self._init_step_index(timestep)
@@ -444,17 +565,20 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
444
565
  self._step_index += 1
445
566
 
446
567
  if not return_dict:
447
- return (prev_sample,)
568
+ return (
569
+ prev_sample,
570
+ pred_original_sample,
571
+ )
448
572
 
449
- return SchedulerOutput(prev_sample=prev_sample)
573
+ return HeunDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
450
574
 
451
575
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
452
576
  def add_noise(
453
577
  self,
454
- original_samples: torch.FloatTensor,
455
- noise: torch.FloatTensor,
456
- timesteps: torch.FloatTensor,
457
- ) -> torch.FloatTensor:
578
+ original_samples: torch.Tensor,
579
+ noise: torch.Tensor,
580
+ timesteps: torch.Tensor,
581
+ ) -> torch.Tensor:
458
582
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
459
583
  sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
460
584
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -468,7 +592,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
468
592
  # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
469
593
  if self.begin_index is None:
470
594
  step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
595
+ elif self.step_index is not None:
596
+ # add_noise is called after first denoising step (for inpainting)
597
+ step_indices = [self.step_index] * timesteps.shape[0]
471
598
  else:
599
+ # add noise is called before first denoising step to create initial latent(img2img)
472
600
  step_indices = [self.begin_index] * timesteps.shape[0]
473
601
 
474
602
  sigma = sigmas[step_indices].flatten()
@@ -61,7 +61,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
61
61
  @property
62
62
  def step_index(self):
63
63
  """
64
- The index counter for current timestep. It will increae 1 after each scheduler step.
64
+ The index counter for current timestep. It will increase 1 after each scheduler step.
65
65
  """
66
66
  return self._step_index
67
67
 
@@ -137,9 +137,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
137
137
 
138
138
  def step(
139
139
  self,
140
- model_output: torch.FloatTensor,
141
- timestep: int,
142
- sample: torch.FloatTensor,
140
+ model_output: torch.Tensor,
141
+ timestep: Union[int, torch.Tensor],
142
+ sample: torch.Tensor,
143
143
  return_dict: bool = True,
144
144
  ) -> Union[SchedulerOutput, Tuple]:
145
145
  """
@@ -147,11 +147,11 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
147
147
  the linear multistep method. It performs one forward pass multiple times to approximate the solution.
148
148
 
149
149
  Args:
150
- model_output (`torch.FloatTensor`):
150
+ model_output (`torch.Tensor`):
151
151
  The direct output from learned diffusion model.
152
152
  timestep (`int`):
153
153
  The current discrete timestep in the diffusion chain.
154
- sample (`torch.FloatTensor`):
154
+ sample (`torch.Tensor`):
155
155
  A current instance of a sample created by the diffusion process.
156
156
  return_dict (`bool`):
157
157
  Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
@@ -193,17 +193,17 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
193
193
 
194
194
  return SchedulerOutput(prev_sample=prev_sample)
195
195
 
196
- def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
196
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
197
197
  """
198
198
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
199
199
  current timestep.
200
200
 
201
201
  Args:
202
- sample (`torch.FloatTensor`):
202
+ sample (`torch.Tensor`):
203
203
  The input sample.
204
204
 
205
205
  Returns:
206
- `torch.FloatTensor`:
206
+ `torch.Tensor`:
207
207
  A scaled input sample.
208
208
  """
209
209
  return sample