diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
1
+ from typing import TYPE_CHECKING, Any, Dict, List, Union
2
+
3
+ from diffusers.utils.import_utils import is_optimum_quanto_version
4
+
5
+ from ...utils import (
6
+ get_module_from_name,
7
+ is_accelerate_available,
8
+ is_accelerate_version,
9
+ is_optimum_quanto_available,
10
+ is_torch_available,
11
+ logging,
12
+ )
13
+ from ..base import DiffusersQuantizer
14
+
15
+
16
+ if TYPE_CHECKING:
17
+ from ...models.modeling_utils import ModelMixin
18
+
19
+
20
+ if is_torch_available():
21
+ import torch
22
+
23
+ if is_accelerate_available():
24
+ from accelerate.utils import CustomDtype, set_module_tensor_to_device
25
+
26
+ if is_optimum_quanto_available():
27
+ from .utils import _replace_with_quanto_layers
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class QuantoQuantizer(DiffusersQuantizer):
33
+ r"""
34
+ Diffusers Quantizer for Optimum Quanto
35
+ """
36
+
37
+ use_keep_in_fp32_modules = True
38
+ requires_calibration = False
39
+ required_packages = ["quanto", "accelerate"]
40
+
41
+ def __init__(self, quantization_config, **kwargs):
42
+ super().__init__(quantization_config, **kwargs)
43
+
44
+ def validate_environment(self, *args, **kwargs):
45
+ if not is_optimum_quanto_available():
46
+ raise ImportError(
47
+ "Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
48
+ )
49
+ if not is_optimum_quanto_version(">=", "0.2.6"):
50
+ raise ImportError(
51
+ "Loading an optimum-quanto quantized model requires `optimum-quanto>=0.2.6`. "
52
+ "Please upgrade your installation with `pip install --upgrade optimum-quanto"
53
+ )
54
+
55
+ if not is_accelerate_available():
56
+ raise ImportError(
57
+ "Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)"
58
+ )
59
+
60
+ device_map = kwargs.get("device_map", None)
61
+ if isinstance(device_map, dict) and len(device_map.keys()) > 1:
62
+ raise ValueError(
63
+ "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the Quanto backend"
64
+ )
65
+
66
+ def check_if_quantized_param(
67
+ self,
68
+ model: "ModelMixin",
69
+ param_value: "torch.Tensor",
70
+ param_name: str,
71
+ state_dict: Dict[str, Any],
72
+ **kwargs,
73
+ ):
74
+ # Quanto imports diffusers internally. This is here to prevent circular imports
75
+ from optimum.quanto import QModuleMixin, QTensor
76
+ from optimum.quanto.tensor.packed import PackedTensor
77
+
78
+ module, tensor_name = get_module_from_name(model, param_name)
79
+ if self.pre_quantized and any(isinstance(module, t) for t in [QTensor, PackedTensor]):
80
+ return True
81
+ elif isinstance(module, QModuleMixin) and "weight" in tensor_name:
82
+ return not module.frozen
83
+
84
+ return False
85
+
86
+ def create_quantized_param(
87
+ self,
88
+ model: "ModelMixin",
89
+ param_value: "torch.Tensor",
90
+ param_name: str,
91
+ target_device: "torch.device",
92
+ *args,
93
+ **kwargs,
94
+ ):
95
+ """
96
+ Create the quantized parameter by calling .freeze() after setting it to the module.
97
+ """
98
+
99
+ dtype = kwargs.get("dtype", torch.float32)
100
+ module, tensor_name = get_module_from_name(model, param_name)
101
+ if self.pre_quantized:
102
+ setattr(module, tensor_name, param_value)
103
+ else:
104
+ set_module_tensor_to_device(model, param_name, target_device, param_value, dtype)
105
+ module.freeze()
106
+ module.weight.requires_grad = False
107
+
108
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
109
+ max_memory = {key: val * 0.90 for key, val in max_memory.items()}
110
+ return max_memory
111
+
112
+ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
113
+ if is_accelerate_version(">=", "0.27.0"):
114
+ mapping = {
115
+ "int8": torch.int8,
116
+ "float8": CustomDtype.FP8,
117
+ "int4": CustomDtype.INT4,
118
+ "int2": CustomDtype.INT2,
119
+ }
120
+ target_dtype = mapping[self.quantization_config.weights_dtype]
121
+
122
+ return target_dtype
123
+
124
+ def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype":
125
+ if torch_dtype is None:
126
+ logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
127
+ torch_dtype = torch.float32
128
+ return torch_dtype
129
+
130
+ def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
131
+ # Quanto imports diffusers internally. This is here to prevent circular imports
132
+ from optimum.quanto import QModuleMixin
133
+
134
+ not_missing_keys = []
135
+ for name, module in model.named_modules():
136
+ if isinstance(module, QModuleMixin):
137
+ for missing in missing_keys:
138
+ if (
139
+ (name in missing or name in f"{prefix}.{missing}")
140
+ and not missing.endswith(".weight")
141
+ and not missing.endswith(".bias")
142
+ ):
143
+ not_missing_keys.append(missing)
144
+ return [k for k in missing_keys if k not in not_missing_keys]
145
+
146
+ def _process_model_before_weight_loading(
147
+ self,
148
+ model: "ModelMixin",
149
+ device_map,
150
+ keep_in_fp32_modules: List[str] = [],
151
+ **kwargs,
152
+ ):
153
+ self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
154
+
155
+ if not isinstance(self.modules_to_not_convert, list):
156
+ self.modules_to_not_convert = [self.modules_to_not_convert]
157
+
158
+ self.modules_to_not_convert.extend(keep_in_fp32_modules)
159
+
160
+ model = _replace_with_quanto_layers(
161
+ model,
162
+ modules_to_not_convert=self.modules_to_not_convert,
163
+ quantization_config=self.quantization_config,
164
+ pre_quantized=self.pre_quantized,
165
+ )
166
+ model.config.quantization_config = self.quantization_config
167
+
168
+ def _process_model_after_weight_loading(self, model, **kwargs):
169
+ return model
170
+
171
+ @property
172
+ def is_trainable(self):
173
+ return True
174
+
175
+ @property
176
+ def is_serializable(self):
177
+ return True
@@ -0,0 +1,60 @@
1
+ import torch.nn as nn
2
+
3
+ from ...utils import is_accelerate_available, logging
4
+
5
+
6
+ logger = logging.get_logger(__name__)
7
+
8
+ if is_accelerate_available():
9
+ from accelerate import init_empty_weights
10
+
11
+
12
+ def _replace_with_quanto_layers(model, quantization_config, modules_to_not_convert: list, pre_quantized=False):
13
+ # Quanto imports diffusers internally. These are placed here to avoid circular imports
14
+ from optimum.quanto import QLinear, freeze, qfloat8, qint2, qint4, qint8
15
+
16
+ def _get_weight_type(dtype: str):
17
+ return {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}[dtype]
18
+
19
+ def _replace_layers(model, quantization_config, modules_to_not_convert):
20
+ has_children = list(model.children())
21
+ if not has_children:
22
+ return model
23
+
24
+ for name, module in model.named_children():
25
+ _replace_layers(module, quantization_config, modules_to_not_convert)
26
+
27
+ if name in modules_to_not_convert:
28
+ continue
29
+
30
+ if isinstance(module, nn.Linear):
31
+ with init_empty_weights():
32
+ qlinear = QLinear(
33
+ in_features=module.in_features,
34
+ out_features=module.out_features,
35
+ bias=module.bias is not None,
36
+ dtype=module.weight.dtype,
37
+ weights=_get_weight_type(quantization_config.weights_dtype),
38
+ )
39
+ model._modules[name] = qlinear
40
+ model._modules[name].source_cls = type(module)
41
+ model._modules[name].requires_grad_(False)
42
+
43
+ return model
44
+
45
+ model = _replace_layers(model, quantization_config, modules_to_not_convert)
46
+ has_been_replaced = any(isinstance(replaced_module, QLinear) for _, replaced_module in model.named_modules())
47
+
48
+ if not has_been_replaced:
49
+ logger.warning(
50
+ f"{model.__class__.__name__} does not appear to have any `nn.Linear` modules. Quantization will not be applied."
51
+ " Please check your model architecture, or submit an issue on Github if you think this is a bug."
52
+ " https://github.com/huggingface/diffusers/issues/new"
53
+ )
54
+
55
+ # We need to freeze the pre_quantized model in order for the loaded state_dict and model state dict
56
+ # to match when trying to load weights with load_model_dict_into_meta
57
+ if pre_quantized:
58
+ freeze(model)
59
+
60
+ return model
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -23,7 +23,14 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
23
23
 
