diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
1
+ # Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,36 +12,337 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
-
16
- from typing import Any, Dict, Optional, Tuple, Union
15
+ import inspect
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
17
 
18
18
  import numpy as np
19
19
  import torch
20
20
  import torch.nn as nn
21
+ import torch.nn.functional as F
21
22
 
22
23
  from ...configuration_utils import ConfigMixin, register_to_config
23
24
  from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
24
- from ...models.attention import FeedForward
25
- from ...models.attention_processor import (
26
- Attention,
27
- AttentionProcessor,
28
- FluxAttnProcessor2_0,
29
- FluxAttnProcessor2_0_NPU,
30
- FusedFluxAttnProcessor2_0,
31
- )
32
- from ...models.modeling_utils import ModelMixin
33
- from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
34
25
  from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
35
26
  from ...utils.import_utils import is_torch_npu_available
36
27
  from ...utils.torch_utils import maybe_allow_in_graph
28
+ from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
29
+ from ..attention_dispatch import dispatch_attention_fn
37
30
  from ..cache_utils import CacheMixin
38
- from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
31
+ from ..embeddings import (
32
+ CombinedTimestepGuidanceTextProjEmbeddings,
33
+ CombinedTimestepTextProjEmbeddings,
34
+ apply_rotary_emb,
35
+ get_1d_rotary_pos_embed,
36
+ )
39
37
  from ..modeling_outputs import Transformer2DModelOutput
38
+ from ..modeling_utils import ModelMixin
39
+ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
40
40
 
41
41
 
42
42
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
43
 
44
44
 
45
+ def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
46
+ query = attn.to_q(hidden_states)
47
+ key = attn.to_k(hidden_states)
48
+ value = attn.to_v(hidden_states)
49
+
50
+ encoder_query = encoder_key = encoder_value = None
51
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
52
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
53
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
54
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
55
+
56
+ return query, key, value, encoder_query, encoder_key, encoder_value
57
+
58
+
59
+ def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
60
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
61
+
62
+ encoder_query = encoder_key = encoder_value = (None,)
63
+ if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
64
+ encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
65
+
66
+ return query, key, value, encoder_query, encoder_key, encoder_value
67
+
68
+
69
+ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
70
+ if attn.fused_projections:
71
+ return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
72
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
73
+
74
+
75
+ class FluxAttnProcessor:
76
+ _attention_backend = None
77
+
78
+ def __init__(self):
79
+ if not hasattr(F, "scaled_dot_product_attention"):
80
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
81
+
82
+ def __call__(
83
+ self,
84
+ attn: "FluxAttention",
85
+ hidden_states: torch.Tensor,
86
+ encoder_hidden_states: torch.Tensor = None,
87
+ attention_mask: Optional[torch.Tensor] = None,
88
+ image_rotary_emb: Optional[torch.Tensor] = None,
89
+ ) -> torch.Tensor:
90
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
91
+ attn, hidden_states, encoder_hidden_states
92
+ )
93
+
94
+ query = query.unflatten(-1, (attn.heads, -1))
95
+ key = key.unflatten(-1, (attn.heads, -1))
96
+ value = value.unflatten(-1, (attn.heads, -1))
97
+
98
+ query = attn.norm_q(query)
99
+ key = attn.norm_k(key)
100
+
101
+ if attn.added_kv_proj_dim is not None:
102
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
103
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
104
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
105
+
106
+ encoder_query = attn.norm_added_q(encoder_query)
107
+ encoder_key = attn.norm_added_k(encoder_key)
108
+
109
+ query = torch.cat([encoder_query, query], dim=1)
110
+ key = torch.cat([encoder_key, key], dim=1)
111
+ value = torch.cat([encoder_value, value], dim=1)
112
+
113
+ if image_rotary_emb is not None:
114
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
115
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
116
+
117
+ hidden_states = dispatch_attention_fn(
118
+ query, key, value, attn_mask=attention_mask, backend=self._attention_backend
119
+ )
120
+ hidden_states = hidden_states.flatten(2, 3)
121
+ hidden_states = hidden_states.to(query.dtype)
122
+
123
+ if encoder_hidden_states is not None:
124
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
125
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
126
+ )
127
+ hidden_states = attn.to_out[0](hidden_states)
128
+ hidden_states = attn.to_out[1](hidden_states)
129
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
130
+
131
+ return hidden_states, encoder_hidden_states
132
+ else:
133
+ return hidden_states
134
+
135
+
136
+ class FluxIPAdapterAttnProcessor(torch.nn.Module):
137
+ """Flux Attention processor for IP-Adapter."""
138
+
139
+ _attention_backend = None
140
+
141
+ def __init__(
142
+ self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
143
+ ):
144
+ super().__init__()
145
+
146
+ if not hasattr(F, "scaled_dot_product_attention"):
147
+ raise ImportError(
148
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
149
+ )
150
+
151
+ self.hidden_size = hidden_size
152
+ self.cross_attention_dim = cross_attention_dim
153
+
154
+ if not isinstance(num_tokens, (tuple, list)):
155
+ num_tokens = [num_tokens]
156
+
157
+ if not isinstance(scale, list):
158
+ scale = [scale] * len(num_tokens)
159
+ if len(scale) != len(num_tokens):
160
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
161
+ self.scale = scale
162
+
163
+ self.to_k_ip = nn.ModuleList(
164
+ [
165
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
166
+ for _ in range(len(num_tokens))
167
+ ]
168
+ )
169
+ self.to_v_ip = nn.ModuleList(
170
+ [
171
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
172
+ for _ in range(len(num_tokens))
173
+ ]
174
+ )
175
+
176
+ def __call__(
177
+ self,
178
+ attn: "FluxAttention",
179
+ hidden_states: torch.Tensor,
180
+ encoder_hidden_states: torch.Tensor = None,
181
+ attention_mask: Optional[torch.Tensor] = None,
182
+ image_rotary_emb: Optional[torch.Tensor] = None,
183
+ ip_hidden_states: Optional[List[torch.Tensor]] = None,
184
+ ip_adapter_masks: Optional[torch.Tensor] = None,
185
+ ) -> torch.Tensor:
186
+ batch_size = hidden_states.shape[0]
187
+
188
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
189
+ attn, hidden_states, encoder_hidden_states
190
+ )
191
+
192
+ query = query.unflatten(-1, (attn.heads, -1))
193
+ key = key.unflatten(-1, (attn.heads, -1))
194
+ value = value.unflatten(-1, (attn.heads, -1))
195
+
196
+ query = attn.norm_q(query)
197
+ key = attn.norm_k(key)
198
+ ip_query = query
199
+
200
+ if encoder_hidden_states is not None:
201
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
202
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
203
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
204
+
205
+ encoder_query = attn.norm_added_q(encoder_query)
206
+ encoder_key = attn.norm_added_k(encoder_key)
207
+
208
+ query = torch.cat([encoder_query, query], dim=1)
209
+ key = torch.cat([encoder_key, key], dim=1)
210
+ value = torch.cat([encoder_value, value], dim=1)
211
+
212
+ if image_rotary_emb is not None:
213
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
214
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
215
+
216
+ hidden_states = dispatch_attention_fn(
217
+ query,
218
+ key,
219
+ value,
220
+ attn_mask=attention_mask,
221
+ dropout_p=0.0,
222
+ is_causal=False,
223
+ backend=self._attention_backend,
224
+ )
225
+ hidden_states = hidden_states.flatten(2, 3)
226
+ hidden_states = hidden_states.to(query.dtype)
227
+
228
+ if encoder_hidden_states is not None:
229
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
230
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
231
+ )
232
+ hidden_states = attn.to_out[0](hidden_states)
233
+ hidden_states = attn.to_out[1](hidden_states)
234
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
235
+
236
+ # IP-adapter
237
+ ip_attn_output = torch.zeros_like(hidden_states)
238
+
239
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
240
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
241
+ ):
242
+ ip_key = to_k_ip(current_ip_hidden_states)
243
+ ip_value = to_v_ip(current_ip_hidden_states)
244
+
245
+ ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
246
+ ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)
247
+
248
+ current_ip_hidden_states = dispatch_attention_fn(
249
+ ip_query,
250
+ ip_key,
251
+ ip_value,
252
+ attn_mask=None,
253
+ dropout_p=0.0,
254
+ is_causal=False,
255
+ backend=self._attention_backend,
256
+ )
257
+ current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
258
+ current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
259
+ ip_attn_output += scale * current_ip_hidden_states
260
+
261
+ return hidden_states, encoder_hidden_states, ip_attn_output
262
+ else:
263
+ return hidden_states
264
+
265
+
266
+ class FluxAttention(torch.nn.Module, AttentionModuleMixin):
267
+ _default_processor_cls = FluxAttnProcessor
268
+ _available_processors = [
269
+ FluxAttnProcessor,
270
+ FluxIPAdapterAttnProcessor,
271
+ ]
272
+
273
+ def __init__(
274
+ self,
275
+ query_dim: int,
276
+ heads: int = 8,
277
+ dim_head: int = 64,
278
+ dropout: float = 0.0,
279
+ bias: bool = False,
280
+ added_kv_proj_dim: Optional[int] = None,
281
+ added_proj_bias: Optional[bool] = True,
282
+ out_bias: bool = True,
283
+ eps: float = 1e-5,
284
+ out_dim: int = None,
285
+ context_pre_only: Optional[bool] = None,
286
+ pre_only: bool = False,
287
+ elementwise_affine: bool = True,
288
+ processor=None,
289
+ ):
290
+ super().__init__()
291
+
292
+ self.head_dim = dim_head
293
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
294
+ self.query_dim = query_dim
295
+ self.use_bias = bias
296
+ self.dropout = dropout
297
+ self.out_dim = out_dim if out_dim is not None else query_dim
298
+ self.context_pre_only = context_pre_only
299
+ self.pre_only = pre_only
300
+ self.heads = out_dim // dim_head if out_dim is not None else heads
301
+ self.added_kv_proj_dim = added_kv_proj_dim
302
+ self.added_proj_bias = added_proj_bias
303
+
304
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
305
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
306
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
307
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
308
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
309
+
310
+ if not self.pre_only:
311
+ self.to_out = torch.nn.ModuleList([])
312
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
313
+ self.to_out.append(torch.nn.Dropout(dropout))
314
+
315
+ if added_kv_proj_dim is not None:
316
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
317
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
318
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
319
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
320
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
321
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
322
+
323
+ if processor is None:
324
+ processor = self._default_processor_cls()
325
+ self.set_processor(processor)
326
+
327
+ def forward(
328
+ self,
329
+ hidden_states: torch.Tensor,
330
+ encoder_hidden_states: Optional[torch.Tensor] = None,
331
+ attention_mask: Optional[torch.Tensor] = None,
332
+ image_rotary_emb: Optional[torch.Tensor] = None,
333
+ **kwargs,
334
+ ) -> torch.Tensor:
335
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
336
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
337
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
338
+ if len(unused_kwargs) > 0:
339
+ logger.warning(
340
+ f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
341
+ )
342
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
343
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
344
+
345
+
45
346
  @maybe_allow_in_graph
46
347
  class FluxSingleTransformerBlock(nn.Module):
47
348
  def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
@@ -54,6 +355,8 @@ class FluxSingleTransformerBlock(nn.Module):
54
355
  self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
55
356
 
56
357
  if is_torch_npu_available():
358
+ from ..attention_processor import FluxAttnProcessor2_0_NPU
359
+
57
360
  deprecation_message = (
58
361
  "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
59
362
  "should be set explicitly using the `set_attn_processor` method."
@@ -61,17 +364,15 @@ class FluxSingleTransformerBlock(nn.Module):
61
364
  deprecate("npu_processor", "0.34.0", deprecation_message)
62
365
  processor = FluxAttnProcessor2_0_NPU()
63
366
  else:
64
- processor = FluxAttnProcessor2_0()
367
+ processor = FluxAttnProcessor()
65
368
 
66
- self.attn = Attention(
369
+ self.attn = FluxAttention(
67
370
  query_dim=dim,
68
- cross_attention_dim=None,
69
371
  dim_head=attention_head_dim,
70
372
  heads=num_attention_heads,
71
373
  out_dim=dim,
72
374
  bias=True,
73
375
  processor=processor,
74
- qk_norm="rms_norm",
75
376
  eps=1e-6,
76
377
  pre_only=True,
77
378
  )
@@ -79,10 +380,14 @@ class FluxSingleTransformerBlock(nn.Module):
79
380
  def forward(
80
381
  self,
81
382
  hidden_states: torch.Tensor,
383
+ encoder_hidden_states: torch.Tensor,
82
384
  temb: torch.Tensor,
83
385
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
84
386
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
85
- ) -> torch.Tensor:
387
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
388
+ text_seq_len = encoder_hidden_states.shape[1]
389
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
390
+
86
391
  residual = hidden_states
87
392
  norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
88
393
  mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
@@ -100,7 +405,8 @@ class FluxSingleTransformerBlock(nn.Module):
100
405
  if hidden_states.dtype == torch.float16:
101
406
  hidden_states = hidden_states.clip(-65504, 65504)
102
407
 
103
- return hidden_states
408
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
409
+ return encoder_hidden_states, hidden_states
104
410
 
105
411
 
106
412
  @maybe_allow_in_graph
@@ -113,17 +419,15 @@ class FluxTransformerBlock(nn.Module):
113
419
  self.norm1 = AdaLayerNormZero(dim)
114
420
  self.norm1_context = AdaLayerNormZero(dim)
115
421
 
116
- self.attn = Attention(
422
+ self.attn = FluxAttention(
117
423
  query_dim=dim,
118
- cross_attention_dim=None,
119
424
  added_kv_proj_dim=dim,
120
425
  dim_head=attention_head_dim,
121
426
  heads=num_attention_heads,
122
427
  out_dim=dim,
123
428
  context_pre_only=False,
124
429
  bias=True,
125
- processor=FluxAttnProcessor2_0(),
126
- qk_norm=qk_norm,
430
+ processor=FluxAttnProcessor(),
127
431
  eps=eps,
128
432
  )
129
433
 
@@ -147,6 +451,7 @@ class FluxTransformerBlock(nn.Module):
147
451
  encoder_hidden_states, emb=temb
148
452
  )
149
453
  joint_attention_kwargs = joint_attention_kwargs or {}
454
+
150
455
  # Attention.
151
456
  attention_outputs = self.attn(
152
457
  hidden_states=norm_hidden_states,
@@ -175,7 +480,6 @@ class FluxTransformerBlock(nn.Module):
175
480
  hidden_states = hidden_states + ip_attn_output
176
481
 
177
482
  # Process attention outputs for the `encoder_hidden_states`.
178
-
179
483
  context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
180
484
  encoder_hidden_states = encoder_hidden_states + context_attn_output
181
485
 
@@ -190,8 +494,45 @@ class FluxTransformerBlock(nn.Module):
190
494
  return encoder_hidden_states, hidden_states
191
495
 
192
496
 
497
+ class FluxPosEmbed(nn.Module):
498
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
499
+ def __init__(self, theta: int, axes_dim: List[int]):
500
+ super().__init__()
501
+ self.theta = theta
502
+ self.axes_dim = axes_dim
503
+
504
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
505
+ n_axes = ids.shape[-1]
506
+ cos_out = []
507
+ sin_out = []
508
+ pos = ids.float()
509
+ is_mps = ids.device.type == "mps"
510
+ is_npu = ids.device.type == "npu"
511
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
512
+ for i in range(n_axes):
513
+ cos, sin = get_1d_rotary_pos_embed(
514
+ self.axes_dim[i],
515
+ pos[:, i],
516
+ theta=self.theta,
517
+ repeat_interleave_real=True,
518
+ use_real=True,
519
+ freqs_dtype=freqs_dtype,
520
+ )
521
+ cos_out.append(cos)
522
+ sin_out.append(sin)
523
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
524
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
525
+ return freqs_cos, freqs_sin
526
+
527
+
193
528
  class FluxTransformer2DModel(
194
- ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
529
+ ModelMixin,
530
+ ConfigMixin,
531
+ PeftAdapterMixin,
532
+ FromOriginalModelMixin,
533
+ FluxTransformer2DLoadersMixin,
534
+ CacheMixin,
535
+ AttentionMixin,
195
536
  ):
196
537
  """
197
538
  The Transformer model introduced in Flux.
@@ -227,6 +568,7 @@ class FluxTransformer2DModel(
227
568
  _supports_gradient_checkpointing = True
228
569
  _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
229
570
  _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
571
+ _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
230
572
 
231
573
  @register_to_config
232
574
  def __init__(
@@ -241,7 +583,7 @@ class FluxTransformer2DModel(
241
583
  joint_attention_dim: int = 4096,
242
584
  pooled_projection_dim: int = 768,
243
585
  guidance_embeds: bool = False,
244
- axes_dims_rope: Tuple[int] = (16, 56, 56),
586
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
245
587
  ):
246
588
  super().__init__()
247
589
  self.out_channels = out_channels or in_channels
@@ -286,106 +628,6 @@ class FluxTransformer2DModel(
286
628
 
287
629
  self.gradient_checkpointing = False
288
630
 
289
- @property
290
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
291
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
292
- r"""
293
- Returns:
294
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
295
- indexed by its weight name.
296
- """
297
- # set recursively
298
- processors = {}
299
-
300
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
301
- if hasattr(module, "get_processor"):
302
- processors[f"{name}.processor"] = module.get_processor()
303
-
304
- for sub_name, child in module.named_children():
305
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
306
-
307
- return processors
308
-
309
- for name, module in self.named_children():
310
- fn_recursive_add_processors(name, module, processors)
311
-
312
- return processors
313
-
314
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
315
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
316
- r"""
317
- Sets the attention processor to use to compute attention.
318
-
319
- Parameters:
320
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
321
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
322
- for **all** `Attention` layers.
323
-
324
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
325
- processor. This is strongly recommended when setting trainable attention processors.
326
-
327
- """
328
- count = len(self.attn_processors.keys())
329
-
330
- if isinstance(processor, dict) and len(processor) != count:
331
- raise ValueError(
332
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
333
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
334
- )
335
-
336
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
337
- if hasattr(module, "set_processor"):
338
- if not isinstance(processor, dict):
339
- module.set_processor(processor)
340
- else:
341
- module.set_processor(processor.pop(f"{name}.processor"))
342
-
343
- for sub_name, child in module.named_children():
344
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
345
-
346
- for name, module in self.named_children():
347
- fn_recursive_attn_processor(name, module, processor)
348
-
349
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
350
- def fuse_qkv_projections(self):
351
- """
352
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
353
- are fused. For cross-attention modules, key and value projection matrices are fused.
354
-
355
- <Tip warning={true}>
356
-
357
- This API is 🧪 experimental.
358
-
359
- </Tip>
360
- """
361
- self.original_attn_processors = None
362
-
363
- for _, attn_processor in self.attn_processors.items():
364
- if "Added" in str(attn_processor.__class__.__name__):
365
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
366
-
367
- self.original_attn_processors = self.attn_processors
368
-
369
- for module in self.modules():
370
- if isinstance(module, Attention):
371
- module.fuse_projections(fuse=True)
372
-
373
- self.set_attn_processor(FusedFluxAttnProcessor2_0())
374
-
375
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
376
- def unfuse_qkv_projections(self):
377
- """Disables the fused QKV projection if enabled.
378
-
379
- <Tip warning={true}>
380
-
381
- This API is 🧪 experimental.
382
-
383
- </Tip>
384
-
385
- """
386
- if self.original_attn_processors is not None:
387
- self.set_attn_processor(self.original_attn_processors)
388
-
389
631
  def forward(
390
632
  self,
391
633
  hidden_states: torch.Tensor,
@@ -447,8 +689,6 @@ class FluxTransformer2DModel(
447
689
  timestep = timestep.to(hidden_states.dtype) * 1000
448
690
  if guidance is not None:
449
691
  guidance = guidance.to(hidden_states.dtype) * 1000
450
- else:
451
- guidance = None
452
692
 
453
693
  temb = (
454
694
  self.time_text_embed(timestep, pooled_projections)
@@ -486,6 +726,7 @@ class FluxTransformer2DModel(
486
726
  encoder_hidden_states,
487
727
  temb,
488
728
  image_rotary_emb,
729
+ joint_attention_kwargs,
489
730
  )
490
731
 
491
732
  else:
@@ -508,20 +749,22 @@ class FluxTransformer2DModel(
508
749
  )
509
750
  else:
510
751
  hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
511
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
512
752
 
513
753
  for index_block, block in enumerate(self.single_transformer_blocks):
514
754
  if torch.is_grad_enabled() and self.gradient_checkpointing:
515
- hidden_states = self._gradient_checkpointing_func(
755
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
516
756
  block,
517
757
  hidden_states,
758
+ encoder_hidden_states,
518
759
  temb,
519
760
  image_rotary_emb,
761
+ joint_attention_kwargs,
520
762
  )
521
763
 
522
764
  else:
523
- hidden_states = block(
765
+ encoder_hidden_states, hidden_states = block(
524
766
  hidden_states=hidden_states,
767
+ encoder_hidden_states=encoder_hidden_states,
525
768
  temb=temb,
526
769
  image_rotary_emb=image_rotary_emb,
527
770
  joint_attention_kwargs=joint_attention_kwargs,
@@ -531,12 +774,7 @@ class FluxTransformer2DModel(
531
774
  if controlnet_single_block_samples is not None:
532
775
  interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
533
776
  interval_control = int(np.ceil(interval_control))
534
- hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
535
- hidden_states[:, encoder_hidden_states.shape[1] :, ...]
536
- + controlnet_single_block_samples[index_block // interval_control]
537
- )
538
-
539
- hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
777
+ hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
540
778
 
541
779
  hidden_states = self.norm_out(hidden_states, temb)
542
780
  output = self.proj_out(hidden_states)