diffusers 0.27.1__py3-none-any.whl → 0.32.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (445) hide show
  1. diffusers/__init__.py +233 -6
  2. diffusers/callbacks.py +209 -0
  3. diffusers/commands/env.py +102 -6
  4. diffusers/configuration_utils.py +45 -16
  5. diffusers/dependency_versions_table.py +4 -3
  6. diffusers/image_processor.py +434 -110
  7. diffusers/loaders/__init__.py +42 -9
  8. diffusers/loaders/ip_adapter.py +626 -36
  9. diffusers/loaders/lora_base.py +900 -0
  10. diffusers/loaders/lora_conversion_utils.py +991 -125
  11. diffusers/loaders/lora_pipeline.py +3812 -0
  12. diffusers/loaders/peft.py +571 -7
  13. diffusers/loaders/single_file.py +405 -173
  14. diffusers/loaders/single_file_model.py +385 -0
  15. diffusers/loaders/single_file_utils.py +1783 -713
  16. diffusers/loaders/textual_inversion.py +41 -23
  17. diffusers/loaders/transformer_flux.py +181 -0
  18. diffusers/loaders/transformer_sd3.py +89 -0
  19. diffusers/loaders/unet.py +464 -540
  20. diffusers/loaders/unet_loader_utils.py +163 -0
  21. diffusers/models/__init__.py +76 -7
  22. diffusers/models/activations.py +65 -10
  23. diffusers/models/adapter.py +53 -53
  24. diffusers/models/attention.py +605 -18
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +4304 -687
  27. diffusers/models/autoencoders/__init__.py +8 -0
  28. diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
  29. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  30. diffusers/models/autoencoders/autoencoder_kl.py +110 -28
  31. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  32. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  33. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  34. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  35. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  36. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
  37. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  38. diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
  39. diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
  40. diffusers/models/autoencoders/vae.py +41 -29
  41. diffusers/models/autoencoders/vq_model.py +182 -0
  42. diffusers/models/controlnet.py +47 -800
  43. diffusers/models/controlnet_flux.py +70 -0
  44. diffusers/models/controlnet_sd3.py +68 -0
  45. diffusers/models/controlnet_sparsectrl.py +116 -0
  46. diffusers/models/controlnets/__init__.py +23 -0
  47. diffusers/models/controlnets/controlnet.py +872 -0
  48. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
  49. diffusers/models/controlnets/controlnet_flux.py +536 -0
  50. diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
  51. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  52. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  53. diffusers/models/controlnets/controlnet_union.py +832 -0
  54. diffusers/models/controlnets/controlnet_xs.py +1946 -0
  55. diffusers/models/controlnets/multicontrolnet.py +183 -0
  56. diffusers/models/downsampling.py +85 -18
  57. diffusers/models/embeddings.py +1856 -158
  58. diffusers/models/embeddings_flax.py +23 -9
  59. diffusers/models/model_loading_utils.py +480 -0
  60. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  61. diffusers/models/modeling_flax_utils.py +2 -7
  62. diffusers/models/modeling_outputs.py +14 -0
  63. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  64. diffusers/models/modeling_utils.py +611 -146
  65. diffusers/models/normalization.py +361 -20
  66. diffusers/models/resnet.py +18 -23
  67. diffusers/models/transformers/__init__.py +16 -0
  68. diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
  69. diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
  70. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  71. diffusers/models/transformers/dual_transformer_2d.py +9 -8
  72. diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
  73. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  74. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  75. diffusers/models/transformers/pixart_transformer_2d.py +445 -0
  76. diffusers/models/transformers/prior_transformer.py +13 -13
  77. diffusers/models/transformers/sana_transformer.py +488 -0
  78. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  79. diffusers/models/transformers/t5_film_transformer.py +17 -19
  80. diffusers/models/transformers/transformer_2d.py +297 -187
  81. diffusers/models/transformers/transformer_allegro.py +422 -0
  82. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  83. diffusers/models/transformers/transformer_flux.py +593 -0
  84. diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
  85. diffusers/models/transformers/transformer_ltx.py +469 -0
  86. diffusers/models/transformers/transformer_mochi.py +499 -0
  87. diffusers/models/transformers/transformer_sd3.py +461 -0
  88. diffusers/models/transformers/transformer_temporal.py +21 -19
  89. diffusers/models/unets/unet_1d.py +8 -8
  90. diffusers/models/unets/unet_1d_blocks.py +31 -31
  91. diffusers/models/unets/unet_2d.py +17 -10
  92. diffusers/models/unets/unet_2d_blocks.py +225 -149
  93. diffusers/models/unets/unet_2d_condition.py +41 -40
  94. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  95. diffusers/models/unets/unet_3d_blocks.py +192 -1057
  96. diffusers/models/unets/unet_3d_condition.py +22 -27
  97. diffusers/models/unets/unet_i2vgen_xl.py +22 -18
  98. diffusers/models/unets/unet_kandinsky3.py +2 -2
  99. diffusers/models/unets/unet_motion_model.py +1413 -89
  100. diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
  101. diffusers/models/unets/unet_stable_cascade.py +19 -18
  102. diffusers/models/unets/uvit_2d.py +2 -2
  103. diffusers/models/upsampling.py +95 -26
  104. diffusers/models/vq_model.py +12 -164
  105. diffusers/optimization.py +1 -1
  106. diffusers/pipelines/__init__.py +202 -3
  107. diffusers/pipelines/allegro/__init__.py +48 -0
  108. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  109. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  110. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  111. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  112. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  113. diffusers/pipelines/animatediff/__init__.py +8 -0
  114. diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
  115. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
  116. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
  117. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
  118. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
  119. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  120. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  121. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  122. diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
  123. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
  124. diffusers/pipelines/aura_flow/__init__.py +48 -0
  125. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
  126. diffusers/pipelines/auto_pipeline.py +196 -28
  127. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  128. diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
  129. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  130. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  131. diffusers/pipelines/cogvideo/__init__.py +54 -0
  132. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
  133. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  134. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
  135. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
  136. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  137. diffusers/pipelines/cogview3/__init__.py +47 -0
  138. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  139. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  140. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
  141. diffusers/pipelines/controlnet/__init__.py +86 -80
  142. diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
  143. diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
  144. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  145. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
  146. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
  147. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
  148. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
  149. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
  150. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  151. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  152. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  153. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  154. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  155. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
  156. diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
  157. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
  158. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  159. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  160. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
  161. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
  162. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  163. diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
  164. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
  165. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
  166. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
  167. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
  168. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
  169. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  170. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  171. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  172. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
  173. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
  174. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  175. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  176. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
  177. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
  178. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
  179. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
  180. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
  181. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
  182. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  183. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  184. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  185. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  186. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  187. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  188. diffusers/pipelines/flux/__init__.py +69 -0
  189. diffusers/pipelines/flux/modeling_flux.py +47 -0
  190. diffusers/pipelines/flux/pipeline_flux.py +957 -0
  191. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  192. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  193. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  194. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  195. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  196. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  197. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  198. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  199. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  200. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  201. diffusers/pipelines/flux/pipeline_output.py +37 -0
  202. diffusers/pipelines/free_init_utils.py +41 -38
  203. diffusers/pipelines/free_noise_utils.py +596 -0
  204. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  205. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  206. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  207. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  208. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
  209. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  210. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  211. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
  212. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  213. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  214. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  215. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  216. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
  217. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  218. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  219. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  220. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  221. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  222. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  223. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  224. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
  225. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
  226. diffusers/pipelines/kolors/__init__.py +54 -0
  227. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  228. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
  229. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  230. diffusers/pipelines/kolors/text_encoder.py +889 -0
  231. diffusers/pipelines/kolors/tokenizer.py +338 -0
  232. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
  233. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
  234. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
  235. diffusers/pipelines/latte/__init__.py +48 -0
  236. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  237. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
  238. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
  239. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  240. diffusers/pipelines/ltx/__init__.py +50 -0
  241. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  242. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  243. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  244. diffusers/pipelines/lumina/__init__.py +48 -0
  245. diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
  246. diffusers/pipelines/marigold/__init__.py +50 -0
  247. diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
  248. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  249. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  250. diffusers/pipelines/mochi/__init__.py +48 -0
  251. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  252. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  253. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  254. diffusers/pipelines/pag/__init__.py +80 -0
  255. diffusers/pipelines/pag/pag_utils.py +243 -0
  256. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
  257. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  258. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
  259. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  260. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
  261. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  262. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
  263. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  264. diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
  265. diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
  266. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  267. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
  268. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  269. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  270. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
  271. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
  272. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
  273. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  274. diffusers/pipelines/pia/pipeline_pia.py +74 -164
  275. diffusers/pipelines/pipeline_flax_utils.py +5 -10
  276. diffusers/pipelines/pipeline_loading_utils.py +515 -53
  277. diffusers/pipelines/pipeline_utils.py +411 -222
  278. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  279. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
  280. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
  281. diffusers/pipelines/sana/__init__.py +47 -0
  282. diffusers/pipelines/sana/pipeline_output.py +21 -0
  283. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  284. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
  285. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  286. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  287. diffusers/pipelines/shap_e/renderer.py +1 -1
  288. diffusers/pipelines/stable_audio/__init__.py +50 -0
  289. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  290. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
  291. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
  292. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  293. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
  294. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  295. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
  296. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  297. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  298. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  299. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
  300. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
  301. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  302. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
  303. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
  304. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
  305. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
  306. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
  307. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
  308. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
  309. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  310. diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  311. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  312. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  313. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
  314. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
  315. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
  316. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
  317. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
  318. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
  319. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
  320. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
  321. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
  322. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
  323. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  324. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  325. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
  326. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
  327. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
  328. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
  329. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
  330. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  331. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  332. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
  333. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
  334. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  335. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
  336. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
  337. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
  338. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
  339. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  340. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  341. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  342. diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
  343. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
  344. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  345. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  346. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
  347. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  348. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  349. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
  350. diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
  351. diffusers/quantizers/auto.py +139 -0
  352. diffusers/quantizers/base.py +233 -0
  353. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  354. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  355. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  356. diffusers/quantizers/gguf/__init__.py +1 -0
  357. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  358. diffusers/quantizers/gguf/utils.py +456 -0
  359. diffusers/quantizers/quantization_config.py +669 -0
  360. diffusers/quantizers/torchao/__init__.py +15 -0
  361. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  362. diffusers/schedulers/__init__.py +12 -2
  363. diffusers/schedulers/deprecated/__init__.py +1 -1
  364. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  365. diffusers/schedulers/scheduling_amused.py +5 -5
  366. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  367. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  368. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  369. diffusers/schedulers/scheduling_ddim.py +27 -26
  370. diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
  371. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  372. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  373. diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
  374. diffusers/schedulers/scheduling_ddpm.py +27 -30
  375. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  376. diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
  377. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  378. diffusers/schedulers/scheduling_deis_multistep.py +150 -50
  379. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  380. diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
  381. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  382. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
  383. diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
  384. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
  385. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
  386. diffusers/schedulers/scheduling_edm_euler.py +62 -39
  387. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
  388. diffusers/schedulers/scheduling_euler_discrete.py +255 -74
  389. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
  390. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
  391. diffusers/schedulers/scheduling_heun_discrete.py +174 -46
  392. diffusers/schedulers/scheduling_ipndm.py +9 -9
  393. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
  394. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
  395. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  396. diffusers/schedulers/scheduling_lcm.py +23 -29
  397. diffusers/schedulers/scheduling_lms_discrete.py +105 -28
  398. diffusers/schedulers/scheduling_pndm.py +20 -20
  399. diffusers/schedulers/scheduling_repaint.py +21 -21
  400. diffusers/schedulers/scheduling_sasolver.py +157 -60
  401. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  402. diffusers/schedulers/scheduling_tcd.py +41 -36
  403. diffusers/schedulers/scheduling_unclip.py +19 -16
  404. diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
  405. diffusers/schedulers/scheduling_utils.py +12 -5
  406. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  407. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  408. diffusers/training_utils.py +214 -30
  409. diffusers/utils/__init__.py +17 -1
  410. diffusers/utils/constants.py +3 -0
  411. diffusers/utils/doc_utils.py +1 -0
  412. diffusers/utils/dummy_pt_objects.py +592 -7
  413. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  414. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  415. diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
  416. diffusers/utils/dynamic_modules_utils.py +34 -29
  417. diffusers/utils/export_utils.py +50 -6
  418. diffusers/utils/hub_utils.py +131 -17
  419. diffusers/utils/import_utils.py +210 -8
  420. diffusers/utils/loading_utils.py +118 -5
  421. diffusers/utils/logging.py +4 -2
  422. diffusers/utils/peft_utils.py +37 -7
  423. diffusers/utils/state_dict_utils.py +13 -2
  424. diffusers/utils/testing_utils.py +193 -11
  425. diffusers/utils/torch_utils.py +4 -0
  426. diffusers/video_processor.py +113 -0
  427. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
  428. diffusers-0.32.2.dist-info/RECORD +550 -0
  429. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
  430. diffusers/loaders/autoencoder.py +0 -146
  431. diffusers/loaders/controlnet.py +0 -136
  432. diffusers/loaders/lora.py +0 -1349
  433. diffusers/models/prior_transformer.py +0 -12
  434. diffusers/models/t5_film_transformer.py +0 -70
  435. diffusers/models/transformer_2d.py +0 -25
  436. diffusers/models/transformer_temporal.py +0 -34
  437. diffusers/models/unet_1d.py +0 -26
  438. diffusers/models/unet_1d_blocks.py +0 -203
  439. diffusers/models/unet_2d.py +0 -27
  440. diffusers/models/unet_2d_blocks.py +0 -375
  441. diffusers/models/unet_2d_condition.py +0 -25
  442. diffusers-0.27.1.dist-info/RECORD +0 -399
  443. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
  444. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
  445. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -20,11 +20,14 @@ import numpy as np