24
24
  from packaging import version
25
25
 
26
- from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging
26
+ from ...utils import (
27
+ get_module_from_name,
28
+ is_torch_available,
29
+ is_torch_version,
30
+ is_torchao_available,
31
+ is_torchao_version,
32
+ logging,
33
+ )
27
34
  from ..base import DiffusersQuantizer
28
35
 
29
36
 
@@ -62,6 +69,43 @@ if is_torchao_available():
62
69
  from torchao.quantization import quantize_
63
70
 
64
71
 
72
+ def _update_torch_safe_globals():
73
+ safe_globals = [
74
+ (torch.uint1, "torch.uint1"),
75
+ (torch.uint2, "torch.uint2"),
76
+ (torch.uint3, "torch.uint3"),
77
+ (torch.uint4, "torch.uint4"),
78
+ (torch.uint5, "torch.uint5"),
79
+ (torch.uint6, "torch.uint6"),
80
+ (torch.uint7, "torch.uint7"),
81
+ ]
82
+ try:
83
+ from torchao.dtypes import NF4Tensor
84
+ from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
85
+ from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
86
+ from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
87
+
88
+ safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
89
+
90
+ except (ImportError, ModuleNotFoundError) as e:
91
+ logger.warning(
92
+ "Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`"
93
+ )
94
+ logger.debug(e)
95
+
96
+ finally:
97
+ torch.serialization.add_safe_globals(safe_globals=safe_globals)
98
+
99
+
100
+ if (
101
+ is_torch_available()
102
+ and is_torch_version(">=", "2.6.0")
103
+ and is_torchao_available()
104
+ and is_torchao_version(">=", "0.7.0")
105
+ ):
106
+ _update_torch_safe_globals()
107
+
108
+
65
109
  logger = logging.get_logger(__name__)
