diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,653 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from dataclasses import dataclass
17
+ from typing import Any, Callable, List, Optional, Tuple
18
+
19
+ import torch
20
+
21
+ from ..models.attention_processor import Attention, MochiAttention
22
+ from ..models.modeling_outputs import Transformer2DModelOutput
23
+ from ..utils import logging
24
+ from .hooks import HookRegistry, ModelHook
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ _FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
31
+ _FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
32
+ _ATTENTION_CLASSES = (Attention, MochiAttention)
33
+ _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
34
+ "^blocks.*attn",
35
+ "^transformer_blocks.*attn",
36
+ "^single_transformer_blocks.*attn",
37
+ )
38
+ _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
39
+ _TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
40
+ _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = (
41
+ "hidden_states",
42
+ "encoder_hidden_states",
43
+ "timestep",
44
+ "attention_mask",
45
+ "encoder_attention_mask",
46
+ )
47
+
48
+
49
+ @dataclass
50
+ class FasterCacheConfig:
51
+ r"""
52
+ Configuration for [FasterCache](https://huggingface.co/papers/2410.19355).
53
+
54
+ Attributes:
55
+ spatial_attention_block_skip_range (`int`, defaults to `2`):
56
+ Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
57
+ be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
58
+ states again.
59
+ temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
60
+ Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
61
+ be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
62
+ states again.
63
+ spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`):
64
+ The timestep range within which the spatial attention computation can be skipped without a significant loss
65
+ in quality. This is to be determined by the user based on the underlying model. The first value in the
66
+ tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
67
+ denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
68
+ timestep 0). For the default values, this would mean that the spatial attention computation skipping will
69
+ be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising
70
+ process.
71
+ temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`):
72
+ The timestep range within which the temporal attention computation can be skipped without a significant
73
+ loss in quality. This is to be determined by the user based on the underlying model. The first value in the
74
+ tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
75
+ denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
76
+ timestep 0).
77
+ low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`):
78
+ The timestep range within which the low frequency weight scaling update is applied. The first value in the
79
+ tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
80
+ function for the update is called only within this range.
81
+ high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`):
82
+ The timestep range within which the high frequency weight scaling update is applied. The first value in the
83
+ tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
84
+ function for the update is called only within this range.
85
+ alpha_low_frequency (`float`, defaults to `1.1`):
86
+ The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from
87
+ the conditional branch outputs.
88
+ alpha_high_frequency (`float`, defaults to `1.1`):
89
+ The weight to scale the high frequency updates by. This is used to approximate the unconditional branch
90
+ from the conditional branch outputs.
91
+ unconditional_batch_skip_range (`int`, defaults to `5`):
92
+ Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
93
+ computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be re-used) before
94
+ computing the new unconditional branch states again.
95
+ unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`):
96
+ The timestep range within which the unconditional branch computation can be skipped without a significant
97
+ loss in quality. This is to be determined by the user based on the underlying model. The first value in the
98
+ tuple is the lower bound and the second value is the upper bound.
99
+ spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
100
+ The identifiers to match the spatial attention blocks in the model. If the name of the block contains any
101
+ of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
102
+ partial layer names, or regex patterns. Matching will always be done using a regex match.
103
+ temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
104
+ The identifiers to match the temporal attention blocks in the model. If the name of the block contains any
105
+ of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
106
+ partial layer names, or regex patterns. Matching will always be done using a regex match.
107
+ attention_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
108
+ The callback function to determine the weight to scale the attention outputs by. This function should take
109
+ the attention module as input and return a float value. This is used to approximate the unconditional
110
+ branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps.
111
+ Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference
112
+ progresses. Users are encouraged to experiment and provide custom weight schedules that take into account
113
+ the number of inference steps and underlying model behaviour as denoising progresses.
114
+ low_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
115
+ The callback function to determine the weight to scale the low frequency updates by. If not provided, the
116
+ default weight is 1.1 for timesteps within the range specified (as described in the paper).
117
+ high_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
118
+ The callback function to determine the weight to scale the high frequency updates by. If not provided, the
119
+ default weight is 1.1 for timesteps within the range specified (as described in the paper).
120
+ tensor_format (`str`, defaults to `"BCFHW"`):
121
+ The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is
122
+ used to split individual latent frames in order for low and high frequency components to be computed.
123
+ is_guidance_distilled (`bool`, defaults to `False`):
124
+ Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be
125
+ applied at the denoiser-level to skip the unconditional branch computation (as there is none).
126
+ _unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
127
+ The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and
128
+ conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will
129
+ split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs
130
+ names that contain the batchwise-concatenated unconditional and conditional inputs.
131
+ """
132
+
133
+ # In the paper and codebase, they hardcode these values to 2. However, it can be made configurable
134
+ # after some testing. We default to 2 if these parameters are not provided.
135
+ spatial_attention_block_skip_range: int = 2
136
+ temporal_attention_block_skip_range: Optional[int] = None
137
+
138
+ spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
139
+ temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
140
+
141
+ # Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
142
+ low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901)
143
+ high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301)
144
+
145
+ # ⍺1 and ⍺2 as mentioned in Equation 11 of the paper
146
+ alpha_low_frequency: float = 1.1
147
+ alpha_high_frequency: float = 1.1
148
+
149
+ # n as described in CFG-Cache explanation in the paper - dependant on the model
150
+ unconditional_batch_skip_range: int = 5
151
+ unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
152
+
153
+ spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
154
+ temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
155
+
156
+ attention_weight_callback: Callable[[torch.nn.Module], float] = None
157
+ low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
158
+ high_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
159
+
160
+ tensor_format: str = "BCFHW"
161
+ is_guidance_distilled: bool = False
162
+
163
+ current_timestep_callback: Callable[[], int] = None
164
+
165
+ _unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
166
+
167
+ def __repr__(self) -> str:
168
+ return (
169
+ f"FasterCacheConfig(\n"
170
+ f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
171
+ f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
172
+ f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
173
+ f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
174
+ f" low_frequency_weight_update_timestep_range={self.low_frequency_weight_update_timestep_range},\n"
175
+ f" high_frequency_weight_update_timestep_range={self.high_frequency_weight_update_timestep_range},\n"
176
+ f" alpha_low_frequency={self.alpha_low_frequency},\n"
177
+ f" alpha_high_frequency={self.alpha_high_frequency},\n"
178
+ f" unconditional_batch_skip_range={self.unconditional_batch_skip_range},\n"
179
+ f" unconditional_batch_timestep_skip_range={self.unconditional_batch_timestep_skip_range},\n"
180
+ f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
181
+ f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
182
+ f" tensor_format={self.tensor_format},\n"
183
+ f")"
184
+ )
185
+
186
+
187
+ class FasterCacheDenoiserState:
188
+ r"""
189
+ State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module.
190
+ """
191
+
192
+ def __init__(self) -> None:
193
+ self.iteration: int = 0
194
+ self.low_frequency_delta: torch.Tensor = None
195
+ self.high_frequency_delta: torch.Tensor = None
196
+
197
+ def reset(self):
198
+ self.iteration = 0
199
+ self.low_frequency_delta = None
200
+ self.high_frequency_delta = None
201
+
202
+
203
+ class FasterCacheBlockState:
204
+ r"""
205
+ State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is
206
+ applied to will have an instance of this state.
207
+ """
208
+
209
+ def __init__(self) -> None:
210
+ self.iteration: int = 0
211
+ self.batch_size: int = None
212
+ self.cache: Tuple[torch.Tensor, torch.Tensor] = None
213
+
214
+ def reset(self):
215
+ self.iteration = 0
216
+ self.batch_size = None
217
+ self.cache = None
218
+
219
+
220
+ class FasterCacheDenoiserHook(ModelHook):
221
+ _is_stateful = True
222
+
223
+ def __init__(
224
+ self,
225
+ unconditional_batch_skip_range: int,
226
+ unconditional_batch_timestep_skip_range: Tuple[int, int],
227
+ tensor_format: str,
228
+ is_guidance_distilled: bool,
229
+ uncond_cond_input_kwargs_identifiers: List[str],
230
+ current_timestep_callback: Callable[[], int],
231
+ low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
232
+ high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
233
+ ) -> None:
234
+ super().__init__()
235
+
236
+ self.unconditional_batch_skip_range = unconditional_batch_skip_range
237
+ self.unconditional_batch_timestep_skip_range = unconditional_batch_timestep_skip_range
238
+ # We can't easily detect what args are to be split in unconditional and conditional branches. We
239
+ # can only do it for kwargs, hence they are the only ones we split. The args are passed as-is.
240
+ # If a model is to be made compatible with FasterCache, the user must ensure that the inputs that
241
+ # contain batchwise-concatenated unconditional and conditional inputs are passed as kwargs.
242
+ self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers
243
+ self.tensor_format = tensor_format
244
+ self.is_guidance_distilled = is_guidance_distilled
245
+
246
+ self.current_timestep_callback = current_timestep_callback
247
+ self.low_frequency_weight_callback = low_frequency_weight_callback
248
+ self.high_frequency_weight_callback = high_frequency_weight_callback
249
+
250
+ def initialize_hook(self, module):
251
+ self.state = FasterCacheDenoiserState()
252
+ return module
253
+
254
+ @staticmethod
255
+ def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
256
+ # Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs
257
+ # followed by conditional inputs.
258
+ _, cond = input.chunk(2, dim=0)
259
+ return cond
260
+
261
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
262
+ # Split the unconditional and conditional inputs. We only want to infer the conditional branch if the
263
+ # requirements for skipping the unconditional branch are met as described in the paper.
264
+ # We skip the unconditional branch only if the following conditions are met:
265
+ # 1. We have completed at least one iteration of the denoiser
266
+ # 2. The current timestep is within the range specified by the user. This is the optimal timestep range
267
+ # where approximating the unconditional branch from the computation of the conditional branch is possible
268
+ # without a significant loss in quality.
269
+ # 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that
270
+ # we compute the unconditional branch at least once every few iterations to ensure minimal quality loss.
271
+ is_within_timestep_range = (
272
+ self.unconditional_batch_timestep_skip_range[0]
273
+ < self.current_timestep_callback()
274
+ < self.unconditional_batch_timestep_skip_range[1]
275
+ )
276
+ should_skip_uncond = (
277
+ self.state.iteration > 0
278
+ and is_within_timestep_range
279
+ and self.state.iteration % self.unconditional_batch_skip_range != 0
280
+ and not self.is_guidance_distilled
281
+ )
282
+
283
+ if should_skip_uncond:
284
+ is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys())
285
+ if is_any_kwarg_uncond:
286
+ logger.debug("FasterCache - Skipping unconditional branch computation")
287
+ args = tuple([self._get_cond_input(arg) if torch.is_tensor(arg) else arg for arg in args])
288
+ kwargs = {
289
+ k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v)
290
+ for k, v in kwargs.items()
291
+ }
292
+
293
+ output = self.fn_ref.original_forward(*args, **kwargs)
294
+
295
+ if self.is_guidance_distilled:
296
+ self.state.iteration += 1
297
+ return output
298
+
299
+ if torch.is_tensor(output):
300
+ hidden_states = output
301
+ elif isinstance(output, (tuple, Transformer2DModelOutput)):
302
+ hidden_states = output[0]
303
+
304
+ batch_size = hidden_states.size(0)
305
+
306
+ if should_skip_uncond:
307
+ self.state.low_frequency_delta = self.state.low_frequency_delta * self.low_frequency_weight_callback(
308
+ module
309
+ )
310
+ self.state.high_frequency_delta = self.state.high_frequency_delta * self.high_frequency_weight_callback(
311
+ module
312
+ )
313
+
314
+ if self.tensor_format == "BCFHW":
315
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
316
+ if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
317
+ hidden_states = hidden_states.flatten(0, 1)
318
+
319
+ low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float())
320
+
321
+ # Approximate/compute the unconditional branch outputs as described in Equation 9 and 10 of the paper
322
+ low_freq_uncond = self.state.low_frequency_delta + low_freq_cond
323
+ high_freq_uncond = self.state.high_frequency_delta + high_freq_cond
324
+ uncond_freq = low_freq_uncond + high_freq_uncond
325
+
326
+ uncond_states = torch.fft.ifftshift(uncond_freq)
327
+ uncond_states = torch.fft.ifft2(uncond_states).real
328
+
329
+ if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
330
+ uncond_states = uncond_states.unflatten(0, (batch_size, -1))
331
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1))
332
+ if self.tensor_format == "BCFHW":
333
+ uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
334
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
335
+
336
+ # Concatenate the approximated unconditional and predicted conditional branches
337
+ uncond_states = uncond_states.to(hidden_states.dtype)
338
+ hidden_states = torch.cat([uncond_states, hidden_states], dim=0)
339
+ else:
340
+ uncond_states, cond_states = hidden_states.chunk(2, dim=0)
341
+ if self.tensor_format == "BCFHW":
342
+ uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
343
+ cond_states = cond_states.permute(0, 2, 1, 3, 4)
344
+ if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
345
+ uncond_states = uncond_states.flatten(0, 1)
346
+ cond_states = cond_states.flatten(0, 1)
347
+
348
+ low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float())
349
+ low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float())
350
+ self.state.low_frequency_delta = low_freq_uncond - low_freq_cond
351
+ self.state.high_frequency_delta = high_freq_uncond - high_freq_cond
352
+
353
+ self.state.iteration += 1
354
+ if torch.is_tensor(output):
355
+ output = hidden_states
356
+ elif isinstance(output, tuple):
357
+ output = (hidden_states, *output[1:])
358
+ else:
359
+ output.sample = hidden_states
360
+
361
+ return output
362
+
363
+ def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
364
+ self.state.reset()
365
+ return module
366
+
367
+
368
+ class FasterCacheBlockHook(ModelHook):
369
+ _is_stateful = True
370
+
371
+ def __init__(
372
+ self,
373
+ block_skip_range: int,
374
+ timestep_skip_range: Tuple[int, int],
375
+ is_guidance_distilled: bool,
376
+ weight_callback: Callable[[torch.nn.Module], float],
377
+ current_timestep_callback: Callable[[], int],
378
+ ) -> None:
379
+ super().__init__()
380
+
381
+ self.block_skip_range = block_skip_range
382
+ self.timestep_skip_range = timestep_skip_range
383
+ self.is_guidance_distilled = is_guidance_distilled
384
+
385
+ self.weight_callback = weight_callback
386
+ self.current_timestep_callback = current_timestep_callback
387
+
388
+ def initialize_hook(self, module):
389
+ self.state = FasterCacheBlockState()
390
+ return module
391
+
392
+ def _compute_approximated_attention_output(
393
+ self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int
394
+ ) -> torch.Tensor:
395
+ if t_2_output.size(0) != batch_size:
396
+ # The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
397
+ # take the conditional branch outputs.
398
+ assert t_2_output.size(0) == 2 * batch_size
399
+ t_2_output = t_2_output[batch_size:]
400
+ if t_output.size(0) != batch_size:
401
+ # The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
402
+ # take the conditional branch outputs.
403
+ assert t_output.size(0) == 2 * batch_size
404
+ t_output = t_output[batch_size:]
405
+ return t_output + (t_output - t_2_output) * weight
406
+
407
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
408
+ batch_size = [
409
+ *[arg.size(0) for arg in args if torch.is_tensor(arg)],
410
+ *[v.size(0) for v in kwargs.values() if torch.is_tensor(v)],
411
+ ][0]
412
+ if self.state.batch_size is None:
413
+ # Will be updated on first forward pass through the denoiser
414
+ self.state.batch_size = batch_size
415
+
416
+ # If we have to skip due to the skip conditions, then let's skip as expected.
417
+ # But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This
418
+ # is because the expected output shapes of attention layer will not match if we only return values from
419
+ # the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true
420
+ # unconditional-conditional batch size) is same as the current batch size, we don't perform the layer
421
+ # skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns.
422
+ is_within_timestep_range = (
423
+ self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
424
+ )
425
+ if not is_within_timestep_range:
426
+ should_skip_attention = False
427
+ else:
428
+ should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0
429
+ should_skip_attention = not should_compute_attention
430
+ if should_skip_attention:
431
+ should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size
432
+
433
+ if should_skip_attention:
434
+ logger.debug("FasterCache - Skipping attention and using approximation")
435
+ if torch.is_tensor(self.state.cache[-1]):
436
+ t_2_output, t_output = self.state.cache
437
+ weight = self.weight_callback(module)
438
+ output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size)
439
+ else:
440
+ # The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them.
441
+ # Diffusers blocks can return multiple tensors - let's call them [A, B, C, ...] for simplicity.
442
+ # In our cache, we would have [[A_1, B_1, C_1, ...], [A_2, B_2, C_2, ...], ...] where each list is the output from
443
+ # a forward pass of the block. We need to compute the approximated output for each of these tensors.
444
+ # The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which
445
+ # allows us to compute the approximated attention output for each tensor in the cache.
446
+ output = ()
447
+ for t_2_output, t_output in zip(*self.state.cache):
448
+ result = self._compute_approximated_attention_output(
449
+ t_2_output, t_output, self.weight_callback(module), batch_size
450
+ )
451
+ output += (result,)
452
+ else:
453
+ logger.debug("FasterCache - Computing attention")
454
+ output = self.fn_ref.original_forward(*args, **kwargs)
455
+
456
+ # Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return
457
+ # a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle
458
+ # both cases.
459
+ if torch.is_tensor(output):
460
+ cache_output = output
461
+ if not self.is_guidance_distilled and cache_output.size(0) == self.state.batch_size:
462
+ # The output here can be both unconditional-conditional branch outputs or just conditional branch outputs.
463
+ # This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs.
464
+ cache_output = cache_output.chunk(2, dim=0)[1]
465
+ else:
466
+ # Cache all return values and perform the same operation as above
467
+ cache_output = ()
468
+ for out in output:
469
+ if not self.is_guidance_distilled and out.size(0) == self.state.batch_size:
470
+ out = out.chunk(2, dim=0)[1]
471
+ cache_output += (out,)
472
+
473
+ if self.state.cache is None:
474
+ self.state.cache = [cache_output, cache_output]
475
+ else:
476
+ self.state.cache = [self.state.cache[-1], cache_output]
477
+
478
+ self.state.iteration += 1
479
+ return output
480
+
481
+ def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
482
+ self.state.reset()
483
+ return module
484
+
485
+
486
+ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> None:
487
+ r"""
488
+ Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
489
+
490
+ Args:
491
+ pipeline (`DiffusionPipeline`):
492
+ The diffusion pipeline to apply FasterCache to.
493
+ config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
494
+ The configuration to use for FasterCache.
495
+
496
+ Example:
497
+ ```python
498
+ >>> import torch
499
+ >>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache
500
+
501
+ >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
502
+ >>> pipe.to("cuda")
503
+
504
+ >>> config = FasterCacheConfig(
505
+ ... spatial_attention_block_skip_range=2,
506
+ ... spatial_attention_timestep_skip_range=(-1, 681),
507
+ ... low_frequency_weight_update_timestep_range=(99, 641),
508
+ ... high_frequency_weight_update_timestep_range=(-1, 301),
509
+ ... spatial_attention_block_identifiers=["transformer_blocks"],
510
+ ... attention_weight_callback=lambda _: 0.3,
511
+ ... tensor_format="BFCHW",
512
+ ... )
513
+ >>> apply_faster_cache(pipe.transformer, config)
514
+ ```
515
+ """
516
+
517
+ logger.warning(
518
+ "FasterCache is a purely experimental feature and may not work as expected. Not all models support FasterCache. "
519
+ "The API is subject to change in future releases, with no guarantee of backward compatibility. Please report any issues at "
520
+ "https://github.com/huggingface/diffusers/issues."
521
+ )
522
+
523
+ if config.attention_weight_callback is None:
524
+ # If the user has not provided a weight callback, we default to 0.5 for all timesteps.
525
+ # In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
526
+ # this depends from model-to-model. It is required by the user to provide a weight callback if they want to
527
+ # use a different weight function. Defaulting to 0.5 works well in practice for most cases.
528
+ logger.warning(
529
+ "No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps."
530
+ )
531
+ config.attention_weight_callback = lambda _: 0.5
532
+
533
+ if config.low_frequency_weight_callback is None:
534
+ logger.debug(
535
+ "Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
536
+ )
537
+
538
+ def low_frequency_weight_callback(module: torch.nn.Module) -> float:
539
+ is_within_range = (
540
+ config.low_frequency_weight_update_timestep_range[0]
541
+ < config.current_timestep_callback()
542
+ < config.low_frequency_weight_update_timestep_range[1]
543
+ )
544
+ return config.alpha_low_frequency if is_within_range else 1.0
545
+
546
+ config.low_frequency_weight_callback = low_frequency_weight_callback
547
+
548
+ if config.high_frequency_weight_callback is None:
549
+ logger.debug(
550
+ "High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
551
+ )
552
+
553
+ def high_frequency_weight_callback(module: torch.nn.Module) -> float:
554
+ is_within_range = (
555
+ config.high_frequency_weight_update_timestep_range[0]
556
+ < config.current_timestep_callback()
557
+ < config.high_frequency_weight_update_timestep_range[1]
558
+ )
559
+ return config.alpha_high_frequency if is_within_range else 1.0
560
+
561
+ config.high_frequency_weight_callback = high_frequency_weight_callback
562
+
563
+ supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video
564
+ if config.tensor_format not in supported_tensor_formats:
565
+ raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.")
566
+
567
+ _apply_faster_cache_on_denoiser(module, config)
568
+
569
+ for name, submodule in module.named_modules():
570
+ if not isinstance(submodule, _ATTENTION_CLASSES):
571
+ continue
572
+ if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
573
+ _apply_faster_cache_on_attention_class(name, submodule, config)
574
+
575
+
576
+ def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None:
577
+ hook = FasterCacheDenoiserHook(
578
+ config.unconditional_batch_skip_range,
579
+ config.unconditional_batch_timestep_skip_range,
580
+ config.tensor_format,
581
+ config.is_guidance_distilled,
582
+ config._unconditional_conditional_input_kwargs_identifiers,
583
+ config.current_timestep_callback,
584
+ config.low_frequency_weight_callback,
585
+ config.high_frequency_weight_callback,
586
+ )
587
+ registry = HookRegistry.check_if_exists_or_initialize(module)
588
+ registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
589
+
590
+
591
+ def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None:
592
+ is_spatial_self_attention = (
593
+ any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
594
+ and config.spatial_attention_block_skip_range is not None
595
+ and not getattr(module, "is_cross_attention", False)
596
+ )
597
+ is_temporal_self_attention = (
598
+ any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
599
+ and config.temporal_attention_block_skip_range is not None
600
+ and not module.is_cross_attention
601
+ )
602
+
603
+ block_skip_range, timestep_skip_range, block_type = None, None, None
604
+ if is_spatial_self_attention:
605
+ block_skip_range = config.spatial_attention_block_skip_range
606
+ timestep_skip_range = config.spatial_attention_timestep_skip_range
607
+ block_type = "spatial"
608
+ elif is_temporal_self_attention:
609
+ block_skip_range = config.temporal_attention_block_skip_range
610
+ timestep_skip_range = config.temporal_attention_timestep_skip_range
611
+ block_type = "temporal"
612
+
613
+ if block_skip_range is None or timestep_skip_range is None:
614
+ logger.debug(
615
+ f'Unable to apply FasterCache to the selected layer: "{name}" because it does '
616
+ f"not match any of the required criteria for spatial or temporal attention layers. Note, "
617
+ f"however, that this layer may still be valid for applying PAB. Please specify the correct "
618
+ f"block identifiers in the configuration or use the specialized `apply_faster_cache_on_module` "
619
+ f"function to apply FasterCache to this layer."
620
+ )
621
+ return
622
+
623
+ logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}")
624
+ hook = FasterCacheBlockHook(
625
+ block_skip_range,
626
+ timestep_skip_range,
627
+ config.is_guidance_distilled,
628
+ config.attention_weight_callback,
629
+ config.current_timestep_callback,
630
+ )
631
+ registry = HookRegistry.check_if_exists_or_initialize(module)
632
+ registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK)
633
+
634
+
635
+ # Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/faster_cache_sample_latte.py#L127C1-L143C39
636
+ @torch.no_grad()
637
+ def _split_low_high_freq(x):
638
+ fft = torch.fft.fft2(x)
639
+ fft_shifted = torch.fft.fftshift(fft)
640
+ height, width = x.shape[-2:]
641
+ radius = min(height, width) // 5
642
+
643
+ y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width))
644
+ center_x, center_y = width // 2, height // 2
645
+ mask = (x_grid - center_x) ** 2 + (y_grid - center_y) ** 2 <= radius**2
646
+
647
+ low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(x.device)
648
+ high_freq_mask = ~low_freq_mask
649
+
650
+ low_freq_fft = fft_shifted * low_freq_mask
651
+ high_freq_fft = fft_shifted * high_freq_mask
652
+
653
+ return low_freq_fft, high_freq_fft