diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -54,11 +54,32 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
54
54
  Args:
55
55
  num_train_timesteps (`int`, defaults to 1000):
56
56
  The number of diffusion steps to train the model.
57
- timestep_spacing (`str`, defaults to `"linspace"`):
58
- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
59
- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
60
57
  shift (`float`, defaults to 1.0):
61
58
  The shift value for the timestep schedule.
59
+ use_dynamic_shifting (`bool`, defaults to False):
60
+ Whether to apply timestep shifting on-the-fly based on the image resolution.
61
+ base_shift (`float`, defaults to 0.5):
62
+ Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent
63
+ with desired output.
64
+ max_shift (`float`, defaults to 1.15):
65
+ Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be
66
+ more exaggerated or stylized.
67
+ base_image_seq_len (`int`, defaults to 256):
68
+ The base image sequence length.
69
+ max_image_seq_len (`int`, defaults to 4096):
70
+ The maximum image sequence length.
71
+ invert_sigmas (`bool`, defaults to False):
72
+ Whether to invert the sigmas.
73
+ shift_terminal (`float`, defaults to None):
74
+ The end value of the shifted timestep schedule.
75
+ use_karras_sigmas (`bool`, defaults to False):
76
+ Whether to use Karras sigmas for step sizes in the noise schedule during sampling.
77
+ use_exponential_sigmas (`bool`, defaults to False):
78
+ Whether to use exponential sigmas for step sizes in the noise schedule during sampling.
79
+ use_beta_sigmas (`bool`, defaults to False):
80
+ Whether to use beta sigmas for step sizes in the noise schedule during sampling.
81
+ time_shift_type (`str`, defaults to "exponential"):
82
+ The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
62
83
  """
63
84
 
64
85
  _compatibles = []
@@ -69,7 +90,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
69
90
  self,
70
91
  num_train_timesteps: int = 1000,
71
92
  shift: float = 1.0,
72
- use_dynamic_shifting=False,
93
+ use_dynamic_shifting: bool = False,
73
94
  base_shift: Optional[float] = 0.5,
74
95
  max_shift: Optional[float] = 1.15,
75
96
  base_image_seq_len: Optional[int] = 256,
@@ -79,6 +100,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
79
100
  use_karras_sigmas: Optional[bool] = False,
80
101
  use_exponential_sigmas: Optional[bool] = False,
81
102
  use_beta_sigmas: Optional[bool] = False,
103
+ time_shift_type: str = "exponential",
82
104
  ):
83
105
  if self.config.use_beta_sigmas and not is_scipy_available():
84
106
  raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -86,6 +108,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
86
108
  raise ValueError(
87
109
  "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
88
110
  )
111
+ if time_shift_type not in {"exponential", "linear"}:
112
+ raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.")
113
+
89
114
  timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
90
115
  timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
91
116
 
@@ -192,7 +217,10 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
192
217
  return sigma * self.config.num_train_timesteps
193
218
 
194
219
  def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
195
- return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
220
+ if self.config.time_shift_type == "exponential":
221
+ return self._time_shift_exponential(mu, sigma, t)
222
+ elif self.config.time_shift_type == "linear":
223
+ return self._time_shift_linear(mu, sigma, t)
196
224
 
197
225
  def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
198
226
  r"""
@@ -217,54 +245,94 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
217
245
 
218
246
  def set_timesteps(
219
247
  self,
220
- num_inference_steps: int = None,
248
+ num_inference_steps: Optional[int] = None,
221
249
  device: Union[str, torch.device] = None,
222
250
  sigmas: Optional[List[float]] = None,
223
251
  mu: Optional[float] = None,
252
+ timesteps: Optional[List[float]] = None,
224
253
  ):