66
110
 
67
111
 
@@ -215,6 +259,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
215
259
  target_device: "torch.device",
216
260
  state_dict: Dict[str, Any],
217
261
  unexpected_keys: List[str],
262
+ **kwargs,
218
263
  ):
219
264
  r"""
220
265
  Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor,
@@ -68,6 +68,7 @@ else:
68
68
  _import_structure["scheduling_pndm"] = ["PNDMScheduler"]
69
69
  _import_structure["scheduling_repaint"] = ["RePaintScheduler"]
70
70
  _import_structure["scheduling_sasolver"] = ["SASolverScheduler"]
71
+ _import_structure["scheduling_scm"] = ["SCMScheduler"]
71
72
  _import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"]
72
73
  _import_structure["scheduling_tcd"] = ["TCDScheduler"]
73
74
  _import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
@@ -168,13 +169,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
168
169
  from .scheduling_pndm import PNDMScheduler
169
170
  from .scheduling_repaint import RePaintScheduler
170
171
  from .scheduling_sasolver import SASolverScheduler
172
+ from .scheduling_scm import SCMScheduler
171
173
  from .scheduling_sde_ve import ScoreSdeVeScheduler
172
174
  from .scheduling_tcd import TCDScheduler
173
175
  from .scheduling_unclip import UnCLIPScheduler
174
176
  from .scheduling_unipc_multistep import UniPCMultistepScheduler
175
177
  from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin
176
178
  from .scheduling_vq_diffusion import VQDiffusionScheduler
177
-
178
179
  try:
179
180
  if not is_flax_available():
180
181
  raise OptionalDependencyNotAvailable()
@@ -203,8 +203,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
203
203
 
204
204
  if timesteps[0] >= self.config.num_train_timesteps:
205
205
  raise ValueError(
206
- f"`timesteps` must start before `self.config.train_timesteps`:"
207
- f" {self.config.num_train_timesteps}."
206
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
208
207
  )
209
208
 
210
209
  timesteps = np.array(timesteps, dtype=np.int64)
@@ -266,7 +266,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
266
266
 
267
267
  self.num_inference_steps = num_inference_steps
268
268
 
269
- # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
269
+ # "leading" and "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
270
270
  if self.config.timestep_spacing == "leading":
271
271
  step_ratio = self.config.num_train_timesteps // self.num_inference_steps
272
272
  # creates integer timesteps by multiplying by ratio
@@ -142,7 +142,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
142
142
  The final `beta` value.
143
143
  beta_schedule (`str`, defaults to `"linear"`):
144
144
  The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
145
- `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
145
+ `linear`, `scaled_linear`, `squaredcos_cap_v2`, or `sigmoid`.
146
146
  trained_betas (`np.ndarray`, *optional*):
147
147
  An array of betas to pass directly to the constructor without using `beta_start` and `beta_end`.
148
148
  variance_type (`str`, defaults to `"fixed_small"`):
@@ -279,8 +279,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
279
279
 
280
280
  if timesteps[0] >= self.config.num_train_timesteps:
281
281
  raise ValueError(
282
- f"`timesteps` must start before `self.config.train_timesteps`:"
283
- f" {self.config.num_train_timesteps}."
282
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
284
283
  )
285
284
 
286
285
  timesteps = np.array(timesteps, dtype=np.int64)
@@ -289,8 +289,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
289
289
 
290
290
  if timesteps[0] >= self.config.num_train_timesteps:
291
291
  raise ValueError(
292
- f"`timesteps` must start before `self.config.train_timesteps`:"
293
- f" {self.config.num_train_timesteps}."
292
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
294
293
  )
295
294
 
296
295
  timesteps = np.array(timesteps, dtype=np.int64)
@@ -136,8 +136,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
136
136
  sampling, and `solver_order=3` for unconditional sampling.
137
137
  prediction_type (`str`, defaults to `epsilon`, *optional*):
138
138
  Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
139
- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
140
- Video](https://imagen.research.google/video/paper.pdf) paper).
139
+ `sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
140
+ Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
141
141
  thresholding (`bool`, defaults to `False`):
142
142
  Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
143
143
  as Stable Diffusion.
@@ -174,6 +174,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
174
174
  Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
175
175
  the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
176
176
  `lambda(t)`.
177
+ use_flow_sigmas (`bool`, *optional*, defaults to `False`):
178
+ Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
179
+ flow_shift (`float`, *optional*, defaults to 1.0):
180
+ The shift value for the timestep schedule for flow matching.
177
181
  final_sigmas_type (`str`, defaults to `"zero"`):
178
182
  The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
179
183
  sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
@@ -395,12 +399,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
395
399
  if self.config.use_karras_sigmas:
396
400
  sigmas = np.flip(sigmas).copy()
397
401
  sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
398
- timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
402
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
403
+ if self.config.beta_schedule != "squaredcos_cap_v2":
404
+ timesteps = timesteps.round()
399
405
  elif self.config.use_lu_lambdas:
400
406
  lambdas = np.flip(log_sigmas.copy())
401
407
  lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
402
408
  sigmas = np.exp(lambdas)
403
- timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
409
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
410
+ if self.config.beta_schedule != "squaredcos_cap_v2":
411
+ timesteps = timesteps.round()
404
412
  elif self.config.use_exponential_sigmas:
405
413
  sigmas = np.flip(sigmas).copy()
406
414
  sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
@@ -14,7 +14,7 @@
14
14
 
15
15
  import math
16
16
  from dataclasses import dataclass
17
- from typing import Optional, Tuple, Union
17
+ from typing import List, Optional, Tuple, Union
18
18
 
19
19
  import torch
20
20
 
@@ -77,6 +77,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
77
77
  Video](https://imagen.research.google/video/paper.pdf) paper).
78
78
  rho (`float`, *optional*, defaults to 7.0):
79
79
  The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
80
+ final_sigmas_type (`str`, defaults to `"zero"`):
81
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
82
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
80
83
  """
