diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -203,8 +203,8 @@ class Attention(nn.Module):
203
203
  self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
204
204
  self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
205
205
  elif qk_norm == "rms_norm":
206
- self.norm_q = RMSNorm(dim_head, eps=eps)
207
- self.norm_k = RMSNorm(dim_head, eps=eps)
206
+ self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
207
+ self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
208
208
  elif qk_norm == "rms_norm_across_heads":
209
209
  # LTX applies qk norm across all heads
210
210
  self.norm_q = RMSNorm(dim_head * heads, eps=eps)
@@ -2272,554 +2272,6 @@ class FusedAuraFlowAttnProcessor2_0:
2272
2272
  return hidden_states
2273
2273
 
2274
2274
 
2275
- class FluxAttnProcessor2_0:
2276
- """Attention processor used typically in processing the SD3-like self-attention projections."""
2277
-
2278
- def __init__(self):
2279
- if not hasattr(F, "scaled_dot_product_attention"):
2280
- raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2281
-
2282
- def __call__(
2283
- self,
2284
- attn: Attention,
2285
- hidden_states: torch.FloatTensor,
2286
- encoder_hidden_states: torch.FloatTensor = None,
2287
- attention_mask: Optional[torch.FloatTensor] = None,
2288
- image_rotary_emb: Optional[torch.Tensor] = None,
2289
- ) -> torch.FloatTensor:
2290
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2291
-
2292
- # `sample` projections.
2293
- query = attn.to_q(hidden_states)
2294
- key = attn.to_k(hidden_states)
2295
- value = attn.to_v(hidden_states)
2296
-
2297
- inner_dim = key.shape[-1]
2298
- head_dim = inner_dim // attn.heads
2299
-
2300
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2301
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2302
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2303
-
2304
- if attn.norm_q is not None:
2305
- query = attn.norm_q(query)
2306
- if attn.norm_k is not None:
2307
- key = attn.norm_k(key)
2308
-
2309
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2310
- if encoder_hidden_states is not None:
2311
- # `context` projections.
2312
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
2313
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
2314
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
2315
-
2316
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2317
- batch_size, -1, attn.heads, head_dim
2318
- ).transpose(1, 2)
2319
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2320
- batch_size, -1, attn.heads, head_dim
2321
- ).transpose(1, 2)
2322
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2323
- batch_size, -1, attn.heads, head_dim
2324
- ).transpose(1, 2)
2325
-
2326
- if attn.norm_added_q is not None:
2327
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2328
- if attn.norm_added_k is not None:
2329
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2330
-
2331
- # attention
2332
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2333
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2334
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2335
-
2336
- if image_rotary_emb is not None:
2337
- from .embeddings import apply_rotary_emb
2338
-
2339
- query = apply_rotary_emb(query, image_rotary_emb)
2340
- key = apply_rotary_emb(key, image_rotary_emb)
2341
-
2342
- hidden_states = F.scaled_dot_product_attention(
2343
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2344
- )
2345
-
2346
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2347
- hidden_states = hidden_states.to(query.dtype)
2348
-
2349
- if encoder_hidden_states is not None:
2350
- encoder_hidden_states, hidden_states = (
2351
- hidden_states[:, : encoder_hidden_states.shape[1]],
2352
- hidden_states[:, encoder_hidden_states.shape[1] :],
2353
- )
2354
-
2355
- # linear proj
2356
- hidden_states = attn.to_out[0](hidden_states)
2357
- # dropout
2358
- hidden_states = attn.to_out[1](hidden_states)
2359
-
2360
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2361
-
2362
- return hidden_states, encoder_hidden_states
2363
- else:
2364
- return hidden_states
2365
-
2366
-
2367
- class FluxAttnProcessor2_0_NPU:
2368
- """Attention processor used typically in processing the SD3-like self-attention projections."""
2369
-
2370
- def __init__(self):
2371
- if not hasattr(F, "scaled_dot_product_attention"):
2372
- raise ImportError(
2373
- "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
2374
- )
2375
-
2376
- def __call__(
2377
- self,
2378
- attn: Attention,
2379
- hidden_states: torch.FloatTensor,
2380
- encoder_hidden_states: torch.FloatTensor = None,
2381
- attention_mask: Optional[torch.FloatTensor] = None,
2382
- image_rotary_emb: Optional[torch.Tensor] = None,
2383
- ) -> torch.FloatTensor:
2384
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2385
-
2386
- # `sample` projections.
2387
- query = attn.to_q(hidden_states)
2388
- key = attn.to_k(hidden_states)
2389
- value = attn.to_v(hidden_states)
2390
-
2391
- inner_dim = key.shape[-1]
2392
- head_dim = inner_dim // attn.heads
2393
-
2394
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2395
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2396
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2397
-
2398
- if attn.norm_q is not None:
2399
- query = attn.norm_q(query)
2400
- if attn.norm_k is not None:
2401
- key = attn.norm_k(key)
2402
-
2403
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2404
- if encoder_hidden_states is not None:
2405
- # `context` projections.
2406
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
2407
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
2408
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
2409
-
2410
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2411
- batch_size, -1, attn.heads, head_dim
2412
- ).transpose(1, 2)
2413
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2414
- batch_size, -1, attn.heads, head_dim
2415
- ).transpose(1, 2)
2416
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2417
- batch_size, -1, attn.heads, head_dim
2418
- ).transpose(1, 2)
2419
-
2420
- if attn.norm_added_q is not None:
2421
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2422
- if attn.norm_added_k is not None:
2423
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2424
-
2425
- # attention
2426
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2427
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2428
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2429
-
2430
- if image_rotary_emb is not None:
2431
- from .embeddings import apply_rotary_emb
2432
-
2433
- query = apply_rotary_emb(query, image_rotary_emb)
2434
- key = apply_rotary_emb(key, image_rotary_emb)
2435
-
2436
- if query.dtype in (torch.float16, torch.bfloat16):
2437
- hidden_states = torch_npu.npu_fusion_attention(
2438
- query,
2439
- key,
2440
- value,
2441
- attn.heads,
2442
- input_layout="BNSD",
2443
- pse=None,
2444
- scale=1.0 / math.sqrt(query.shape[-1]),
2445
- pre_tockens=65536,
2446
- next_tockens=65536,
2447
- keep_prob=1.0,
2448
- sync=False,
2449
- inner_precise=0,
2450
- )[0]
2451
- else:
2452
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2453
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2454
- hidden_states = hidden_states.to(query.dtype)
2455
-
2456
- if encoder_hidden_states is not None:
2457
- encoder_hidden_states, hidden_states = (
2458
- hidden_states[:, : encoder_hidden_states.shape[1]],
2459
- hidden_states[:, encoder_hidden_states.shape[1] :],
2460
- )
2461
-
2462
- # linear proj
2463
- hidden_states = attn.to_out[0](hidden_states)
2464
- # dropout
2465
- hidden_states = attn.to_out[1](hidden_states)
2466
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2467
-
2468
- return hidden_states, encoder_hidden_states
2469
- else:
2470
- return hidden_states
2471
-
2472
-
2473
- class FusedFluxAttnProcessor2_0:
2474
- """Attention processor used typically in processing the SD3-like self-attention projections."""
2475
-
2476
- def __init__(self):
2477
- if not hasattr(F, "scaled_dot_product_attention"):
2478
- raise ImportError(
2479
- "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2480
- )
2481
-
2482
- def __call__(
2483
- self,
2484
- attn: Attention,
2485
- hidden_states: torch.FloatTensor,
2486
- encoder_hidden_states: torch.FloatTensor = None,
2487
- attention_mask: Optional[torch.FloatTensor] = None,
2488
- image_rotary_emb: Optional[torch.Tensor] = None,
2489
- ) -> torch.FloatTensor:
2490
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2491
-
2492
- # `sample` projections.
2493
- qkv = attn.to_qkv(hidden_states)
2494
- split_size = qkv.shape[-1] // 3
2495
- query, key, value = torch.split(qkv, split_size, dim=-1)
2496
-
2497
- inner_dim = key.shape[-1]
2498
- head_dim = inner_dim // attn.heads
2499
-
2500
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2501
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2502
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2503
-
2504
- if attn.norm_q is not None:
2505
- query = attn.norm_q(query)
2506
- if attn.norm_k is not None:
2507
- key = attn.norm_k(key)
2508
-
2509
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2510
- # `context` projections.
2511
- if encoder_hidden_states is not None:
2512
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
2513
- split_size = encoder_qkv.shape[-1] // 3
2514
- (
2515
- encoder_hidden_states_query_proj,
2516
- encoder_hidden_states_key_proj,
2517
- encoder_hidden_states_value_proj,
2518
- ) = torch.split(encoder_qkv, split_size, dim=-1)
2519
-
2520
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2521
- batch_size, -1, attn.heads, head_dim
2522
- ).transpose(1, 2)
2523
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2524
- batch_size, -1, attn.heads, head_dim
2525
- ).transpose(1, 2)
2526
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2527
- batch_size, -1, attn.heads, head_dim
2528
- ).transpose(1, 2)
2529
-
2530
- if attn.norm_added_q is not None:
2531
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2532
- if attn.norm_added_k is not None:
2533
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2534
-
2535
- # attention
2536
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2537
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2538
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2539
-
2540
- if image_rotary_emb is not None:
2541
- from .embeddings import apply_rotary_emb
2542
-
2543
- query = apply_rotary_emb(query, image_rotary_emb)
2544
- key = apply_rotary_emb(key, image_rotary_emb)
2545
-
2546
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2547
-
2548
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2549
- hidden_states = hidden_states.to(query.dtype)
2550
-
2551
- if encoder_hidden_states is not None:
2552
- encoder_hidden_states, hidden_states = (
2553
- hidden_states[:, : encoder_hidden_states.shape[1]],
2554
- hidden_states[:, encoder_hidden_states.shape[1] :],
2555
- )
2556
-
2557
- # linear proj
2558
- hidden_states = attn.to_out[0](hidden_states)
2559
- # dropout
2560
- hidden_states = attn.to_out[1](hidden_states)
2561
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2562
-
2563
- return hidden_states, encoder_hidden_states
2564
- else:
2565
- return hidden_states
2566
-
2567
-
2568
- class FusedFluxAttnProcessor2_0_NPU:
2569
- """Attention processor used typically in processing the SD3-like self-attention projections."""
2570
-
2571
- def __init__(self):
2572
- if not hasattr(F, "scaled_dot_product_attention"):
2573
- raise ImportError(
2574
- "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
2575
- )
2576
-
2577
- def __call__(
2578
- self,
2579
- attn: Attention,
2580
- hidden_states: torch.FloatTensor,
2581
- encoder_hidden_states: torch.FloatTensor = None,
2582
- attention_mask: Optional[torch.FloatTensor] = None,
2583
- image_rotary_emb: Optional[torch.Tensor] = None,
2584
- ) -> torch.FloatTensor:
2585
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2586
-
2587
- # `sample` projections.
2588
- qkv = attn.to_qkv(hidden_states)
2589
- split_size = qkv.shape[-1] // 3
2590
- query, key, value = torch.split(qkv, split_size, dim=-1)
2591
-
2592
- inner_dim = key.shape[-1]
2593
- head_dim = inner_dim // attn.heads
2594
-
2595
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2596
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2597
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2598
-
2599
- if attn.norm_q is not None:
2600
- query = attn.norm_q(query)
2601
- if attn.norm_k is not None:
2602
- key = attn.norm_k(key)
2603
-
2604
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2605
- # `context` projections.
2606
- if encoder_hidden_states is not None:
2607
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
2608
- split_size = encoder_qkv.shape[-1] // 3
2609
- (
2610
- encoder_hidden_states_query_proj,
2611
- encoder_hidden_states_key_proj,
2612
- encoder_hidden_states_value_proj,
2613
- ) = torch.split(encoder_qkv, split_size, dim=-1)
2614
-
2615
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2616
- batch_size, -1, attn.heads, head_dim
2617
- ).transpose(1, 2)
2618
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2619
- batch_size, -1, attn.heads, head_dim
2620
- ).transpose(1, 2)
2621
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2622
- batch_size, -1, attn.heads, head_dim
2623
- ).transpose(1, 2)
2624
-
2625
- if attn.norm_added_q is not None:
2626
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2627
- if attn.norm_added_k is not None:
2628
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2629
-
2630
- # attention
2631
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2632
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2633
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2634
-
2635
- if image_rotary_emb is not None:
2636
- from .embeddings import apply_rotary_emb
2637
-
2638
- query = apply_rotary_emb(query, image_rotary_emb)
2639
- key = apply_rotary_emb(key, image_rotary_emb)
2640
-
2641
- if query.dtype in (torch.float16, torch.bfloat16):
2642
- hidden_states = torch_npu.npu_fusion_attention(
2643
- query,
2644
- key,
2645
- value,
2646
- attn.heads,
2647
- input_layout="BNSD",
2648
- pse=None,
2649
- scale=1.0 / math.sqrt(query.shape[-1]),
2650
- pre_tockens=65536,
2651
- next_tockens=65536,
2652
- keep_prob=1.0,
2653
- sync=False,
2654
- inner_precise=0,
2655
- )[0]
2656
- else:
2657
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2658
-
2659
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2660
- hidden_states = hidden_states.to(query.dtype)
2661
-
2662
- if encoder_hidden_states is not None:
2663
- encoder_hidden_states, hidden_states = (
2664
- hidden_states[:, : encoder_hidden_states.shape[1]],
2665
- hidden_states[:, encoder_hidden_states.shape[1] :],
2666
- )
2667
-
2668
- # linear proj
2669
- hidden_states = attn.to_out[0](hidden_states)
2670
- # dropout
2671
- hidden_states = attn.to_out[1](hidden_states)
2672
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2673
-
2674
- return hidden_states, encoder_hidden_states
2675
- else:
2676
- return hidden_states
2677
-
2678
-
2679
- class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
2680
- """Flux Attention processor for IP-Adapter."""
2681
-
2682
- def __init__(
2683
- self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
2684
- ):
2685
- super().__init__()
2686
-
2687
- if not hasattr(F, "scaled_dot_product_attention"):
2688
- raise ImportError(
2689
- f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2690
- )
2691
-
2692
- self.hidden_size = hidden_size
2693
- self.cross_attention_dim = cross_attention_dim
2694
-
2695
- if not isinstance(num_tokens, (tuple, list)):
2696
- num_tokens = [num_tokens]
2697
-
2698
- if not isinstance(scale, list):
2699
- scale = [scale] * len(num_tokens)
2700
- if len(scale) != len(num_tokens):
2701
- raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
2702
- self.scale = scale
2703
-
2704
- self.to_k_ip = nn.ModuleList(
2705
- [
2706
- nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
2707
- for _ in range(len(num_tokens))
2708
- ]
2709
- )
2710
- self.to_v_ip = nn.ModuleList(
2711
- [
2712
- nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
2713
- for _ in range(len(num_tokens))
2714
- ]
2715
- )
2716
-
2717
- def __call__(
2718
- self,
2719
- attn: Attention,
2720
- hidden_states: torch.FloatTensor,
2721
- encoder_hidden_states: torch.FloatTensor = None,
2722
- attention_mask: Optional[torch.FloatTensor] = None,
2723
- image_rotary_emb: Optional[torch.Tensor] = None,
2724
- ip_hidden_states: Optional[List[torch.Tensor]] = None,
2725
- ip_adapter_masks: Optional[torch.Tensor] = None,
2726
- ) -> torch.FloatTensor:
2727
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2728
-
2729
- # `sample` projections.
2730
- hidden_states_query_proj = attn.to_q(hidden_states)
2731
- key = attn.to_k(hidden_states)
2732
- value = attn.to_v(hidden_states)
2733
-
2734
- inner_dim = key.shape[-1]
2735
- head_dim = inner_dim // attn.heads
2736
-
2737
- hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2738
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2739
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2740
-
2741
- if attn.norm_q is not None:
2742
- hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
2743
- if attn.norm_k is not None:
2744
- key = attn.norm_k(key)
2745
-
2746
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2747
- if encoder_hidden_states is not None:
2748
- # `context` projections.
2749
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
2750
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
2751
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
2752
-
2753
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2754
- batch_size, -1, attn.heads, head_dim
2755
- ).transpose(1, 2)
2756
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2757
- batch_size, -1, attn.heads, head_dim
2758
- ).transpose(1, 2)
2759
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2760
- batch_size, -1, attn.heads, head_dim
2761
- ).transpose(1, 2)
2762
-
2763
- if attn.norm_added_q is not None:
2764
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2765
- if attn.norm_added_k is not None:
2766
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2767
-
2768
- # attention
2769
- query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
2770
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2771
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2772
-
2773
- if image_rotary_emb is not None:
2774
- from .embeddings import apply_rotary_emb
2775
-
2776
- query = apply_rotary_emb(query, image_rotary_emb)
2777
- key = apply_rotary_emb(key, image_rotary_emb)
2778
-
2779
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2780
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2781
- hidden_states = hidden_states.to(query.dtype)
2782
-
2783
- if encoder_hidden_states is not None:
2784
- encoder_hidden_states, hidden_states = (
2785
- hidden_states[:, : encoder_hidden_states.shape[1]],
2786
- hidden_states[:, encoder_hidden_states.shape[1] :],
2787
- )
2788
-
2789
- # linear proj
2790
- hidden_states = attn.to_out[0](hidden_states)
2791
- # dropout
2792
- hidden_states = attn.to_out[1](hidden_states)
2793
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2794
-
2795
- # IP-adapter
2796
- ip_query = hidden_states_query_proj
2797
- ip_attn_output = torch.zeros_like(hidden_states)
2798
-
2799
- for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
2800
- ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
2801
- ):
2802
- ip_key = to_k_ip(current_ip_hidden_states)
2803
- ip_value = to_v_ip(current_ip_hidden_states)
2804
-
2805
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2806
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2807
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
2808
- # TODO: add support for attn.scale when we move to Torch 2.1
2809
- current_ip_hidden_states = F.scaled_dot_product_attention(
2810
- ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2811
- )
2812
- current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
2813
- batch_size, -1, attn.heads * head_dim
2814
- )
2815
- current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
2816
- ip_attn_output += scale * current_ip_hidden_states
2817
-
2818
- return hidden_states, encoder_hidden_states, ip_attn_output
2819
- else:
2820
- return hidden_states
2821
-
2822
-
2823
2275
  class CogVideoXAttnProcessor2_0:
2824
2276
  r"""
2825
2277
  Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -3449,106 +2901,6 @@ class XLAFlashAttnProcessor2_0:
3449
2901
  return hidden_states
3450
2902
 
3451
2903
 
3452
- class XLAFluxFlashAttnProcessor2_0:
3453
- r"""
3454
- Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
3455
- """
3456
-
3457
- def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
3458
- if not hasattr(F, "scaled_dot_product_attention"):
3459
- raise ImportError(
3460
- "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
3461
- )
3462
- if is_torch_xla_version("<", "2.3"):
3463
- raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
3464
- if is_spmd() and is_torch_xla_version("<", "2.4"):
3465
- raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
3466
- self.partition_spec = partition_spec
3467
-
3468
- def __call__(
3469
- self,
3470
- attn: Attention,
3471
- hidden_states: torch.FloatTensor,
3472
- encoder_hidden_states: torch.FloatTensor = None,
3473
- attention_mask: Optional[torch.FloatTensor] = None,
3474
- image_rotary_emb: Optional[torch.Tensor] = None,
3475
- ) -> torch.FloatTensor:
3476
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
3477
-
3478
- # `sample` projections.
3479
- query = attn.to_q(hidden_states)
3480
- key = attn.to_k(hidden_states)
3481
- value = attn.to_v(hidden_states)
3482
-
3483
- inner_dim = key.shape[-1]
3484
- head_dim = inner_dim // attn.heads
3485
-
3486
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3487
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3488
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3489
-
3490
- if attn.norm_q is not None:
3491
- query = attn.norm_q(query)
3492
- if attn.norm_k is not None:
3493
- key = attn.norm_k(key)
3494
-
3495
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
3496
- if encoder_hidden_states is not None:
3497
- # `context` projections.
3498
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
3499
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
3500
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
3501
-
3502
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
3503
- batch_size, -1, attn.heads, head_dim
3504
- ).transpose(1, 2)
3505
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
3506
- batch_size, -1, attn.heads, head_dim
3507
- ).transpose(1, 2)
3508
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
3509
- batch_size, -1, attn.heads, head_dim
3510
- ).transpose(1, 2)
3511
-
3512
- if attn.norm_added_q is not None:
3513
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
3514
- if attn.norm_added_k is not None:
3515
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
3516
-
3517
- # attention
3518
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
3519
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
3520
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
3521
-
3522
- if image_rotary_emb is not None:
3523
- from .embeddings import apply_rotary_emb
3524
-
3525
- query = apply_rotary_emb(query, image_rotary_emb)
3526
- key = apply_rotary_emb(key, image_rotary_emb)
3527
-
3528
- query /= math.sqrt(head_dim)
3529
- hidden_states = flash_attention(query, key, value, causal=False)
3530
-
3531
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
3532
- hidden_states = hidden_states.to(query.dtype)
3533
-
3534
- if encoder_hidden_states is not None:
3535
- encoder_hidden_states, hidden_states = (
3536
- hidden_states[:, : encoder_hidden_states.shape[1]],
3537
- hidden_states[:, encoder_hidden_states.shape[1] :],
3538
- )
3539
-
3540
- # linear proj
3541
- hidden_states = attn.to_out[0](hidden_states)
3542
- # dropout
3543
- hidden_states = attn.to_out[1](hidden_states)
3544
-
3545
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
3546
-
3547
- return hidden_states, encoder_hidden_states
3548
- else:
3549
- return hidden_states
3550
-
3551
-
3552
2904
  class MochiVaeAttnProcessor2_0:
3553
2905
  r"""