225
254
  """
226
255
  Sets the discrete timesteps used for the diffusion chain (to be run before inference).
227
256
 
228
257
  Args:
229
- num_inference_steps (`int`):
258
+ num_inference_steps (`int`, *optional*):
230
259
  The number of diffusion steps used when generating samples with a pre-trained model.
231
260
  device (`str` or `torch.device`, *optional*):
232
261
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
262
+ sigmas (`List[float]`, *optional*):
263
+ Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
264
+ automatically.
265
+ mu (`float`, *optional*):
266
+ Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep
267
+ shifting.
268
+ timesteps (`List[float]`, *optional*):
269
+ Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed
270
+ automatically.
233
271
  """
234
272
  if self.config.use_dynamic_shifting and mu is None:
235
- raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
273
+ raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
274
+
275
+ if sigmas is not None and timesteps is not None:
276
+ if len(sigmas) != len(timesteps):
277
+ raise ValueError("`sigmas` and `timesteps` should have the same length")
278
+
279
+ if num_inference_steps is not None:
280
+ if (sigmas is not None and len(sigmas) != num_inference_steps) or (
281
+ timesteps is not None and len(timesteps) != num_inference_steps
282
+ ):
283
+ raise ValueError(
284
+ "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided"
285
+ )
286
+ else:
287
+ num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps)
236
288
 
237
- if sigmas is None:
238
- timesteps = np.linspace(
239
- self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
240
- )
289
+ self.num_inference_steps = num_inference_steps
290
+
291
+ # 1. Prepare default sigmas
292
+ is_timesteps_provided = timesteps is not None
241
293
 
294
+ if is_timesteps_provided:
295
+ timesteps = np.array(timesteps).astype(np.float32)
296
+
297
+ if sigmas is None:
298
+ if timesteps is None:
299
+ timesteps = np.linspace(
300
+ self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
301
+ )
242
302
  sigmas = timesteps / self.config.num_train_timesteps
243
303
  else:
244
304
  sigmas = np.array(sigmas).astype(np.float32)
245
305
  num_inference_steps = len(sigmas)
246
- self.num_inference_steps = num_inference_steps
247
306
 
307
+ # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of
308
+ # "exponential" or "linear" type is applied
248
309
  if self.config.use_dynamic_shifting:
249
310
  sigmas = self.time_shift(mu, 1.0, sigmas)
250
311
  else:
251
312
  sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
252
313
 
314
+ # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value
253
315
  if self.config.shift_terminal:
254
316
  sigmas = self.stretch_shift_to_terminal(sigmas)
255
317
 
318
+ # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules
256
319
  if self.config.use_karras_sigmas:
257
320
  sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
258
-
259
321
  elif self.config.use_exponential_sigmas:
260
322
  sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
261
-
262
323
  elif self.config.use_beta_sigmas:
263
324
  sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
264
325
 
326
+ # 5. Convert sigmas and timesteps to tensors and move to specified device
265
327
  sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
266
- timesteps = sigmas * self.config.num_train_timesteps
328
+ if not is_timesteps_provided:
329
+ timesteps = sigmas * self.config.num_train_timesteps
330
+ else:
331
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
267
332
 
333
+ # 6. Append the terminal sigma value.
334
+ # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the
335
+ # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi
268
336
  if self.config.invert_sigmas:
269
337
  sigmas = 1.0 - sigmas
270
338
  timesteps = sigmas * self.config.num_train_timesteps
@@ -272,7 +340,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
272
340
  else:
273
341
  sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
274
342
 
275
- self.timesteps = timesteps.to(device=device)
343
+ self.timesteps = timesteps
276
344
  self.sigmas = sigmas
277
345
  self._step_index = None
278
346
  self._begin_index = None
@@ -309,6 +377,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
309
377
  s_tmax: float = float("inf"),
310
378
  s_noise: float = 1.0,
311
379
  generator: Optional[torch.Generator] = None,
380
+ per_token_timesteps: Optional[torch.Tensor] = None,
312
381
  return_dict: bool = True,
313
382
  ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
314
383
  """
@@ -329,14 +398,17 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
329
398
  Scaling factor for noise added to the sample.
330
399
  generator (`torch.Generator`, *optional*):
331
400
  A random number generator.
401
+ per_token_timesteps (`torch.Tensor`, *optional*):
402
+ The timesteps for each token in the sample.
332
403
  return_dict (`bool`):
333
- Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
334
- tuple.
404
+ Whether or not to return a
405
+ [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
335
406
 
336
407
  Returns:
337
- [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
338
- If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
339
- returned, otherwise a tuple is returned where the first element is the sample tensor.
408
+ [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`:
409
+ If return_dict is `True`,
410
+ [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned,
411
+ otherwise a tuple is returned where the first element is the sample tensor.
340
412
  """
341
413
 
342
414
  if (
@@ -347,7 +419,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
347
419
  raise ValueError(
348
420
  (
349
421
  "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
350
- " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
422
+ " `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass"
351
423
  " one of the `scheduler.timesteps` as a timestep."
352
424
  ),
353
425
  )
@@ -358,16 +430,26 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
358
430
  # Upcast to avoid precision issues when computing prev_sample
359
431
  sample = sample.to(torch.float32)
360
432
 
361
- sigma = self.sigmas[self.step_index]
362
- sigma_next = self.sigmas[self.step_index + 1]
433
+ if per_token_timesteps is not None:
434
+ per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps
363
435
 
364
- prev_sample = sample + (sigma_next - sigma) * model_output
436
+ sigmas = self.sigmas[:, None, None]
437
+ lower_mask = sigmas < per_token_sigmas[None] - 1e-6
438
+ lower_sigmas = lower_mask * sigmas
439
+ lower_sigmas, _ = lower_sigmas.max(dim=0)
440
+ dt = (per_token_sigmas - lower_sigmas)[..., None]
441
+ else:
442
+ sigma = self.sigmas[self.step_index]
443
+ sigma_next = self.sigmas[self.step_index + 1]
444
+ dt = sigma_next - sigma
365
445
 
366
- # Cast sample back to model compatible dtype
367
- prev_sample = prev_sample.to(model_output.dtype)
446
+ prev_sample = sample + dt * model_output
368
447
 
369
448
  # upon completion increase step index by one
370
449
  self._step_index += 1
450
+ if per_token_timesteps is None:
451
+ # Cast sample back to model compatible dtype
452
+ prev_sample = prev_sample.to(model_output.dtype)
371
453
 
372
454
  if not return_dict:
373
455
  return (prev_sample,)
@@ -454,5 +536,11 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
454
536
  )
455
537
  return sigmas
456
538
 
539
+ def _time_shift_exponential(self, mu, sigma, t):
540
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
541
+
542
+ def _time_shift_linear(self, mu, sigma, t):
543
+ return mu / (mu + (1 / t - 1) ** sigma)
544
+
457
545
  def __len__(self):
458
546
  return self.config.num_train_timesteps
@@ -228,13 +228,14 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
228
228
  generator (`torch.Generator`, *optional*):
229
229
  A random number generator.
230
230
  return_dict (`bool`):
231
- Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or
232
- tuple.
231
+ Whether or not to return a
232
+ [`~schedulers.scheduling_flow_match_heun_discrete.FlowMatchHeunDiscreteSchedulerOutput`] tuple.
233
233
 
234
234
  Returns:
235
- [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
236
- If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is
237
- returned, otherwise a tuple is returned where the first element is the sample tensor.
235
+ [`~schedulers.scheduling_flow_match_heun_discrete.FlowMatchHeunDiscreteSchedulerOutput`] or `tuple`:
236
+ If return_dict is `True`,
237
+ [`~schedulers.scheduling_flow_match_heun_discrete.FlowMatchHeunDiscreteSchedulerOutput`] is returned,
238
+ otherwise a tuple is returned where the first element is the sample tensor.
238
239
  """
239
240
 
240
241
  if (
@@ -245,7 +246,7 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
245
246
  raise ValueError(
246
247
  (
247
248
  "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
248
- " `HeunDiscreteScheduler.step()` is not supported. Make sure to pass"
249
+ " `FlowMatchHeunDiscreteScheduler.step()` is not supported. Make sure to pass"
249
250
  " one of the `scheduler.timesteps` as a timestep."
250
251
  ),
251
252
  )
@@ -342,7 +342,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
342
342
  timesteps = torch.from_numpy(timesteps)
343
343
  timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
344
344
 
345
- self.timesteps = timesteps.to(device=device)
345
+ self.timesteps = timesteps.to(device=device, dtype=torch.float32)
346
346
 
347
347
  # empty dt and derivative
348
348
  self.prev_derivative = None
@@ -413,8 +413,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
413
413
 
414
414
  if timesteps[0] >= self.config.num_train_timesteps:
415
415
  raise ValueError(
416
- f"`timesteps` must start before `self.config.train_timesteps`:"
417
- f" {self.config.num_train_timesteps}."
416
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
418
417
  )
419
418
 
420
419
  # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
@@ -311,7 +311,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
311
311
  sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
312
312
 
313
313
  self.sigmas = torch.from_numpy(sigmas).to(device=device)
314
- self.timesteps = torch.from_numpy(timesteps).to(device=device)
314
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.float32)
315
315
  self._step_index = None
316
316
  self._begin_index = None
317
317
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@@ -319,7 +319,11 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
319
319
  prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance
320
320
 
321
321
  # 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf
322
- prev_known_part = (alpha_prod_t_prev**0.5) * original_image + (1 - alpha_prod_t_prev) * noise
322
+ # The computation reported in Algorithm 1 Line 5 is incorrect. Line 5 refers to formula (8a) of the same paper,
323
+ # which tells to sample from a Gaussian distribution with mean "(alpha_prod_t_prev**0.5) * original_image"
324
+ # and variance "(1 - alpha_prod_t_prev)". This means that the standard Gaussian distribution "noise" should be
325
+ # scaled by the square root of the variance (as it is done here), however Algorithm 1 Line 5 tells to scale by the variance.
326
+ prev_known_part = (alpha_prod_t_prev**0.5) * original_image + ((1 - alpha_prod_t_prev) ** 0.5) * noise
323
327
 
324
328
  # 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf
325
329
  pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part
@@ -0,0 +1,265 @@
1
+ # # Copyright 2024 Sana-Sprint Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ from ..configuration_utils import ConfigMixin, register_to_config
25
+ from ..schedulers.scheduling_utils import SchedulerMixin
26
+ from ..utils import BaseOutput, logging
27
+ from ..utils.torch_utils import randn_tensor
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ @dataclass
34
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->SCM
35
+ class SCMSchedulerOutput(BaseOutput):
36
+ """
37
+ Output class for the scheduler's `step` function output.
38
+
39
+ Args:
40
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
+ denoising loop.
43
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
44
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
45
+ `pred_original_sample` can be used to preview progress or for guidance.
46
+ """
47
+
48
+ prev_sample: torch.Tensor
49
+ pred_original_sample: Optional[torch.Tensor] = None
50
+
51
+
52
+ class SCMScheduler(SchedulerMixin, ConfigMixin):
53
+ """
54
+ `SCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
55
+ non-Markovian guidance. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass
56
+ documentation for the generic methods the library implements for all schedulers such as loading and saving.
57
+
58
+ Args:
59
+ num_train_timesteps (`int`, defaults to 1000):
60
+ The number of diffusion steps to train the model.
61
+ prediction_type (`str`, defaults to `trigflow`):
62
+ Prediction type of the scheduler function. Currently only supports "trigflow".
63
+ sigma_data (`float`, defaults to 0.5):
64
+ The standard deviation of the noise added during multi-step inference.
65
+ """
66
+
67
+ # _compatibles = [e.name for e in KarrasDiffusionSchedulers]
68
+ order = 1
69
+
70
+ @register_to_config
71
+ def __init__(
72
+ self,
73
+ num_train_timesteps: int = 1000,
74
+ prediction_type: str = "trigflow",
75
+ sigma_data: float = 0.5,
76
+ ):
77
+ """
78
+ Initialize the SCM scheduler.
79
+
80
+ Args:
81
+ num_train_timesteps (`int`, defaults to 1000):
82
+ The number of diffusion steps to train the model.
83
+ prediction_type (`str`, defaults to `trigflow`):
84
+ Prediction type of the scheduler function. Currently only supports "trigflow".
85
+ sigma_data (`float`, defaults to 0.5):
86
+ The standard deviation of the noise added during multi-step inference.
87
+ """
88
+ # standard deviation of the initial noise distribution
89
+ self.init_noise_sigma = 1.0
90
+
91
+ # setable values
92
+ self.num_inference_steps = None
93
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
94
+
95
+ self._step_index = None
96
+ self._begin_index = None
97
+
98
+ @property
99
+ def step_index(self):
100
+ return self._step_index
101
+
102
+ @property
103
+ def begin_index(self):
104
+ return self._begin_index
105
+
106
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
107
+ def set_begin_index(self, begin_index: int = 0):
108
+ """
109
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
110
+
111
+ Args:
112
+ begin_index (`int`):
113
+ The begin index for the scheduler.
114
+ """
115
+ self._begin_index = begin_index
116
+
117
+ def set_timesteps(
118
+ self,
119
+ num_inference_steps: int,
120
+ timesteps: torch.Tensor = None,
121
+ device: Union[str, torch.device] = None,
122
+ max_timesteps: float = 1.57080,
123
+ intermediate_timesteps: float = 1.3,
124
+ ):
125
+ """
126
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
127
+
128
+ Args:
129
+ num_inference_steps (`int`):
130
+ The number of diffusion steps used when generating samples with a pre-trained model.
131
+ timesteps (`torch.Tensor`, *optional*):
132
+ Custom timesteps to use for the denoising process.
133
+ max_timesteps (`float`, defaults to 1.57080):
134
+ The maximum timestep value used in the SCM scheduler.
135
+ intermediate_timesteps (`float`, *optional*, defaults to 1.3):
136
+ The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
137
+ """
138
+ if num_inference_steps > self.config.num_train_timesteps:
139
+ raise ValueError(
140
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
141
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
142
+ f" maximal {self.config.num_train_timesteps} timesteps."
143
+ )
144
+
145
+ if timesteps is not None and len(timesteps) != num_inference_steps + 1:
146
+ raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
147
+
148
+ if timesteps is not None and max_timesteps is not None:
149
+ raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
150
+
151
+ if timesteps is None and max_timesteps is None:
152
+ raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
153
+
154
+ if intermediate_timesteps is not None and num_inference_steps != 2:
155
+ raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
156
+
157
+ self.num_inference_steps = num_inference_steps
158
+
159
+ if timesteps is not None:
160
+ if isinstance(timesteps, list):
161
+ self.timesteps = torch.tensor(timesteps, device=device).float()
162
+ elif isinstance(timesteps, torch.Tensor):
163
+ self.timesteps = timesteps.to(device).float()
164
+ else:
165
+ raise ValueError(f"Unsupported timesteps type: {type(timesteps)}")
166
+ elif intermediate_timesteps is not None:
167
+ self.timesteps = torch.tensor([max_timesteps, intermediate_timesteps, 0], device=device).float()
168
+ else:
169
+ # max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
170
+ self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float()
171
+ print(f"Set timesteps: {self.timesteps}")
172
+
173
+ self._step_index = None
174
+ self._begin_index = None
175
+
176
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
177
+ def _init_step_index(self, timestep):
178
+ if self.begin_index is None:
179
+ if isinstance(timestep, torch.Tensor):
180
+ timestep = timestep.to(self.timesteps.device)
181
+ self._step_index = self.index_for_timestep(timestep)
182
+ else:
183
+ self._step_index = self._begin_index
184
+
185
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
186
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
187
+ if schedule_timesteps is None:
188
+ schedule_timesteps = self.timesteps
189
+
190
+ indices = (schedule_timesteps == timestep).nonzero()
191
+
192
+ # The sigma index that is taken for the **very** first `step`
193
+ # is always the second index (or the last index if there is only 1)
194
+ # This way we can ensure we don't accidentally skip a sigma in
195
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
196
+ pos = 1 if len(indices) > 1 else 0
197
+
198
+ return indices[pos].item()
199
+
200
+ def step(
201
+ self,
202
+ model_output: torch.FloatTensor,
203
+ timestep: float,
204
+ sample: torch.FloatTensor,
205
+ generator: torch.Generator = None,
206
+ return_dict: bool = True,
207
+ ) -> Union[SCMSchedulerOutput, Tuple]:
208
+ """
209
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
210
+ process from the learned model outputs (most often the predicted noise).
211
+
212
+ Args:
213
+ model_output (`torch.FloatTensor`):
214
+ The direct output from learned diffusion model.
215
+ timestep (`float`):
216
+ The current discrete timestep in the diffusion chain.
217
+ sample (`torch.FloatTensor`):
218
+ A current instance of a sample created by the diffusion process.
219
+ return_dict (`bool`, *optional*, defaults to `True`):
220
+ Whether or not to return a [`~schedulers.scheduling_scm.SCMSchedulerOutput`] or `tuple`.
221
+ Returns:
222
+ [`~schedulers.scheduling_utils.SCMSchedulerOutput`] or `tuple`:
223
+ If return_dict is `True`, [`~schedulers.scheduling_scm.SCMSchedulerOutput`] is returned, otherwise a
224
+ tuple is returned where the first element is the sample tensor.
225
+ """
226
+ if self.num_inference_steps is None:
227
+ raise ValueError(
228
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
229
+ )
230
+
231
+ if self.step_index is None:
232
+ self._init_step_index(timestep)
233
+
234
+ # 2. compute alphas, betas
235
+ t = self.timesteps[self.step_index + 1]
236
+ s = self.timesteps[self.step_index]
237
+
238
+ # 4. Different Parameterization:
239
+ parameterization = self.config.prediction_type
240
+
241
+ if parameterization == "trigflow":
242
+ pred_x0 = torch.cos(s) * sample - torch.sin(s) * model_output
243
+ else:
244
+ raise ValueError(f"Unsupported parameterization: {parameterization}")
245
+
246
+ # 5. Sample z ~ N(0, I), For MultiStep Inference
247
+ # Noise is not used for one-step sampling.
248
+ if len(self.timesteps) > 1:
249
+ noise = (
250
+ randn_tensor(model_output.shape, device=model_output.device, generator=generator)
251
+ * self.config.sigma_data
252
+ )
253
+ prev_sample = torch.cos(t) * pred_x0 + torch.sin(t) * noise
254
+ else:
255
+ prev_sample = pred_x0
256
+
257
+ self._step_index += 1
258
+
259
+ if not return_dict:
260
+ return (prev_sample, pred_x0)
261
+
262
+ return SCMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0)
263
+
264
+ def __len__(self):
265
+ return self.config.num_train_timesteps
@@ -431,8 +431,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
431
431
 
432
432
  if timesteps[0] >= self.config.num_train_timesteps:
433
433
  raise ValueError(
434
- f"`timesteps` must start before `self.config.train_timesteps`:"
435
- f" {self.config.num_train_timesteps}."
434
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
436
435
  )
437
436
 
438
437
  # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
@@ -19,6 +19,7 @@ from typing import Optional, Union
19
19
 
20
20
  import torch
21
21
  from huggingface_hub.utils import validate_hf_hub_args
22
+ from typing_extensions import Self
22
23
 
23
24
  from ..utils import BaseOutput, PushToHubMixin
24
25
 
@@ -99,7 +100,7 @@ class SchedulerMixin(PushToHubMixin):
99
100
  subfolder: Optional[str] = None,
100
101
  return_unused_kwargs=False,
101
102
  **kwargs,
102
- ):
103
+ ) -> Self:
103
104
  r"""
104
105
  Instantiate a scheduler from a pre-defined JSON configuration file in a local directory or Hub repository.
105
106