81
84
 
82
85
  _compatibles = []
@@ -92,6 +95,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
92
95
  num_train_timesteps: int = 1000,
93
96
  prediction_type: str = "epsilon",
94
97
  rho: float = 7.0,
98
+ final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
95
99
  ):
96
100
  if sigma_schedule not in ["karras", "exponential"]:
97
101
  raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
@@ -99,15 +103,24 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
99
103
  # setable values
100
104
  self.num_inference_steps = None
101
105
 
102
- ramp = torch.linspace(0, 1, num_train_timesteps)
106
+ sigmas = torch.arange(num_train_timesteps + 1) / num_train_timesteps
103
107
  if sigma_schedule == "karras":
104
- sigmas = self._compute_karras_sigmas(ramp)
108
+ sigmas = self._compute_karras_sigmas(sigmas)
105
109
  elif sigma_schedule == "exponential":
106
- sigmas = self._compute_exponential_sigmas(ramp)
110
+ sigmas = self._compute_exponential_sigmas(sigmas)
107
111
 
108
112
  self.timesteps = self.precondition_noise(sigmas)
109
113
 
110
- self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
114
+ if self.config.final_sigmas_type == "sigma_min":
115
+ sigma_last = sigmas[-1]
116
+ elif self.config.final_sigmas_type == "zero":
117
+ sigma_last = 0
118
+ else:
119
+ raise ValueError(
120
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
121
+ )
122
+
123
+ self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
111
124
 