20
20
  import torch
21
21
 
22
22
  from ..configuration_utils import ConfigMixin, register_to_config
23
- from ..utils import BaseOutput, logging
23
+ from ..utils import BaseOutput, is_scipy_available, logging
24
24
  from ..utils.torch_utils import randn_tensor
25
25
  from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
26
26
 
27
27
 
28
+ if is_scipy_available():
29
+ import scipy.stats
30
+
28
31
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
32
 
30
33
 
@@ -35,16 +38,16 @@ class EulerDiscreteSchedulerOutput(BaseOutput):
35
38
  Output class for the scheduler's `step` function output.
36
39
 
37
40
  Args:
38
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39
42
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40
43
  denoising loop.
41
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
44
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42
45
  The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43
46
  `pred_original_sample` can be used to preview progress or for guidance.
44
47
  """
45
48
 
46
- prev_sample: torch.FloatTensor
47
- pred_original_sample: Optional[torch.FloatTensor] = None
49
+ prev_sample: torch.Tensor
50
+ pred_original_sample: Optional[torch.Tensor] = None
48
51
 
49
52
 
50
53
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
@@ -82,7 +85,7 @@ def betas_for_alpha_bar(
82
85
  return math.exp(t * -12.0)
83
86
 
84
87
  else:
85
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
88
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
86
89
 
87
90
  betas = []
88
91
  for i in range(num_diffusion_timesteps):
@@ -99,11 +102,11 @@ def rescale_zero_terminal_snr(betas):
99
102
 
100
103
 
101
104
  Args:
102
- betas (`torch.FloatTensor`):
105
+ betas (`torch.Tensor`):
103
106
  the betas that the scheduler is being initialized with.
104
107
 
105
108
  Returns:
106
- `torch.FloatTensor`: rescaled betas with zero terminal SNR
109
+ `torch.Tensor`: rescaled betas with zero terminal SNR
107
110
  """
108
111
  # Convert betas to alphas_bar_sqrt
109
112
  alphas = 1.0 - betas
@@ -158,6 +161,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
158
161
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
159
162
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
160
163
  the sigmas are determined according to a sequence of noise levels {σi}.
164
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
165
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
166
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
167
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
168
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
161
169
  timestep_spacing (`str`, defaults to `"linspace"`):
162
170
  The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
163
171
  Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -167,6 +175,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
167
175
  Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
168
176
  dark samples instead of limiting it to samples with medium brightness. Loosely related to
169
177
  [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
178
+ final_sigmas_type (`str`, defaults to `"zero"`):
179
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
180
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
170
181
  """
171
182
 
172
183
  _compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -183,13 +194,22 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
183
194
  prediction_type: str = "epsilon",
184
195
  interpolation_type: str = "linear",
185
196
  use_karras_sigmas: Optional[bool] = False,
197
+ use_exponential_sigmas: Optional[bool] = False,
198
+ use_beta_sigmas: Optional[bool] = False,
186
199
  sigma_min: Optional[float] = None,
187
200
  sigma_max: Optional[float] = None,
188
201
  timestep_spacing: str = "linspace",
189
202
  timestep_type: str = "discrete", # can be "discrete" or "continuous"
190
203
  steps_offset: int = 0,
191
204
  rescale_betas_zero_snr: bool = False,
205
+ final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
192
206
  ):
207
+ if self.config.use_beta_sigmas and not is_scipy_available():
208
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
209
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
210
+ raise ValueError(
211
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
212
+ )
193
213
  if trained_betas is not None:
194
214
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
195
215
  elif beta_schedule == "linear":
@@ -201,7 +221,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
201
221
  # Glide cosine schedule
202
222
  self.betas = betas_for_alpha_bar(num_train_timesteps)
203
223
  else:
204
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
224
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
205
225
 
206
226
  if rescale_betas_zero_snr:
207
227
  self.betas = rescale_zero_terminal_snr(self.betas)
@@ -231,6 +251,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
231
251
 
232
252
  self.is_scale_input_called = False
233
253
  self.use_karras_sigmas = use_karras_sigmas
254
+ self.use_exponential_sigmas = use_exponential_sigmas
255
+ self.use_beta_sigmas = use_beta_sigmas
234
256
 
235
257
  self._step_index = None
236
258
  self._begin_index = None
@@ -248,7 +270,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
248
270
  @property
249
271
  def step_index(self):
250
272
  """
251
- The index counter for current timestep. It will increae 1 after each scheduler step.
273
+ The index counter for current timestep. It will increase 1 after each scheduler step.
252
274
  """
253
275
  return self._step_index
254
276
 
@@ -270,21 +292,19 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
270
292
  """
271
293
  self._begin_index = begin_index
272
294
 
273
- def scale_model_input(
274
- self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
275
- ) -> torch.FloatTensor:
295
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
276
296
  """
277
297
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
278
298
  current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
279
299
 
280
300
  Args:
281
- sample (`torch.FloatTensor`):
301
+ sample (`torch.Tensor`):
282
302
  The input sample.
283
303
  timestep (`int`, *optional*):
284
304
  The current timestep in the diffusion chain.
285
305
 
286
306
  Returns:
287
- `torch.FloatTensor`:
307
+ `torch.Tensor`:
288
308
  A scaled input sample.
289
309
  """
290
310
  if self.step_index is None:
@@ -296,7 +316,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
296
316
  self.is_scale_input_called = True
297
317
  return sample
298
318
 
299
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
319
+ def set_timesteps(
320
+ self,
321
+ num_inference_steps: int = None,
322
+ device: Union[str, torch.device] = None,
323
+ timesteps: Optional[List[int]] = None,
324
+ sigmas: Optional[List[float]] = None,
325
+ ):
300
326
  """
301
327
  Sets the discrete timesteps used for the diffusion chain (to be run before inference).
302
328
 
@@ -305,60 +331,123 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
305
331
  The number of diffusion steps used when generating samples with a pre-trained model.
306
332
  device (`str` or `torch.device`, *optional*):
307
333
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
334
+ timesteps (`List[int]`, *optional*):
335
+ Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
336
+ based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
337
+ must be `None`, and `timestep_spacing` attribute will be ignored.
338
+ sigmas (`List[float]`, *optional*):
339
+ Custom sigmas used to support arbitrary timesteps schedule schedule. If `None`, timesteps and sigmas
340
+ will be generated based on the relevant scheduler attributes. If `sigmas` is passed,
341
+ `num_inference_steps` and `timesteps` must be `None`, and the timesteps will be generated based on the
342
+ custom sigmas schedule.
308
343
  """
309
- self.num_inference_steps = num_inference_steps
310
344
 
311
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
312
- if self.config.timestep_spacing == "linspace":
313
- timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
314
- ::-1
315
- ].copy()
316
- elif self.config.timestep_spacing == "leading":
317
- step_ratio = self.config.num_train_timesteps // self.num_inference_steps
318
- # creates integer timesteps by multiplying by ratio
319
- # casting to int to avoid issues when num_inference_step is power of 3
320
- timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
321
- timesteps += self.config.steps_offset
322
- elif self.config.timestep_spacing == "trailing":
323
- step_ratio = self.config.num_train_timesteps / self.num_inference_steps
324
- # creates integer timesteps by multiplying by ratio
325
- # casting to int to avoid issues when num_inference_step is power of 3
326
- timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
327
- timesteps -= 1
328
- else:
345
+ if timesteps is not None and sigmas is not None:
346
+ raise ValueError("Only one of `timesteps` or `sigmas` should be set.")
347
+ if num_inference_steps is None and timesteps is None and sigmas is None:
348
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps` or `sigmas.")
349
+ if num_inference_steps is not None and (timesteps is not None or sigmas is not None):
350
+ raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.")
351
+ if timesteps is not None and self.config.use_karras_sigmas:
352
+ raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
353
+ if timesteps is not None and self.config.use_exponential_sigmas:
354
+ raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
355
+ if timesteps is not None and self.config.use_beta_sigmas:
356
+ raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
357
+ if (
358
+ timesteps is not None
359
+ and self.config.timestep_type == "continuous"
360
+ and self.config.prediction_type == "v_prediction"
361
+ ):
329
362
  raise ValueError(
330
- f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
363
+ "Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
331
364
  )
332
365
 
333
- sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
334
- log_sigmas = np.log(sigmas)
366
+ if num_inference_steps is None:
367
+ num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1
368
+ self.num_inference_steps = num_inference_steps
335
369
 
336
- if self.config.interpolation_type == "linear":
337
- sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
338
- elif self.config.interpolation_type == "log_linear":
339
- sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
340
- else:
341
- raise ValueError(
342
- f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
343
- " 'linear' or 'log_linear'"
344
- )
370
+ if sigmas is not None:
371
+ log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5))
372
+ sigmas = np.array(sigmas).astype(np.float32)
373
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]])
345
374
 
346
- if self.use_karras_sigmas:
347
- sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
348
- timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
375
+ else:
376
+ if timesteps is not None:
377
+ timesteps = np.array(timesteps).astype(np.float32)
378
+ else:
379
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
380
+ if self.config.timestep_spacing == "linspace":
381
+ timesteps = np.linspace(
382
+ 0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32
383
+ )[::-1].copy()
384
+ elif self.config.timestep_spacing == "leading":
385
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
386
+ # creates integer timesteps by multiplying by ratio
387
+ # casting to int to avoid issues when num_inference_step is power of 3
388
+ timesteps = (
389
+ (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
390
+ )
391
+ timesteps += self.config.steps_offset
392
+ elif self.config.timestep_spacing == "trailing":
393
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
394
+ # creates integer timesteps by multiplying by ratio
395
+ # casting to int to avoid issues when num_inference_step is power of 3
396
+ timesteps = (
397
+ (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
398
+ )
399
+ timesteps -= 1
400
+ else:
401
+ raise ValueError(
402
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
403
+ )
404
+
405
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
406
+ log_sigmas = np.log(sigmas)
407
+ if self.config.interpolation_type == "linear":
408
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
409
+ elif self.config.interpolation_type == "log_linear":
410
+ sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
411
+ else:
412
+ raise ValueError(
413
+ f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
414
+ " 'linear' or 'log_linear'"
415
+ )
416
+
417
+ if self.config.use_karras_sigmas:
418
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
419
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
420
+
421
+ elif self.config.use_exponential_sigmas:
422
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
423
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
424
+
425
+ elif self.config.use_beta_sigmas:
426
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
427
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
428
+
429
+ if self.config.final_sigmas_type == "sigma_min":
430
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
431
+ elif self.config.final_sigmas_type == "zero":
432
+ sigma_last = 0
433
+ else:
434
+ raise ValueError(
435
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
436
+ )
437
+
438
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
349
439
 
350
440
  sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
351
441
 
352
442
  # TODO: Support the full EDM scalings for all prediction types and timestep types
353
443
  if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
354
- self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device)
444
+ self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device)
355
445
  else:
356
446
  self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
357
447
 
358
- self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
359
448
  self._step_index = None
360
449
  self._begin_index = None
361
- self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
450
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
362
451
 
363
452
  def _sigma_to_t(self, sigma, log_sigmas):
364
453
  # get log sigma
@@ -384,7 +473,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
384
473
  return t
385
474
 
386
475
  # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
387
- def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
476
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
388
477
  """Constructs the noise schedule of Karras et al. (2022)."""
389
478
 
390
479
  # Hack to make sure that other schedulers which copy this function don't break
@@ -409,6 +498,59 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
409
498
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
410
499
  return sigmas
411
500
 
501
+ # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L26
502
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
503
+ """Constructs an exponential noise schedule."""
504
+
505
+ # Hack to make sure that other schedulers which copy this function don't break
506
+ # TODO: Add this logic to the other schedulers
507
+ if hasattr(self.config, "sigma_min"):
508
+ sigma_min = self.config.sigma_min
509
+ else:
510
+ sigma_min = None
511
+
512
+ if hasattr(self.config, "sigma_max"):
513
+ sigma_max = self.config.sigma_max
514
+ else:
515
+ sigma_max = None
516
+
517
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
518
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
519
+
520
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
521
+ return sigmas
522
+
523
+ def _convert_to_beta(
524
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
525
+ ) -> torch.Tensor:
526
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
527
+
528
+ # Hack to make sure that other schedulers which copy this function don't break
529
+ # TODO: Add this logic to the other schedulers
530
+ if hasattr(self.config, "sigma_min"):
531
+ sigma_min = self.config.sigma_min
532
+ else:
533
+ sigma_min = None
534
+
535
+ if hasattr(self.config, "sigma_max"):
536
+ sigma_max = self.config.sigma_max
537
+ else:
538
+ sigma_max = None
539
+
540
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
541
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
542
+
543
+ sigmas = np.array(
544
+ [
545
+ sigma_min + (ppf * (sigma_max - sigma_min))
546
+ for ppf in [
547
+ scipy.stats.beta.ppf(timestep, alpha, beta)
548
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
549
+ ]
550
+ ]
551
+ )
552
+ return sigmas
553
+
412
554
  def index_for_timestep(self, timestep, schedule_timesteps=None):
413
555
  if schedule_timesteps is None:
414
556
  schedule_timesteps = self.timesteps
@@ -433,9 +575,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
433
575
 
434
576
  def step(
435
577
  self,
436
- model_output: torch.FloatTensor,
437
- timestep: Union[float, torch.FloatTensor],
438
- sample: torch.FloatTensor,
578
+ model_output: torch.Tensor,
579
+ timestep: Union[float, torch.Tensor],
580
+ sample: torch.Tensor,
439
581
  s_churn: float = 0.0,
440
582
  s_tmin: float = 0.0,
441
583
  s_tmax: float = float("inf"),
@@ -448,11 +590,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
448
590
  process from the learned model outputs (most often the predicted noise).
449
591
 
450
592
  Args:
451
- model_output (`torch.FloatTensor`):
593
+ model_output (`torch.Tensor`):
452
594
  The direct output from learned diffusion model.
453
595
  timestep (`float`):
454
596
  The current discrete timestep in the diffusion chain.
455
- sample (`torch.FloatTensor`):
597
+ sample (`torch.Tensor`):
456
598
  A current instance of a sample created by the diffusion process.
457
599
  s_churn (`float`):
458
600
  s_tmin (`float`):
@@ -471,11 +613,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
471
613
  returned, otherwise a tuple is returned where the first element is the sample tensor.
472
614
  """
473
615
 
474
- if (
475
- isinstance(timestep, int)
476
- or isinstance(timestep, torch.IntTensor)
477
- or isinstance(timestep, torch.LongTensor)
478
- ):
616
+ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
479
617
  raise ValueError(
480
618
  (
481
619
  "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
@@ -500,14 +638,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
500
638
 
501
639
  gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
502
640
 
503
- noise = randn_tensor(
504
- model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
505
- )
506
-
507
- eps = noise * s_noise
508
641
  sigma_hat = sigma * (gamma + 1)
509
642
 
510
643
  if gamma > 0:
644
+ noise = randn_tensor(
645
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
646
+ )
647
+ eps = noise * s_noise
511
648
  sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
512
649
 
513
650
  # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
@@ -539,16 +676,19 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
539
676
  self._step_index += 1
540
677
 
541
678
  if not return_dict:
542
- return (prev_sample,)
679
+ return (
680
+ prev_sample,
681
+ pred_original_sample,
682
+ )
543
683
 
544
684
  return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
545
685
 
546
686
  def add_noise(
547
687
  self,
548
- original_samples: torch.FloatTensor,
549
- noise: torch.FloatTensor,
550
- timesteps: torch.FloatTensor,
551
- ) -> torch.FloatTensor:
688
+ original_samples: torch.Tensor,
689
+ noise: torch.Tensor,
690
+ timesteps: torch.Tensor,
691
+ ) -> torch.Tensor:
552
692
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
553
693
  sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
554
694
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -562,7 +702,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
562
702
  # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
563
703
  if self.begin_index is None:
564
704
  step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
705
+ elif self.step_index is not None:
706
+ # add_noise is called after first denoising step (for inpainting)
707
+ step_indices = [self.step_index] * timesteps.shape[0]
565
708
  else:
709
+ # add noise is called before first denoising step to create initial latent(img2img)
566
710
  step_indices = [self.begin_index] * timesteps.shape[0]
567
711
 
568
712
  sigma = sigmas[step_indices].flatten()
@@ -572,5 +716,42 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
572
716
  noisy_samples = original_samples + noise * sigma
573
717
  return noisy_samples
574
718
 
719
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
720
+ if (
721
+ isinstance(timesteps, int)
722
+ or isinstance(timesteps, torch.IntTensor)
723
+ or isinstance(timesteps, torch.LongTensor)
724
+ ):
725
+ raise ValueError(
726
+ (
727
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
728
+ " `EulerDiscreteScheduler.get_velocity()` is not supported. Make sure to pass"
729
+ " one of the `scheduler.timesteps` as a timestep."
730
+ ),
731
+ )
732
+
733
+ if sample.device.type == "mps" and torch.is_floating_point(timesteps):
734
+ # mps does not support float64
735
+ schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
736
+ timesteps = timesteps.to(sample.device, dtype=torch.float32)
737
+ else:
738
+ schedule_timesteps = self.timesteps.to(sample.device)
739
+ timesteps = timesteps.to(sample.device)
740
+
741
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
742
+ alphas_cumprod = self.alphas_cumprod.to(sample)
743
+ sqrt_alpha_prod = alphas_cumprod[step_indices] ** 0.5
744
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
745
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
746
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
747
+
748
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[step_indices]) ** 0.5
749
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
750
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
751
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
752
+
753
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
754
+ return velocity
755
+
575
756
  def __len__(self):
576
757
  return self.config.num_train_timesteps