3554
2906
  Attention processor used in Mochi VAE.
@@ -3972,7 +3324,7 @@ class PAGHunyuanAttnProcessor2_0:
3972
3324
  r"""
3973
3325
  Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
3974
3326
  used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
3975
- variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
3327
+ variant of the processor employs [Pertubed Attention Guidance](https://huggingface.co/papers/2403.17377).
3976
3328
  """
3977
3329
 
3978
3330
  def __init__(self):
@@ -4095,7 +3447,7 @@ class PAGCFGHunyuanAttnProcessor2_0:
4095
3447
  r"""
4096
3448
  Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
4097
3449
  used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
4098
- variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
3450
+ variant of the processor employs [Pertubed Attention Guidance](https://huggingface.co/papers/2403.17377).
4099
3451
  """
4100
3452
 
4101
3453
  def __init__(self):
@@ -4828,7 +4180,7 @@ class SlicedAttnAddedKVProcessor:
4828
4180
 
4829
4181
  class SpatialNorm(nn.Module):
4830
4182
  """
4831
- Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
4183
+ Spatially conditioned normalization as defined in https://huggingface.co/papers/2209.09002.
4832
4184
 
4833
4185
  Args:
4834
4186
  f_channels (`int`):
@@ -5693,7 +5045,7 @@ class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module):
5693
5045
  class PAGIdentitySelfAttnProcessor2_0:
5694
5046
  r"""