112
125
  self.is_scale_input_called = False
113
126
 
@@ -197,7 +210,12 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
197
210
  self.is_scale_input_called = True
198
211
  return sample
199
212
 
200
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
213
+ def set_timesteps(
214
+ self,
215
+ num_inference_steps: int = None,
216
+ device: Union[str, torch.device] = None,
217
+ sigmas: Optional[Union[torch.Tensor, List[float]]] = None,
218
+ ):
201
219
  """
202
220
  Sets the discrete timesteps used for the diffusion chain (to be run before inference).
203
221
 
@@ -206,19 +224,36 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
206
224
  The number of diffusion steps used when generating samples with a pre-trained model.
207
225
  device (`str` or `torch.device`, *optional*):
208
226
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
227
+ sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
228
+ Custom sigmas to use for the denoising process. If not defined, the default behavior when
229
+ `num_inference_steps` is passed will be used.
209
230
  """
210
231
  self.num_inference_steps = num_inference_steps
211
232
 
212
- ramp = torch.linspace(0, 1, self.num_inference_steps)
233
+ if sigmas is None:
234
+ sigmas = torch.linspace(0, 1, self.num_inference_steps)
235
+ elif isinstance(sigmas, float):
236
+ sigmas = torch.tensor(sigmas, dtype=torch.float32)
237
+ else:
238
+ sigmas = sigmas
213
239
  if self.config.sigma_schedule == "karras":
214
- sigmas = self._compute_karras_sigmas(ramp)
240
+ sigmas = self._compute_karras_sigmas(sigmas)
215
241
  elif self.config.sigma_schedule == "exponential":
216
- sigmas = self._compute_exponential_sigmas(ramp)
242
+ sigmas = self._compute_exponential_sigmas(sigmas)
217
243
 
218
244
  sigmas = sigmas.to(dtype=torch.float32, device=device)
219
245
  self.timesteps = self.precondition_noise(sigmas)
220
246
 
221
- self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
247
+ if self.config.final_sigmas_type == "sigma_min":
248
+ sigma_last = sigmas[-1]
249
+ elif self.config.final_sigmas_type == "zero":
250
+ sigma_last = 0
251
+ else:
252
+ raise ValueError(
253
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
254
+ )
255
+
256
+ self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
222
257
  self._step_index = None
223
258
  self._begin_index = None
224
259
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication