diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,56 @@
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+
19
+ from ..models.attention import AttentionModuleMixin, FeedForward, LuminaFeedForward
20
+ from ..models.attention_processor import Attention, MochiAttention
21
+
22
+
23
+ _ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
24
+ _FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
25
+
26
+ _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
27
+ _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
28
+ _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
29
+
30
+ _ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
31
+ {
32
+ *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
33
+ *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
34
+ *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
35
+ }
36
+ )
37
+
38
+ # Layers supported for group offloading and layerwise casting
39
+ _GO_LC_SUPPORTED_PYTORCH_LAYERS = (
40
+ torch.nn.Conv1d,
41
+ torch.nn.Conv2d,
42
+ torch.nn.Conv3d,
43
+ torch.nn.ConvTranspose1d,
44
+ torch.nn.ConvTranspose2d,
45
+ torch.nn.ConvTranspose3d,
46
+ torch.nn.Linear,
47
+ # TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
48
+ # because of double invocation of the same norm layer in CogVideoXLayerNorm
49
+ )
50
+
51
+
52
+ def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
53
+ for submodule_name, submodule in module.named_modules():
54
+ if submodule_name == fqn:
55
+ return submodule
56
+ return None
@@ -0,0 +1,293 @@
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from dataclasses import dataclass
17
+ from typing import Any, Callable, Dict, Type
18
+
19
+
20
+ @dataclass
21
+ class AttentionProcessorMetadata:
22
+ skip_processor_output_fn: Callable[[Any], Any]
23
+
24
+
25
+ @dataclass
26
+ class TransformerBlockMetadata:
27
+ return_hidden_states_index: int = None
28
+ return_encoder_hidden_states_index: int = None
29
+
30
+ _cls: Type = None
31
+ _cached_parameter_indices: Dict[str, int] = None
32
+
33
+ def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
34
+ kwargs = kwargs or {}
35
+ if identifier in kwargs:
36
+ return kwargs[identifier]
37
+ if self._cached_parameter_indices is not None:
38
+ return args[self._cached_parameter_indices[identifier]]
39
+ if self._cls is None:
40
+ raise ValueError("Model class is not set for metadata.")
41
+ parameters = list(inspect.signature(self._cls.forward).parameters.keys())
42
+ parameters = parameters[1:] # skip `self`
43
+ self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
44
+ if identifier not in self._cached_parameter_indices:
45
+ raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
46
+ index = self._cached_parameter_indices[identifier]
47
+ if index >= len(args):
48
+ raise ValueError(f"Expected {index} arguments but got {len(args)}.")
49
+ return args[index]
50
+
51
+
52
+ class AttentionProcessorRegistry:
53
+ _registry = {}
54
+ # TODO(aryan): this is only required for the time being because we need to do the registrations
55
+ # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
56
+ # import errors because of the models imported in this file.
57
+ _is_registered = False
58
+
59
+ @classmethod
60
+ def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
61
+ cls._register()
62
+ cls._registry[model_class] = metadata
63
+
64
+ @classmethod
65
+ def get(cls, model_class: Type) -> AttentionProcessorMetadata:
66
+ cls._register()
67
+ if model_class not in cls._registry:
68
+ raise ValueError(f"Model class {model_class} not registered.")
69
+ return cls._registry[model_class]
70
+
71
+ @classmethod
72
+ def _register(cls):
73
+ if cls._is_registered:
74
+ return
75
+ cls._is_registered = True
76
+ _register_attention_processors_metadata()
77
+
78
+
79
+ class TransformerBlockRegistry:
80
+ _registry = {}
81
+ # TODO(aryan): this is only required for the time being because we need to do the registrations
82
+ # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
83
+ # import errors because of the models imported in this file.
84
+ _is_registered = False
85
+
86
+ @classmethod
87
+ def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
88
+ cls._register()
89
+ metadata._cls = model_class
90
+ cls._registry[model_class] = metadata
91
+
92
+ @classmethod
93
+ def get(cls, model_class: Type) -> TransformerBlockMetadata:
94
+ cls._register()
95
+ if model_class not in cls._registry:
96
+ raise ValueError(f"Model class {model_class} not registered.")
97
+ return cls._registry[model_class]
98
+
99
+ @classmethod
100
+ def _register(cls):
101
+ if cls._is_registered:
102
+ return
103
+ cls._is_registered = True
104
+ _register_transformer_blocks_metadata()
105
+
106
+
107
+ def _register_attention_processors_metadata():
108
+ from ..models.attention_processor import AttnProcessor2_0
109
+ from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
110
+ from ..models.transformers.transformer_flux import FluxAttnProcessor
111
+ from ..models.transformers.transformer_wan import WanAttnProcessor2_0
112
+
113
+ # AttnProcessor2_0
114
+ AttentionProcessorRegistry.register(
115
+ model_class=AttnProcessor2_0,
116
+ metadata=AttentionProcessorMetadata(
117
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
118
+ ),
119
+ )
120
+
121
+ # CogView4AttnProcessor
122
+ AttentionProcessorRegistry.register(
123
+ model_class=CogView4AttnProcessor,
124
+ metadata=AttentionProcessorMetadata(
125
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
126
+ ),
127
+ )
128
+
129
+ # WanAttnProcessor2_0
130
+ AttentionProcessorRegistry.register(
131
+ model_class=WanAttnProcessor2_0,
132
+ metadata=AttentionProcessorMetadata(
133
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
134
+ ),
135
+ )
136
+
137
+ # FluxAttnProcessor
138
+ AttentionProcessorRegistry.register(
139
+ model_class=FluxAttnProcessor,
140
+ metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
141
+ )
142
+
143
+
144
+ def _register_transformer_blocks_metadata():
145
+ from ..models.attention import BasicTransformerBlock
146
+ from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
147
+ from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
148
+ from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
149
+ from ..models.transformers.transformer_hunyuan_video import (
150
+ HunyuanVideoSingleTransformerBlock,
151
+ HunyuanVideoTokenReplaceSingleTransformerBlock,
152
+ HunyuanVideoTokenReplaceTransformerBlock,
153
+ HunyuanVideoTransformerBlock,
154
+ )
155
+ from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
156
+ from ..models.transformers.transformer_mochi import MochiTransformerBlock
157
+ from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
158
+ from ..models.transformers.transformer_wan import WanTransformerBlock
159
+
160
+ # BasicTransformerBlock
161
+ TransformerBlockRegistry.register(
162
+ model_class=BasicTransformerBlock,
163
+ metadata=TransformerBlockMetadata(
164
+ return_hidden_states_index=0,
165
+ return_encoder_hidden_states_index=None,
166
+ ),
167
+ )
168
+
169
+ # CogVideoX
170
+ TransformerBlockRegistry.register(
171
+ model_class=CogVideoXBlock,
172
+ metadata=TransformerBlockMetadata(
173
+ return_hidden_states_index=0,
174
+ return_encoder_hidden_states_index=1,
175
+ ),
176
+ )
177
+
178
+ # CogView4
179
+ TransformerBlockRegistry.register(
180
+ model_class=CogView4TransformerBlock,
181
+ metadata=TransformerBlockMetadata(
182
+ return_hidden_states_index=0,
183
+ return_encoder_hidden_states_index=1,
184
+ ),
185
+ )
186
+
187
+ # Flux
188
+ TransformerBlockRegistry.register(
189
+ model_class=FluxTransformerBlock,
190
+ metadata=TransformerBlockMetadata(
191
+ return_hidden_states_index=1,
192
+ return_encoder_hidden_states_index=0,
193
+ ),
194
+ )
195
+ TransformerBlockRegistry.register(
196
+ model_class=FluxSingleTransformerBlock,
197
+ metadata=TransformerBlockMetadata(
198
+ return_hidden_states_index=1,
199
+ return_encoder_hidden_states_index=0,
200
+ ),
201
+ )
202
+
203
+ # HunyuanVideo
204
+ TransformerBlockRegistry.register(
205
+ model_class=HunyuanVideoTransformerBlock,
206
+ metadata=TransformerBlockMetadata(
207
+ return_hidden_states_index=0,
208
+ return_encoder_hidden_states_index=1,
209
+ ),
210
+ )
211
+ TransformerBlockRegistry.register(
212
+ model_class=HunyuanVideoSingleTransformerBlock,
213
+ metadata=TransformerBlockMetadata(
214
+ return_hidden_states_index=0,
215
+ return_encoder_hidden_states_index=1,
216
+ ),
217
+ )
218
+ TransformerBlockRegistry.register(
219
+ model_class=HunyuanVideoTokenReplaceTransformerBlock,
220
+ metadata=TransformerBlockMetadata(
221
+ return_hidden_states_index=0,
222
+ return_encoder_hidden_states_index=1,
223
+ ),
224
+ )
225
+ TransformerBlockRegistry.register(
226
+ model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
227
+ metadata=TransformerBlockMetadata(
228
+ return_hidden_states_index=0,
229
+ return_encoder_hidden_states_index=1,
230
+ ),
231
+ )
232
+
233
+ # LTXVideo
234
+ TransformerBlockRegistry.register(
235
+ model_class=LTXVideoTransformerBlock,
236
+ metadata=TransformerBlockMetadata(
237
+ return_hidden_states_index=0,
238
+ return_encoder_hidden_states_index=None,
239
+ ),
240
+ )
241
+
242
+ # Mochi
243
+ TransformerBlockRegistry.register(
244
+ model_class=MochiTransformerBlock,
245
+ metadata=TransformerBlockMetadata(
246
+ return_hidden_states_index=0,
247
+ return_encoder_hidden_states_index=1,
248
+ ),
249
+ )
250
+
251
+ # Wan
252
+ TransformerBlockRegistry.register(
253
+ model_class=WanTransformerBlock,
254
+ metadata=TransformerBlockMetadata(
255
+ return_hidden_states_index=0,
256
+ return_encoder_hidden_states_index=None,
257
+ ),
258
+ )
259
+
260
+ # QwenImage
261
+ TransformerBlockRegistry.register(
262
+ model_class=QwenImageTransformerBlock,
263
+ metadata=TransformerBlockMetadata(
264
+ return_hidden_states_index=1,
265
+ return_encoder_hidden_states_index=0,
266
+ ),
267
+ )
268
+
269
+
270
+ # fmt: off
271
+ def _skip_attention___ret___hidden_states(self, *args, **kwargs):
272
+ hidden_states = kwargs.get("hidden_states", None)
273
+ if hidden_states is None and len(args) > 0:
274
+ hidden_states = args[0]
275
+ return hidden_states
276
+
277
+
278
+ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
279
+ hidden_states = kwargs.get("hidden_states", None)
280
+ encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
281
+ if hidden_states is None and len(args) > 0:
282
+ hidden_states = args[0]
283
+ if encoder_hidden_states is None and len(args) > 1:
284
+ encoder_hidden_states = args[1]
285
+ return hidden_states, encoder_hidden_states
286
+
287
+
288
+ _skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
289
+ _skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
290
+ _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
291
+ # not sure what this is yet.
292
+ _skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
293
+ # fmt: on
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -18,9 +18,10 @@ from typing import Any, Callable, List, Optional, Tuple
18
18
 
19
19
  import torch
20
20
 
21
- from ..models.attention_processor import Attention, MochiAttention
21
+ from ..models.attention import AttentionModuleMixin
22
22
  from ..models.modeling_outputs import Transformer2DModelOutput
23
23
  from ..utils import logging
24
+ from ._common import _ATTENTION_CLASSES
24
25
  from .hooks import HookRegistry, ModelHook
25
26
 
26
27
 
@@ -29,7 +30,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
30
 
30
31
  _FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
31
32
  _FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
32
- _ATTENTION_CLASSES = (Attention, MochiAttention)
33
33
  _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
34
34
  "^blocks.*attn",
35
35
  "^transformer_blocks.*attn",
@@ -146,7 +146,7 @@ class FasterCacheConfig:
146
146
  alpha_low_frequency: float = 1.1
147
147
  alpha_high_frequency: float = 1.1
148
148
 
149
- # n as described in CFG-Cache explanation in the paper - dependant on the model
149
+ # n as described in CFG-Cache explanation in the paper - dependent on the model
150
150
  unconditional_batch_skip_range: int = 5
151
151
  unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
152
152
 
@@ -488,9 +488,10 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No
488
488
  Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
489
489
 
490
490
  Args:
491
- pipeline (`DiffusionPipeline`):
492
- The diffusion pipeline to apply FasterCache to.
493
- config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
491
+ module (`torch.nn.Module`):
492
+ The pytorch module to apply FasterCache to. Typically, this should be a transformer architecture supported
493
+ in Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
494
+ config (`FasterCacheConfig`):
494
495
  The configuration to use for FasterCache.