5695
5047
  Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
5696
- PAG reference: https://arxiv.org/abs/2403.17377
5048
+ PAG reference: https://huggingface.co/papers/2403.17377
5697
5049
  """
5698
5050
 
5699
5051
  def __init__(self):
@@ -5792,7 +5144,7 @@ class PAGIdentitySelfAttnProcessor2_0:
5792
5144
  class PAGCFGIdentitySelfAttnProcessor2_0:
5793
5145
  r"""
5794
5146
  Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
5795
- PAG reference: https://arxiv.org/abs/2403.17377
5147
+ PAG reference: https://huggingface.co/papers/2403.17377
5796
5148
  """
5797
5149
 
5798
5150
  def __init__(self):
@@ -5988,17 +5340,6 @@ class LoRAAttnAddedKVProcessor:
5988
5340
  pass
5989
5341
 
5990
5342
 
5991
- class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
5992
- r"""
5993
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
5994
- """
5995
-
5996
- def __init__(self):
5997
- deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
5998
- deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
5999
- super().__init__()
6000
-
6001
-
6002
5343
  class SanaLinearAttnProcessor2_0:
6003
5344
  r"""
6004
5345
  Processor for implementing scaled dot-product linear attention.
@@ -6163,6 +5504,111 @@ class PAGIdentitySanaLinearAttnProcessor2_0:
6163
5504
  return hidden_states
