diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1218 @@
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import contextlib
16
+ import functools
17
+ import inspect
18
+ import math
19
+ from enum import Enum
20
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
21
+
22
+ import torch
23
+
24
+ from ..utils import (
25
+ get_logger,
26
+ is_flash_attn_3_available,
27
+ is_flash_attn_available,
28
+ is_flash_attn_version,
29
+ is_sageattention_available,
30
+ is_sageattention_version,
31
+ is_torch_npu_available,
32
+ is_torch_version,
33
+ is_torch_xla_available,
34
+ is_torch_xla_version,
35
+ is_xformers_available,
36
+ is_xformers_version,
37
+ )
38
+ from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
39
+
40
+
41
+ _REQUIRED_FLASH_VERSION = "2.6.3"
42
+ _REQUIRED_SAGE_VERSION = "2.1.1"
43
+ _REQUIRED_FLEX_VERSION = "2.5.0"
44
+ _REQUIRED_XLA_VERSION = "2.2"
45
+ _REQUIRED_XFORMERS_VERSION = "0.0.29"
46
+
47
+ _CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
48
+ _CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
49
+ _CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
50
+ _CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
51
+ _CAN_USE_NPU_ATTN = is_torch_npu_available()
52
+ _CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
53
+ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
54
+
55
+
56
+ if _CAN_USE_FLASH_ATTN:
57
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
58
+ else:
59
+ flash_attn_func = None
60
+ flash_attn_varlen_func = None
61
+
62
+
63
+ if _CAN_USE_FLASH_ATTN_3:
64
+ from flash_attn_interface import flash_attn_func as flash_attn_3_func
65
+ from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
66
+ else:
67
+ flash_attn_3_func = None
68
+ flash_attn_3_varlen_func = None
69
+
70
+
71
+ if _CAN_USE_SAGE_ATTN:
72
+ from sageattention import (
73
+ sageattn,
74
+ sageattn_qk_int8_pv_fp8_cuda,
75
+ sageattn_qk_int8_pv_fp8_cuda_sm90,
76
+ sageattn_qk_int8_pv_fp16_cuda,
77
+ sageattn_qk_int8_pv_fp16_triton,
78
+ sageattn_varlen,
79
+ )
80
+ else:
81
+ sageattn = None
82
+ sageattn_qk_int8_pv_fp16_cuda = None
83
+ sageattn_qk_int8_pv_fp16_triton = None
84
+ sageattn_qk_int8_pv_fp8_cuda = None
85
+ sageattn_qk_int8_pv_fp8_cuda_sm90 = None
86
+ sageattn_varlen = None
87
+
88
+
89
+ if _CAN_USE_FLEX_ATTN:
90
+ # We cannot import the flex_attention function from the package directly because it is expected (from the
91
+ # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
92
+ # compiled function.
93
+ import torch.nn.attention.flex_attention as flex_attention
94
+
95
+
96
+ if _CAN_USE_NPU_ATTN:
97
+ from torch_npu import npu_fusion_attention
98
+ else:
99
+ npu_fusion_attention = None
100
+
101
+
102
+ if _CAN_USE_XLA_ATTN:
103
+ from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
104
+ else:
105
+ xla_flash_attention = None
106
+
107
+
108
+ if _CAN_USE_XFORMERS_ATTN:
109
+ import xformers.ops as xops
110
+ else:
111
+ xops = None
112
+
113
+
114
+ logger = get_logger(__name__) # pylint: disable=invalid-name
115
+
116
+ # TODO(aryan): Add support for the following:
117
+ # - Sage Attention++
118
+ # - block sparse, radial and other attention methods
119
+ # - CP with sage attention, flex, xformers, other missing backends
120
+ # - Add support for normal and CP training with backends that don't support it yet
121
+
122
+ _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
123
+ _SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
124
+ _SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
125
+
126
+
127
+ class AttentionBackendName(str, Enum):
128
+ # EAGER = "eager"
129
+
130
+ # `flash-attn`
131
+ FLASH = "flash"
132
+ FLASH_VARLEN = "flash_varlen"
133
+ _FLASH_3 = "_flash_3"
134
+ _FLASH_VARLEN_3 = "_flash_varlen_3"
135
+
136
+ # PyTorch native
137
+ FLEX = "flex"
138
+ NATIVE = "native"
139
+ _NATIVE_CUDNN = "_native_cudnn"
140
+ _NATIVE_EFFICIENT = "_native_efficient"
141
+ _NATIVE_FLASH = "_native_flash"
142
+ _NATIVE_MATH = "_native_math"
143
+ _NATIVE_NPU = "_native_npu"
144
+ _NATIVE_XLA = "_native_xla"
145
+
146
+ # `sageattention`
147
+ SAGE = "sage"
148
+ SAGE_VARLEN = "sage_varlen"
149
+ _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
150
+ _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
151
+ _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda"
152
+ _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton"
153
+ # TODO: let's not add support for Sparge Attention now because it requires tuning per model
154
+ # We can look into supporting something "autotune"-ing in the future
155
+ # SPARGE = "sparge"
156
+
157
+ # `xformers`
158
+ XFORMERS = "xformers"
159
+
160
+
161
+ class _AttentionBackendRegistry:
162
+ _backends = {}
163
+ _constraints = {}
164
+ _supported_arg_names = {}
165
+ _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
166
+ _checks_enabled = DIFFUSERS_ATTN_CHECKS
167
+
168
+ @classmethod
169
+ def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None):
170
+ logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
171
+
172
+ def decorator(func):
173
+ cls._backends[backend] = func
174
+ cls._constraints[backend] = constraints or []
175
+ cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
176
+ return func
177
+
178
+ return decorator
179
+
180
+ @classmethod
181
+ def get_active_backend(cls):
182
+ return cls._active_backend, cls._backends[cls._active_backend]
183
+
184
+ @classmethod
185
+ def list_backends(cls):
186
+ return list(cls._backends.keys())
187
+
188
+
189
+ @contextlib.contextmanager
190
+ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
191
+ """
192
+ Context manager to set the active attention backend.
193
+ """
194
+ if backend not in _AttentionBackendRegistry._backends:
195
+ raise ValueError(f"Backend {backend} is not registered.")
196
+
197
+ backend = AttentionBackendName(backend)
198
+ _check_attention_backend_requirements(backend)
199
+
200
+ old_backend = _AttentionBackendRegistry._active_backend
201
+ _AttentionBackendRegistry._active_backend = backend
202
+
203
+ try:
204
+ yield
205
+ finally:
206
+ _AttentionBackendRegistry._active_backend = old_backend
207
+
208
+
209
+ def dispatch_attention_fn(
210
+ query: torch.Tensor,
211
+ key: torch.Tensor,
212
+ value: torch.Tensor,
213
+ attn_mask: Optional[torch.Tensor] = None,
214
+ dropout_p: float = 0.0,
215
+ is_causal: bool = False,
216
+ scale: Optional[float] = None,
217
+ enable_gqa: bool = False,
218
+ attention_kwargs: Optional[Dict[str, Any]] = None,
219
+ *,
220
+ backend: Optional[AttentionBackendName] = None,
221
+ ) -> torch.Tensor:
222
+ attention_kwargs = attention_kwargs or {}
223
+
224
+ if backend is None:
225
+ # If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment
226
+ # variable), or we use a custom backend based on whether user is using the `attention_backend` context manager
227
+ backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend()
228
+ else:
229
+ backend_name = AttentionBackendName(backend)
230
+ backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
231
+
232
+ kwargs = {
233
+ "query": query,
234
+ "key": key,
235
+ "value": value,
236
+ "attn_mask": attn_mask,
237
+ "dropout_p": dropout_p,
238
+ "is_causal": is_causal,
239
+ "scale": scale,
240
+ **attention_kwargs,
241
+ }
242
+ if is_torch_version(">=", "2.5.0"):
243
+ kwargs["enable_gqa"] = enable_gqa
244
+
245
+ if _AttentionBackendRegistry._checks_enabled:
246
+ removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name])
247
+ if removed_kwargs:
248
+ logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.")
249
+ for check in _AttentionBackendRegistry._constraints.get(backend_name):
250
+ check(**kwargs)
251
+
252
+ kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
253
+ return backend_fn(**kwargs)
254
+
255
+
256
+ # ===== Checks =====
257
+ # A list of very simple functions to catch common errors quickly when debugging.
258
+
259
+
260
+ def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None:
261
+ if attn_mask is not None and is_causal:
262
+ raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.")
263
+
264
+
265
+ def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
266
+ if query.device != key.device or query.device != value.device:
267
+ raise ValueError("Query, key, and value must be on the same device.")
268
+ if query.dtype != key.dtype or query.dtype != value.dtype:
269
+ raise ValueError("Query, key, and value must have the same dtype.")
270
+
271
+
272
+ def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
273
+ _check_device(query, key, value)
274
+ if query.device.type != "cuda":
275
+ raise ValueError("Query, key, and value must be on a CUDA device.")
276
+
277
+
278
+ def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable:
279
+ def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
280
+ _check_device_cuda(query, key, value)
281
+ if torch.cuda.get_device_capability(query.device) < (major, minor):
282
+ raise ValueError(
283
+ f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}."
284
+ )
285
+
286
+ return check_device_cuda
287
+
288
+
289
+ def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
290
+ if query.dtype != key.dtype:
291
+ raise ValueError("Query and key must have the same dtype.")
292
+ if query.dtype != value.dtype:
293
+ raise ValueError("Query and value must have the same dtype.")
294
+
295
+
296
+ def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
297
+ _check_qkv_dtype_match(query, key, value)
298
+ if query.dtype not in (torch.bfloat16, torch.float16):
299
+ raise ValueError("Query, key, and value must be either bfloat16 or float16.")
300
+
301
+
302
+ def _check_shape(
303
+ query: torch.Tensor,
304
+ key: torch.Tensor,
305
+ value: torch.Tensor,
306
+ attn_mask: Optional[torch.Tensor] = None,
307
+ **kwargs,
308
+ ) -> None:
309
+ if query.shape[-1] != key.shape[-1]:
310
+ raise ValueError("Query and key must have the same last dimension.")
311
+ if query.shape[-2] != value.shape[-2]:
312
+ raise ValueError("Query and value must have the same second to last dimension.")
313
+ if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]:
314
+ raise ValueError("Attention mask must match the key's second to last dimension.")
315
+
316
+
317
+ # ===== Helper functions =====
318
+
319
+
320
+ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None:
321
+ if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]:
322
+ if not _CAN_USE_FLASH_ATTN:
323
+ raise RuntimeError(
324
+ f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`."
325
+ )
326
+
327
+ elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
328
+ if not _CAN_USE_FLASH_ATTN_3:
329
+ raise RuntimeError(
330
+ f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
331
+ )
332
+
333
+ elif backend in [
334
+ AttentionBackendName.SAGE,
335
+ AttentionBackendName.SAGE_VARLEN,
336
+ AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
337
+ AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
338
+ AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
339
+ AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
340
+ ]:
341
+ if not _CAN_USE_SAGE_ATTN:
342
+ raise RuntimeError(
343
+ f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`."
344
+ )
345
+
346
+ elif backend == AttentionBackendName.FLEX:
347
+ if not _CAN_USE_FLEX_ATTN:
348
+ raise RuntimeError(
349
+ f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`."
350
+ )
351
+
352
+ elif backend == AttentionBackendName._NATIVE_NPU:
353
+ if not _CAN_USE_NPU_ATTN:
354
+ raise RuntimeError(
355
+ f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`."
356
+ )
357
+
358
+ elif backend == AttentionBackendName._NATIVE_XLA:
359
+ if not _CAN_USE_XLA_ATTN:
360
+ raise RuntimeError(
361
+ f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`."
362
+ )
363
+
364
+ elif backend == AttentionBackendName.XFORMERS:
365
+ if not _CAN_USE_XFORMERS_ATTN:
366
+ raise RuntimeError(
367
+ f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`."
368
+ )
369
+
370
+
371
+ @functools.lru_cache(maxsize=128)
372
+ def _prepare_for_flash_attn_or_sage_varlen_without_mask(
373
+ batch_size: int,
374
+ seq_len_q: int,
375
+ seq_len_kv: int,
376
+ device: Optional[torch.device] = None,
377
+ ):
378
+ seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
379
+ seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
380
+ cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
381
+ cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
382
+ cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
383
+ cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
384
+ max_seqlen_q = seqlens_q.max().item()
385
+ max_seqlen_k = seqlens_k.max().item()
386
+ return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
387
+
388
+
389
+ def _prepare_for_flash_attn_or_sage_varlen_with_mask(
390
+ batch_size: int,
391
+ seq_len_q: int,
392
+ attn_mask: torch.Tensor,
393
+ device: Optional[torch.device] = None,
394
+ ):
395
+ seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
396
+ seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
397
+ cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
398
+ cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
399
+ cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
400
+ cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
401
+ max_seqlen_q = seqlens_q.max().item()
402
+ max_seqlen_k = seqlens_k.max().item()
403
+ return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
404
+
405
+
406
+ def _prepare_for_flash_attn_or_sage_varlen(
407
+ batch_size: int,
408
+ seq_len_q: int,
409
+ seq_len_kv: int,
410
+ attn_mask: Optional[torch.Tensor] = None,
411
+ device: Optional[torch.device] = None,
412
+ ) -> None:
413
+ if attn_mask is None:
414
+ return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device)
415
+ return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device)
416
+
417
+
418
+ def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
419
+ """
420
+ Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in
421
+ FlashAttention/Sage varlen.
422
+
423
+ Supports 1D to 4D shapes and common broadcasting patterns.
424
+ """
425
+ if attn_mask.dtype != torch.bool:
426
+ raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.")
427
+
428
+ if attn_mask.ndim == 1:
429
+ # [seq_len_k] -> broadcast across batch
430
+ attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k)
431
+
432
+ elif attn_mask.ndim == 2:
433
+ # [batch_size, seq_len_k]. Maybe broadcast across batch
434
+ if attn_mask.size(0) not in [1, batch_size]:
435
+ raise ValueError(
436
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask."
437
+ )
438
+ attn_mask = attn_mask.expand(batch_size, seq_len_k)
439
+
440
+ elif attn_mask.ndim == 3:
441
+ # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension
442
+ # We do this reduction because we know that arbitrary QK masks is not supported in Flash/Sage varlen.
443
+ if attn_mask.size(0) not in [1, batch_size]:
444
+ raise ValueError(
445
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask."
446
+ )
447
+ attn_mask = attn_mask.any(dim=1)
448
+ attn_mask = attn_mask.expand(batch_size, seq_len_k)
449
+
450
+ elif attn_mask.ndim == 4:
451
+ # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions
452
+ if attn_mask.size(0) not in [1, batch_size]:
453
+ raise ValueError(
454
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask."
455
+ )
456
+ attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K]
457
+ attn_mask = attn_mask.any(dim=(1, 2)) # [B, K]
458
+
459
+ else:
460
+ raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}")
461
+
462
+ if attn_mask.shape != (batch_size, seq_len_k):
463
+ raise ValueError(
464
+ f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})"
465
+ )
466
+
467
+ return attn_mask
468
+
469
+
470
+ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
471
+ return q_idx >= kv_idx
472
+
473
+
474
+ # ===== torch op registrations =====
475
+ # Registrations are required for fullgraph tracing compatibility
476
+
477
+
478
+ # TODO: library.custom_op and register_fake probably need version guards?
479
+ # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
480
+ # this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
481
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
482
+ def _wrapped_flash_attn_3_original(
483
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
484
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
485
+ out, lse = flash_attn_3_func(query, key, value)
486
+ lse = lse.permute(0, 2, 1)
487
+ return out, lse
488
+
489
+
490
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
491
+ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
492
+ batch_size, seq_len, num_heads, head_dim = query.shape
493
+ lse_shape = (batch_size, seq_len, num_heads)
494
+ return torch.empty_like(query), query.new_empty(lse_shape)
495
+
496
+
497
+ # ===== Attention backends =====
498
+
499
+
500
+ @_AttentionBackendRegistry.register(
501
+ AttentionBackendName.FLASH,
502
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
503
+ )
504
+ def _flash_attention(
505
+ query: torch.Tensor,
506
+ key: torch.Tensor,
507
+ value: torch.Tensor,
508
+ dropout_p: float = 0.0,
509
+ scale: Optional[float] = None,
510
+ is_causal: bool = False,
511
+ window_size: Tuple[int, int] = (-1, -1),
512
+ softcap: float = 0.0,
513
+ alibi_slopes: Optional[torch.Tensor] = None,
514
+ deterministic: bool = False,
515
+ return_attn_probs: bool = False,
516
+ ) -> torch.Tensor:
517
+ out = flash_attn_func(
518
+ q=query,
519
+ k=key,
520
+ v=value,
521
+ dropout_p=dropout_p,
522
+ softmax_scale=scale,
523
+ causal=is_causal,
524
+ window_size=window_size,
525
+ softcap=softcap,
526
+ alibi_slopes=alibi_slopes,
527
+ deterministic=deterministic,
528
+ return_attn_probs=return_attn_probs,
529
+ )
530
+ return out
531
+
532
+
533
+ @_AttentionBackendRegistry.register(
534
+ AttentionBackendName.FLASH_VARLEN,
535
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
536
+ )
537
+ def _flash_varlen_attention(
538
+ query: torch.Tensor,
539
+ key: torch.Tensor,
540
+ value: torch.Tensor,
541
+ cu_seqlens_q: Optional[torch.Tensor] = None,
542
+ cu_seqlens_k: Optional[torch.Tensor] = None,
543
+ max_seqlen_q: Optional[int] = None,
544
+ max_seqlen_k: Optional[int] = None,
545
+ dropout_p: float = 0.0,
546
+ scale: Optional[float] = None,
547
+ is_causal: bool = False,
548
+ window_size: Tuple[int, int] = (-1, -1),
549
+ softcap: float = 0.0,
550
+ alibi_slopes: Optional[torch.Tensor] = None,
551
+ deterministic: bool = False,
552
+ return_attn_probs: bool = False,
553
+ attn_mask: Optional[torch.Tensor] = None,
554
+ ) -> torch.Tensor:
555
+ batch_size, seq_len_q, _, _ = query.shape
556
+ _, seq_len_kv, _, _ = key.shape
557
+
558
+ if attn_mask is not None:
559
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
560
+
561
+ if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
562
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
563
+ _prepare_for_flash_attn_or_sage_varlen(
564
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
565
+ )
566
+ )
567
+ else:
568
+ seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
569
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
570
+ cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
571
+
572
+ key_valid, value_valid = [], []
573
+ for b in range(batch_size):
574
+ valid_len = seqlens_k[b]
575
+ key_valid.append(key[b, :valid_len])
576
+ value_valid.append(value[b, :valid_len])
577
+
578
+ query_packed = query.flatten(0, 1)
579
+ key_packed = torch.cat(key_valid, dim=0)
580
+ value_packed = torch.cat(value_valid, dim=0)
581
+
582
+ out = flash_attn_varlen_func(
583
+ q=query_packed,
584
+ k=key_packed,
585
+ v=value_packed,
586
+ cu_seqlens_q=cu_seqlens_q,
587
+ cu_seqlens_k=cu_seqlens_k,
588
+ max_seqlen_q=max_seqlen_q,
589
+ max_seqlen_k=max_seqlen_k,
590
+ dropout_p=dropout_p,
591
+ softmax_scale=scale,
592
+ causal=is_causal,
593
+ window_size=window_size,
594
+ softcap=softcap,
595
+ alibi_slopes=alibi_slopes,
596
+ deterministic=deterministic,
597
+ return_attn_probs=return_attn_probs,
598
+ )
599
+ out = out.unflatten(0, (batch_size, -1))
600
+
601
+ return out
602
+
603
+
604
+ @_AttentionBackendRegistry.register(
605
+ AttentionBackendName._FLASH_3,
606
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
607
+ )
608
+ def _flash_attention_3(
609
+ query: torch.Tensor,
610
+ key: torch.Tensor,
611
+ value: torch.Tensor,
612
+ scale: Optional[float] = None,
613
+ is_causal: bool = False,
614
+ window_size: Tuple[int, int] = (-1, -1),
615
+ softcap: float = 0.0,
616
+ deterministic: bool = False,
617
+ return_attn_probs: bool = False,
618
+ ) -> torch.Tensor:
619
+ out, lse, *_ = flash_attn_3_func(
620
+ q=query,
621
+ k=key,
622
+ v=value,
623
+ softmax_scale=scale,
624
+ causal=is_causal,
625
+ qv=None,
626
+ q_descale=None,
627
+ k_descale=None,
628
+ v_descale=None,
629
+ window_size=window_size,
630
+ attention_chunk=0,
631
+ softcap=softcap,
632
+ num_splits=1,
633
+ pack_gqa=None,
634
+ deterministic=deterministic,
635
+ sm_margin=0,
636
+ )
637
+ return (out, lse) if return_attn_probs else out
638
+
639
+
640
+ @_AttentionBackendRegistry.register(
641
+ AttentionBackendName._FLASH_VARLEN_3,
642
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
643
+ )
644
+ def _flash_varlen_attention_3(
645
+ query: torch.Tensor,
646
+ key: torch.Tensor,
647
+ value: torch.Tensor,
648
+ cu_seqlens_q: Optional[torch.Tensor] = None,
649
+ cu_seqlens_k: Optional[torch.Tensor] = None,
650
+ max_seqlen_q: Optional[int] = None,
651
+ max_seqlen_k: Optional[int] = None,
652
+ scale: Optional[float] = None,
653
+ is_causal: bool = False,
654
+ window_size: Tuple[int, int] = (-1, -1),
655
+ softcap: float = 0.0,
656
+ deterministic: bool = False,
657
+ return_attn_probs: bool = False,
658
+ attn_mask: Optional[torch.Tensor] = None,
659
+ ) -> torch.Tensor:
660
+ batch_size, seq_len_q, _, _ = query.shape
661
+ _, seq_len_kv, _, _ = key.shape
662
+
663
+ if attn_mask is not None:
664
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
665
+
666
+ if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
667
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
668
+ _prepare_for_flash_attn_or_sage_varlen(
669
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
670
+ )
671
+ )
672
+ else:
673
+ seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
674
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
675
+ cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
676
+
677
+ key_valid, value_valid = [], []
678
+ for b in range(batch_size):
679
+ valid_len = seqlens_k[b]
680
+ key_valid.append(key[b, :valid_len])
681
+ value_valid.append(value[b, :valid_len])
682
+
683
+ query_packed = query.flatten(0, 1)
684
+ key_packed = torch.cat(key_valid, dim=0)
685
+ value_packed = torch.cat(value_valid, dim=0)
686
+
687
+ out, lse, *_ = flash_attn_3_varlen_func(
688
+ q=query_packed,
689
+ k=key_packed,
690
+ v=value_packed,
691
+ cu_seqlens_q=cu_seqlens_q,
692
+ cu_seqlens_k=cu_seqlens_k,
693
+ max_seqlen_q=max_seqlen_q,
694
+ max_seqlen_k=max_seqlen_k,
695
+ seqused_q=None,
696
+ seqused_k=None,
697
+ softmax_scale=scale,
698
+ causal=is_causal,
699
+ qv=None,
700
+ q_descale=None,
701
+ k_descale=None,
702
+ v_descale=None,
703
+ window_size=window_size,
704
+ softcap=softcap,
705
+ num_splits=1,
706
+ pack_gqa=None,
707
+ deterministic=deterministic,
708
+ sm_margin=0,
709
+ )
710
+ out = out.unflatten(0, (batch_size, -1))
711
+
712
+ return (out, lse) if return_attn_probs else out
713
+
714
+
715
+ @_AttentionBackendRegistry.register(
716
+ AttentionBackendName.FLEX,
717
+ constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
718
+ )
719
+ def _native_flex_attention(
720
+ query: torch.Tensor,
721
+ key: torch.Tensor,
722
+ value: torch.Tensor,
723
+ attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None,
724
+ is_causal: bool = False,
725
+ scale: Optional[float] = None,
726
+ enable_gqa: bool = False,
727
+ return_lse: bool = False,
728
+ kernel_options: Optional[Dict[str, Any]] = None,
729
+ ) -> torch.Tensor:
730
+ # TODO: should we LRU cache the block mask creation?
731
+ score_mod = None
732
+ block_mask = None
733
+ batch_size, seq_len_q, num_heads, _ = query.shape
734
+ _, seq_len_kv, _, _ = key.shape
735
+
736
+ if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask):
737
+ block_mask = attn_mask
738
+ elif is_causal:
739
+ block_mask = flex_attention.create_block_mask(
740
+ _flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device
741
+ )
742
+ elif torch.is_tensor(attn_mask):
743
+ if attn_mask.ndim == 2:
744
+ attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
745
+
746
+ attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv)
747
+
748
+ if attn_mask.dtype == torch.bool:
749
+ # TODO: this probably does not work but verify!
750
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
751
+ return attn_mask[batch_idx, head_idx, q_idx, kv_idx]
752
+
753
+ block_mask = flex_attention.create_block_mask(
754
+ mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device
755
+ )
756
+ else:
757
+
758
+ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
759
+ return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx]
760
+ else:
761
+ raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.")
762
+
763
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
764
+ out = flex_attention.flex_attention(
765
+ query=query,
766
+ key=key,
767
+ value=value,
768
+ score_mod=score_mod,
769
+ block_mask=block_mask,
770
+ scale=scale,
771
+ enable_gqa=enable_gqa,
772
+ return_lse=return_lse,
773
+ kernel_options=kernel_options,
774
+ )
775
+ out = out.permute(0, 2, 1, 3)
776
+ return out
777
+
778
+
779
+ @_AttentionBackendRegistry.register(
780
+ AttentionBackendName.NATIVE,
781
+ constraints=[_check_device, _check_shape],
782
+ )
783
+ def _native_attention(
784
+ query: torch.Tensor,
785
+ key: torch.Tensor,
786
+ value: torch.Tensor,
787
+ attn_mask: Optional[torch.Tensor] = None,
788
+ dropout_p: float = 0.0,
789
+ is_causal: bool = False,
790
+ scale: Optional[float] = None,
791
+ enable_gqa: bool = False,
792
+ ) -> torch.Tensor:
793
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
794
+ out = torch.nn.functional.scaled_dot_product_attention(
795
+ query=query,
796
+ key=key,
797
+ value=value,
798
+ attn_mask=attn_mask,
799
+ dropout_p=dropout_p,
800
+ is_causal=is_causal,
801
+ scale=scale,
802
+ enable_gqa=enable_gqa,
803
+ )
804
+ out = out.permute(0, 2, 1, 3)
805
+ return out
806
+
807
+
808
+ @_AttentionBackendRegistry.register(
809
+ AttentionBackendName._NATIVE_CUDNN,
810
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
811
+ )
812
+ def _native_cudnn_attention(
813
+ query: torch.Tensor,
814
+ key: torch.Tensor,
815
+ value: torch.Tensor,
816
+ attn_mask: Optional[torch.Tensor] = None,
817
+ dropout_p: float = 0.0,
818
+ is_causal: bool = False,
819
+ scale: Optional[float] = None,
820
+ enable_gqa: bool = False,
821
+ ) -> torch.Tensor:
822
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
823
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
824
+ out = torch.nn.functional.scaled_dot_product_attention(
825
+ query=query,
826
+ key=key,
827
+ value=value,
828
+ attn_mask=attn_mask,
829
+ dropout_p=dropout_p,
830
+ is_causal=is_causal,
831
+ scale=scale,
832
+ enable_gqa=enable_gqa,
833
+ )
834
+ out = out.permute(0, 2, 1, 3)
835
+ return out
836
+
837
+
838
+ @_AttentionBackendRegistry.register(
839
+ AttentionBackendName._NATIVE_EFFICIENT,
840
+ constraints=[_check_device, _check_shape],
841
+ )
842
+ def _native_efficient_attention(
843
+ query: torch.Tensor,
844
+ key: torch.Tensor,
845
+ value: torch.Tensor,
846
+ attn_mask: Optional[torch.Tensor] = None,
847
+ dropout_p: float = 0.0,
848
+ is_causal: bool = False,
849
+ scale: Optional[float] = None,
850
+ enable_gqa: bool = False,
851
+ ) -> torch.Tensor:
852
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
853
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
854
+ out = torch.nn.functional.scaled_dot_product_attention(
855
+ query=query,
856
+ key=key,
857
+ value=value,
858
+ attn_mask=attn_mask,
859
+ dropout_p=dropout_p,
860
+ is_causal=is_causal,
861
+ scale=scale,
862
+ enable_gqa=enable_gqa,
863
+ )
864
+ out = out.permute(0, 2, 1, 3)
865
+ return out
866
+
867
+
868
+ @_AttentionBackendRegistry.register(
869
+ AttentionBackendName._NATIVE_FLASH,
870
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
871
+ )
872
+ def _native_flash_attention(
873
+ query: torch.Tensor,
874
+ key: torch.Tensor,
875
+ value: torch.Tensor,
876
+ dropout_p: float = 0.0,
877
+ is_causal: bool = False,
878
+ scale: Optional[float] = None,
879
+ enable_gqa: bool = False,
880
+ ) -> torch.Tensor:
881
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
882
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
883
+ out = torch.nn.functional.scaled_dot_product_attention(
884
+ query=query,
885
+ key=key,
886
+ value=value,
887
+ attn_mask=None, # not supported
888
+ dropout_p=dropout_p,
889
+ is_causal=is_causal,
890
+ scale=scale,
891
+ enable_gqa=enable_gqa,
892
+ )
893
+ out = out.permute(0, 2, 1, 3)
894
+ return out
895
+
896
+
897
+ @_AttentionBackendRegistry.register(
898
+ AttentionBackendName._NATIVE_MATH,
899
+ constraints=[_check_device, _check_shape],
900
+ )
901
+ def _native_math_attention(
902
+ query: torch.Tensor,
903
+ key: torch.Tensor,
904
+ value: torch.Tensor,
905
+ attn_mask: Optional[torch.Tensor] = None,
906
+ dropout_p: float = 0.0,
907
+ is_causal: bool = False,
908
+ scale: Optional[float] = None,
909
+ enable_gqa: bool = False,
910
+ ) -> torch.Tensor:
911
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
912
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
913
+ out = torch.nn.functional.scaled_dot_product_attention(
914
+ query=query,
915
+ key=key,
916
+ value=value,
917
+ attn_mask=attn_mask,
918
+ dropout_p=dropout_p,
919
+ is_causal=is_causal,
920
+ scale=scale,
921
+ enable_gqa=enable_gqa,
922
+ )
923
+ out = out.permute(0, 2, 1, 3)
924
+ return out
925
+
926
+
927
+ @_AttentionBackendRegistry.register(
928
+ AttentionBackendName._NATIVE_NPU,
929
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
930
+ )
931
+ def _native_npu_attention(
932
+ query: torch.Tensor,
933
+ key: torch.Tensor,
934
+ value: torch.Tensor,
935
+ dropout_p: float = 0.0,
936
+ scale: Optional[float] = None,
937
+ ) -> torch.Tensor:
938
+ return npu_fusion_attention(
939
+ query,
940
+ key,
941
+ value,
942
+ query.size(2), # num_heads
943
+ input_layout="BSND",
944
+ pse=None,
945
+ scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
946
+ pre_tockens=65536,
947
+ next_tockens=65536,
948
+ keep_prob=1.0 - dropout_p,
949
+ sync=False,
950
+ inner_precise=0,
951
+ )[0]
952
+
953
+
954
+ # Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853
955
+ @_AttentionBackendRegistry.register(
956
+ AttentionBackendName._NATIVE_XLA,
957
+ constraints=[_check_device, _check_shape],
958
+ )
959
+ def _native_xla_attention(
960
+ query: torch.Tensor,
961
+ key: torch.Tensor,
962
+ value: torch.Tensor,
963
+ is_causal: bool = False,
964
+ ) -> torch.Tensor:
965
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
966
+ query = query / math.sqrt(query.shape[-1])
967
+ out = xla_flash_attention(
968
+ q=query,
969
+ k=key,
970
+ v=value,
971
+ causal=is_causal,
972
+ )
973
+ out = out.permute(0, 2, 1, 3)
974
+ return out
975
+
976
+
977
+ @_AttentionBackendRegistry.register(
978
+ AttentionBackendName.SAGE,
979
+ constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
980
+ )
981
+ def _sage_attention(
982
+ query: torch.Tensor,
983
+ key: torch.Tensor,
984
+ value: torch.Tensor,
985
+ is_causal: bool = False,
986
+ scale: Optional[float] = None,
987
+ return_lse: bool = False,
988
+ ) -> torch.Tensor:
989
+ return sageattn(
990
+ q=query,
991
+ k=key,
992
+ v=value,
993
+ tensor_layout="NHD",
994
+ is_causal=is_causal,
995
+ sm_scale=scale,
996
+ return_lse=return_lse,
997
+ )
998
+
999
+
1000
+ @_AttentionBackendRegistry.register(
1001
+ AttentionBackendName.SAGE_VARLEN,
1002
+ constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
1003
+ )
1004
+ def _sage_varlen_attention(
1005
+ query: torch.Tensor,
1006
+ key: torch.Tensor,
1007
+ value: torch.Tensor,
1008
+ cu_seqlens_q: Optional[torch.Tensor] = None,
1009
+ cu_seqlens_k: Optional[torch.Tensor] = None,
1010
+ max_seqlen_q: Optional[int] = None,
1011
+ max_seqlen_k: Optional[int] = None,
1012
+ is_causal: bool = False,
1013
+ scale: Optional[float] = None,
1014
+ smooth_k: bool = True,
1015
+ attn_mask: Optional[torch.Tensor] = None,
1016
+ ) -> torch.Tensor:
1017
+ batch_size, seq_len_q, _, _ = query.shape
1018
+ _, seq_len_kv, _, _ = key.shape
1019
+
1020
+ if attn_mask is not None:
1021
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
1022
+
1023
+ if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
1024
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
1025
+ _prepare_for_flash_attn_or_sage_varlen(
1026
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
1027
+ )
1028
+ )
1029
+ else:
1030
+ seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
1031
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
1032
+ cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
1033
+
1034
+ key_valid, value_valid = [], []
1035
+ for b in range(batch_size):
1036
+ valid_len = seqlens_k[b]
1037
+ key_valid.append(key[b, :valid_len])
1038
+ value_valid.append(value[b, :valid_len])
1039
+
1040
+ query_packed = query.flatten(0, 1)
1041
+ key_packed = torch.cat(key_valid, dim=0)
1042
+ value_packed = torch.cat(value_valid, dim=0)
1043
+
1044
+ out = sageattn_varlen(
1045
+ q=query_packed,
1046
+ k=key_packed,
1047
+ v=value_packed,
1048
+ cu_seqlens_q=cu_seqlens_q,
1049
+ cu_seqlens_k=cu_seqlens_k,
1050
+ max_seqlen_q=max_seqlen_q,
1051
+ max_seqlen_k=max_seqlen_k,
1052
+ is_causal=is_causal,
1053
+ sm_scale=scale,
1054
+ smooth_k=smooth_k,
1055
+ )
1056
+ out = out.unflatten(0, (batch_size, -1))
1057
+
1058
+ return out
1059
+
1060
+
1061
+ @_AttentionBackendRegistry.register(
1062
+ AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
1063
+ constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
1064
+ )
1065
+ def _sage_qk_int8_pv_fp8_cuda_attention(
1066
+ query: torch.Tensor,
1067
+ key: torch.Tensor,
1068
+ value: torch.Tensor,
1069
+ is_causal: bool = False,
1070
+ scale: Optional[float] = None,
1071
+ qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
1072
+ pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
1073
+ smooth_k: bool = True,
1074
+ smooth_v: bool = False,
1075
+ return_lse: bool = False,
1076
+ ) -> torch.Tensor:
1077
+ return sageattn_qk_int8_pv_fp8_cuda(
1078
+ q=query,
1079
+ k=key,
1080
+ v=value,
1081
+ tensor_layout="NHD",
1082
+ is_causal=is_causal,
1083
+ qk_quant_gran=qk_quant_gran,
1084
+ sm_scale=scale,
1085
+ pv_accum_dtype=pv_accum_dtype,
1086
+ smooth_k=smooth_k,
1087
+ smooth_v=smooth_v,
1088
+ return_lse=return_lse,
1089
+ )
1090
+
1091
+
1092
+ @_AttentionBackendRegistry.register(
1093
+ AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
1094
+ constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
1095
+ )
1096
+ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
1097
+ query: torch.Tensor,
1098
+ key: torch.Tensor,
1099
+ value: torch.Tensor,
1100
+ is_causal: bool = False,
1101
+ scale: Optional[float] = None,
1102
+ qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
1103
+ pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
1104
+ smooth_k: bool = True,
1105
+ return_lse: bool = False,
1106
+ ) -> torch.Tensor:
1107
+ return sageattn_qk_int8_pv_fp8_cuda_sm90(
1108
+ q=query,
1109
+ k=key,
1110
+ v=value,
1111
+ tensor_layout="NHD",
1112
+ is_causal=is_causal,
1113
+ qk_quant_gran=qk_quant_gran,
1114
+ sm_scale=scale,
1115
+ pv_accum_dtype=pv_accum_dtype,
1116
+ smooth_k=smooth_k,
1117
+ return_lse=return_lse,
1118
+ )
1119
+
1120
+
1121
+ @_AttentionBackendRegistry.register(
1122
+ AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
1123
+ constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
1124
+ )
1125
+ def _sage_qk_int8_pv_fp16_cuda_attention(
1126
+ query: torch.Tensor,
1127
+ key: torch.Tensor,
1128
+ value: torch.Tensor,
1129
+ is_causal: bool = False,
1130
+ scale: Optional[float] = None,
1131
+ qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
1132
+ pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32",
1133
+ smooth_k: bool = True,
1134
+ smooth_v: bool = False,
1135
+ return_lse: bool = False,
1136
+ ) -> torch.Tensor:
1137
+ return sageattn_qk_int8_pv_fp16_cuda(
1138
+ q=query,
1139
+ k=key,
1140
+ v=value,
1141
+ tensor_layout="NHD",
1142
+ is_causal=is_causal,
1143
+ qk_quant_gran=qk_quant_gran,
1144
+ sm_scale=scale,
1145
+ pv_accum_dtype=pv_accum_dtype,
1146
+ smooth_k=smooth_k,
1147
+ smooth_v=smooth_v,
1148
+ return_lse=return_lse,
1149
+ )
1150
+
1151
+
1152
+ @_AttentionBackendRegistry.register(
1153
+ AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
1154
+ constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
1155
+ )
1156
+ def _sage_qk_int8_pv_fp16_triton_attention(
1157
+ query: torch.Tensor,
1158
+ key: torch.Tensor,
1159
+ value: torch.Tensor,
1160
+ is_causal: bool = False,
1161
+ scale: Optional[float] = None,
1162
+ quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton",
1163
+ smooth_k: bool = True,
1164
+ return_lse: bool = False,
1165
+ ) -> torch.Tensor:
1166
+ return sageattn_qk_int8_pv_fp16_triton(
1167
+ q=query,
1168
+ k=key,
1169
+ v=value,
1170
+ tensor_layout="NHD",
1171
+ quantization_backend=quantization_backend,
1172
+ is_causal=is_causal,
1173
+ sm_scale=scale,
1174
+ smooth_k=smooth_k,
1175
+ return_lse=return_lse,
1176
+ )
1177
+
1178
+
1179
+ @_AttentionBackendRegistry.register(
1180
+ AttentionBackendName.XFORMERS,
1181
+ constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
1182
+ )
1183
+ def _xformers_attention(
1184
+ query: torch.Tensor,
1185
+ key: torch.Tensor,
1186
+ value: torch.Tensor,
1187
+ attn_mask: Optional[torch.Tensor] = None,
1188
+ dropout_p: float = 0.0,
1189
+ is_causal: bool = False,
1190
+ scale: Optional[float] = None,
1191
+ enable_gqa: bool = False,
1192
+ ) -> torch.Tensor:
1193
+ batch_size, seq_len_q, num_heads_q, _ = query.shape
1194
+ _, seq_len_kv, num_heads_kv, _ = key.shape
1195
+
1196
+ if is_causal:
1197
+ attn_mask = xops.LowerTriangularMask()
1198
+ elif attn_mask is not None:
1199
+ if attn_mask.ndim == 2:
1200
+ attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
1201
+ elif attn_mask.ndim != 4:
1202
+ raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
1203
+ attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
1204
+
1205
+ if enable_gqa:
1206
+ if num_heads_q % num_heads_kv != 0:
1207
+ raise ValueError("Number of heads in query must be divisible by number of heads in key/value.")
1208
+ num_heads_per_group = num_heads_q // num_heads_kv
1209
+ query = query.unflatten(2, (num_heads_kv, -1))
1210
+ key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
1211
+ value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
1212
+
1213
+ out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale)
1214
+
1215
+ if enable_gqa:
1216
+ out = out.flatten(2, 3)
1217
+
1218
+ return out