495
496
 
496
497
  Example:
@@ -588,7 +589,7 @@ def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCache
588
589
  registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
589
590
 
590
591
 
591
- def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None:
592
+ def _apply_faster_cache_on_attention_class(name: str, module: AttentionModuleMixin, config: FasterCacheConfig) -> None:
592
593
  is_spatial_self_attention = (
593
594
  any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
594
595
  and config.spatial_attention_block_skip_range is not None
@@ -0,0 +1,259 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Tuple, Union
17
+
18
+ import torch
19
+
20
+ from ..utils import get_logger
21
+ from ..utils.torch_utils import unwrap_module
22
+ from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
23
+ from ._helpers import TransformerBlockRegistry
24
+ from .hooks import BaseState, HookRegistry, ModelHook, StateManager
25
+
26
+
27
+ logger = get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+ _FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
30
+ _FBC_BLOCK_HOOK = "fbc_block_hook"
31
+
32
+
33
+ @dataclass
34
+ class FirstBlockCacheConfig:
35
+ r"""
36
+ Configuration for [First Block
37
+ Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
38
+
39
+ Args:
40
+ threshold (`float`, defaults to `0.05`):
41
+ The threshold to determine whether or not a forward pass through all layers of the model is required. A
42
+ higher threshold usually results in a forward pass through a lower number of layers and faster inference,
43
+ but might lead to poorer generation quality. A lower threshold may not result in significant generation
44
+ speedup. The threshold is compared against the absmean difference of the residuals between the current and
45
+ cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
46
+ is skipped.
47
+ """
48
+
49
+ threshold: float = 0.05
50
+
51
+
52
+ class FBCSharedBlockState(BaseState):
53
+ def __init__(self) -> None:
54
+ super().__init__()
55
+
56
+ self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
57
+ self.head_block_residual: torch.Tensor = None
58
+ self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
59
+ self.should_compute: bool = True
60
+
61
+ def reset(self):
62
+ self.tail_block_residuals = None
63
+ self.should_compute = True
64
+
65
+
66
+ class FBCHeadBlockHook(ModelHook):
67
+ _is_stateful = True
68
+
69
+ def __init__(self, state_manager: StateManager, threshold: float):
70
+ self.state_manager = state_manager
71
+ self.threshold = threshold
72
+ self._metadata = None
73
+
74
+ def initialize_hook(self, module):
75
+ unwrapped_module = unwrap_module(module)
76
+ self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
77
+ return module
78
+
79
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
80
+ original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
81
+
82
+ output = self.fn_ref.original_forward(*args, **kwargs)
83
+ is_output_tuple = isinstance(output, tuple)
84
+
85
+ if is_output_tuple:
86
+ hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
87
+ else:
88
+ hidden_states_residual = output - original_hidden_states
89
+
90
+ shared_state: FBCSharedBlockState = self.state_manager.get_state()
91
+ hidden_states = encoder_hidden_states = None
92
+ should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
93
+ shared_state.should_compute = should_compute
94
+
95
+ if not should_compute:
96
+ # Apply caching
97
+ if is_output_tuple:
98
+ hidden_states = (
99
+ shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
100
+ )
101
+ else:
102
+ hidden_states = shared_state.tail_block_residuals[0] + output
103
+
104
+ if self._metadata.return_encoder_hidden_states_index is not None:
105
+ assert is_output_tuple
106
+ encoder_hidden_states = (
107
+ shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
108
+ )
109
+
110
+ if is_output_tuple:
111
+ return_output = [None] * len(output)
112
+ return_output[self._metadata.return_hidden_states_index] = hidden_states
113
+ return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
114
+ return_output = tuple(return_output)
115
+ else:
116
+ return_output = hidden_states
117
+ output = return_output
118
+ else:
119
+ if is_output_tuple:
120
+ head_block_output = [None] * len(output)
121
+ head_block_output[0] = output[self._metadata.return_hidden_states_index]
122
+ head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
123
+ else:
124
+ head_block_output = output
125
+ shared_state.head_block_output = head_block_output
126
+ shared_state.head_block_residual = hidden_states_residual
127
+
128
+ return output
129
+
130
+ def reset_state(self, module):
131
+ self.state_manager.reset()
132
+ return module
133
+
134
+ @torch.compiler.disable
135
+ def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
136
+ shared_state = self.state_manager.get_state()
137
+ if shared_state.head_block_residual is None:
138
+ return True
139
+ prev_hidden_states_residual = shared_state.head_block_residual
140
+ absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
141
+ prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
142
+ diff = (absmean / prev_hidden_states_absmean).item()
143
+ return diff > self.threshold
144
+
145
+
146
+ class FBCBlockHook(ModelHook):
147
+ def __init__(self, state_manager: StateManager, is_tail: bool = False):
148
+ super().__init__()
149
+ self.state_manager = state_manager
150
+ self.is_tail = is_tail
151
+ self._metadata = None
152
+
153
+ def initialize_hook(self, module):
154
+ unwrapped_module = unwrap_module(module)
155
+ self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
156
+ return module
157
+
158
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
159
+ original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
160
+ original_encoder_hidden_states = None
161
+ if self._metadata.return_encoder_hidden_states_index is not None:
162
+ original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
163
+ "encoder_hidden_states", args, kwargs
164
+ )
165
+
166
+ shared_state = self.state_manager.get_state()
167
+
168
+ if shared_state.should_compute:
169
+ output = self.fn_ref.original_forward(*args, **kwargs)
170
+ if self.is_tail:
171
+ hidden_states_residual = encoder_hidden_states_residual = None
172
+ if isinstance(output, tuple):
173
+ hidden_states_residual = (
174
+ output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
175
+ )
176
+ encoder_hidden_states_residual = (
177
+ output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
178
+ )
179
+ else:
180
+ hidden_states_residual = output - shared_state.head_block_output
181
+ shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
182
+ return output
183
+
184
+ if original_encoder_hidden_states is None:
185
+ return_output = original_hidden_states
186
+ else:
187
+ return_output = [None, None]
188
+ return_output[self._metadata.return_hidden_states_index] = original_hidden_states
189
+ return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
190
+ return_output = tuple(return_output)
191
+ return return_output
192
+
193
+
194
+ def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
195
+ """
196
+ Applies [First Block
197
+ Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching)
198
+ to a given module.
199
+
200
+ First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler
201
+ to implement generically for a wide range of models and has been integrated first for experimental purposes.
202
+
203
+ Args:
204
+ module (`torch.nn.Module`):
205
+ The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in
206
+ Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
207
+ config (`FirstBlockCacheConfig`):
208
+ The configuration to use for applying the FBCache method.
209
+
210
+ Example:
211
+ ```python
212
+ >>> import torch
213
+ >>> from diffusers import CogView4Pipeline
214
+ >>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
215
+
216
+ >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
217
+ >>> pipe.to("cuda")
218
+
219
+ >>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
220
+
221
+ >>> prompt = "A photo of an astronaut riding a horse on mars"
222
+ >>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
223
+ >>> image.save("output.png")
224
+ ```
225
+ """
226
+
227
+ state_manager = StateManager(FBCSharedBlockState, (), {})
228
+ remaining_blocks = []
229
+
230
+ for name, submodule in module.named_children():
231
+ if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
232
+ continue
233
+ for index, block in enumerate(submodule):
234
+ remaining_blocks.append((f"{name}.{index}", block))
235
+
236
+ head_block_name, head_block = remaining_blocks.pop(0)
237
+ tail_block_name, tail_block = remaining_blocks.pop(-1)
238
+
239
+ logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
240
+ _apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
241
+
242
+ for name, block in remaining_blocks:
243
+ logger.debug(f"Applying FBCBlockHook to '{name}'")
244
+ _apply_fbc_block_hook(block, state_manager)
245
+
246
+ logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
247
+ _apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
248
+
249
+
250
+ def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
251
+ registry = HookRegistry.check_if_exists_or_initialize(block)
252
+ hook = FBCHeadBlockHook(state_manager, threshold)
253
+ registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
254
+
255
+
256
+ def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
257
+ registry = HookRegistry.check_if_exists_or_initialize(block)
258
+ hook = FBCBlockHook(state_manager, is_tail)
259
+ registry.register_hook(hook, _FBC_BLOCK_HOOK)