6164
5505
 
6165
5506
 
5507
+ class FluxAttnProcessor2_0:
5508
+ def __new__(cls, *args, **kwargs):
5509
+ deprecation_message = "`FluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
5510
+ deprecate("FluxAttnProcessor2_0", "1.0.0", deprecation_message)
5511
+
5512
+ from .transformers.transformer_flux import FluxAttnProcessor
5513
+
5514
+ return FluxAttnProcessor(*args, **kwargs)
5515
+
5516
+
5517
+ class FluxSingleAttnProcessor2_0:
5518
+ r"""
5519
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
5520
+ """
5521
+
5522
+ def __new__(cls, *args, **kwargs):
5523
+ deprecation_message = "`FluxSingleAttnProcessor` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
5524
+ deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message)
5525
+
5526
+ from .transformers.transformer_flux import FluxAttnProcessor
5527
+
5528
+ return FluxAttnProcessor(*args, **kwargs)
5529
+
5530
+
5531
+ class FusedFluxAttnProcessor2_0:
5532
+ def __new__(cls, *args, **kwargs):
5533
+ deprecation_message = "`FusedFluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
5534
+ deprecate("FusedFluxAttnProcessor2_0", "1.0.0", deprecation_message)
5535
+
5536
+ from .transformers.transformer_flux import FluxAttnProcessor
5537
+
5538
+ return FluxAttnProcessor(*args, **kwargs)
5539
+
5540
+
5541
+ class FluxIPAdapterJointAttnProcessor2_0:
5542
+ def __new__(cls, *args, **kwargs):
5543
+ deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`"
5544
+ deprecate("FluxIPAdapterJointAttnProcessor2_0", "1.0.0", deprecation_message)
5545
+
5546
+ from .transformers.transformer_flux import FluxIPAdapterAttnProcessor
5547
+
5548
+ return FluxIPAdapterAttnProcessor(*args, **kwargs)
5549
+
5550
+
5551
+ class FluxAttnProcessor2_0_NPU:
5552
+ def __new__(cls, *args, **kwargs):
5553
+ deprecation_message = (
5554
+ "FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
5555
+ "alternative solution to use NPU Flash Attention will be provided in the future."
5556
+ )
5557
+ deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
5558
+
5559
+ from .transformers.transformer_flux import FluxAttnProcessor
5560
+
5561
+ processor = FluxAttnProcessor()
5562
+ processor._attention_backend = "_native_npu"
5563
+ return processor
5564
+
5565
+
5566
+ class FusedFluxAttnProcessor2_0_NPU:
5567
+ def __new__(self):
5568
+ deprecation_message = (
5569
+ "FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
5570
+ "alternative solution to use NPU Flash Attention will be provided in the future."
5571
+ )
5572
+ deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
5573
+
5574
+ from .transformers.transformer_flux import FluxAttnProcessor
5575
+
5576
+ processor = FluxAttnProcessor()
5577
+ processor._attention_backend = "_fused_npu"
5578
+ return processor
5579
+
5580
+
5581
+ class XLAFluxFlashAttnProcessor2_0:
5582
+ r"""
5583
+ Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
5584
+ """
5585
+
5586
+ def __new__(cls, *args, **kwargs):
5587
+ deprecation_message = (
5588
+ "XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An "
5589
+ "alternative solution to using XLA Flash Attention will be provided in the future."
5590
+ )
5591
+ deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
5592
+
5593
+ if is_torch_xla_version("<", "2.3"):
5594
+ raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
5595
+ if is_spmd() and is_torch_xla_version("<", "2.4"):
5596
+ raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
5597
+
5598
+ from .transformers.transformer_flux import FluxAttnProcessor
5599
+
5600
+ if len(args) > 0 or kwargs.get("partition_spec", None) is not None:
5601
+ deprecation_message = (
5602
+ "partition_spec was not used in the processor implementation when it was added. Passing it "
5603
+ "is a no-op and support for it will be removed."
5604
+ )
5605
+ deprecate("partition_spec", "1.0.0", deprecation_message)
5606
+
5607
+ processor = FluxAttnProcessor(*args, **kwargs)
5608
+ processor._attention_backend = "_native_xla"
5609
+ return processor
5610
+
5611
+
6166
5612
  ADDED_KV_ATTENTION_PROCESSORS = (
6167
5613
  AttnAddedKVProcessor,
6168
5614
  SlicedAttnAddedKVProcessor,