diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +145 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +3 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +2 -2
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +3 -3
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +9 -8
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +332 -227
- diffusers/hooks/hooks.py +58 -3
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +5 -10
- diffusers/hooks/pyramid_attention_broadcast.py +15 -12
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +10 -0
- diffusers/loaders/ip_adapter.py +260 -18
- diffusers/loaders/lora_base.py +261 -127
- diffusers/loaders/lora_conversion_utils.py +657 -35
- diffusers/loaders/lora_pipeline.py +2778 -1246
- diffusers/loaders/peft.py +78 -112
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +64 -15
- diffusers/loaders/single_file_utils.py +395 -7
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +10 -11
- diffusers/loaders/transformer_sd3.py +8 -3
- diffusers/loaders/unet.py +24 -21
- diffusers/loaders/unet_loader_utils.py +6 -3
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +23 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +488 -7
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +113 -667
- diffusers/models/auto_model.py +49 -12
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +17 -4
- diffusers/models/autoencoders/autoencoder_kl.py +5 -5
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +32 -10
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +21 -20
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +5 -5
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +36 -46
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +203 -108
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +7 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +641 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +353 -27
- diffusers/models/transformers/transformer_cosmos.py +586 -0
- diffusers/models/transformers/transformer_flux.py +376 -138
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +105 -24
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +316 -87
- diffusers/models/transformers/transformer_wan_vace.py +387 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +4 -3
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +68 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +23 -20
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +4 -2
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +37 -36
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
- diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +5 -6
- diffusers/pipelines/pipeline_loading_utils.py +113 -15
- diffusers/pipelines/pipeline_utils.py +127 -48
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +91 -30
- diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
- diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +3 -1
- diffusers/quantizers/base.py +17 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +108 -16
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +16 -9
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -2
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
- diffusers/schedulers/scheduling_utils.py +3 -3
- diffusers/schedulers/scheduling_utils_flax.py +2 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +91 -5
- diffusers/utils/__init__.py +15 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +4 -0
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +432 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
- diffusers/utils/dynamic_modules_utils.py +85 -8
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +151 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +96 -10
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +195 -17
- diffusers/utils/torch_utils.py +43 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
- diffusers-0.35.0.dist-info/RECORD +703 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -203,8 +203,8 @@ class Attention(nn.Module):
|
|
203
203
|
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
|
204
204
|
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
|
205
205
|
elif qk_norm == "rms_norm":
|
206
|
-
self.norm_q = RMSNorm(dim_head, eps=eps)
|
207
|
-
self.norm_k = RMSNorm(dim_head, eps=eps)
|
206
|
+
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
207
|
+
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
208
208
|
elif qk_norm == "rms_norm_across_heads":
|
209
209
|
# LTX applies qk norm across all heads
|
210
210
|
self.norm_q = RMSNorm(dim_head * heads, eps=eps)
|
@@ -2272,554 +2272,6 @@ class FusedAuraFlowAttnProcessor2_0:
|
|
2272
2272
|
return hidden_states
|
2273
2273
|
|
2274
2274
|
|
2275
|
-
class FluxAttnProcessor2_0:
|
2276
|
-
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
2277
|
-
|
2278
|
-
def __init__(self):
|
2279
|
-
if not hasattr(F, "scaled_dot_product_attention"):
|
2280
|
-
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
2281
|
-
|
2282
|
-
def __call__(
|
2283
|
-
self,
|
2284
|
-
attn: Attention,
|
2285
|
-
hidden_states: torch.FloatTensor,
|
2286
|
-
encoder_hidden_states: torch.FloatTensor = None,
|
2287
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
2288
|
-
image_rotary_emb: Optional[torch.Tensor] = None,
|
2289
|
-
) -> torch.FloatTensor:
|
2290
|
-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2291
|
-
|
2292
|
-
# `sample` projections.
|
2293
|
-
query = attn.to_q(hidden_states)
|
2294
|
-
key = attn.to_k(hidden_states)
|
2295
|
-
value = attn.to_v(hidden_states)
|
2296
|
-
|
2297
|
-
inner_dim = key.shape[-1]
|
2298
|
-
head_dim = inner_dim // attn.heads
|
2299
|
-
|
2300
|
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2301
|
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2302
|
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2303
|
-
|
2304
|
-
if attn.norm_q is not None:
|
2305
|
-
query = attn.norm_q(query)
|
2306
|
-
if attn.norm_k is not None:
|
2307
|
-
key = attn.norm_k(key)
|
2308
|
-
|
2309
|
-
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2310
|
-
if encoder_hidden_states is not None:
|
2311
|
-
# `context` projections.
|
2312
|
-
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
2313
|
-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
2314
|
-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
2315
|
-
|
2316
|
-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2317
|
-
batch_size, -1, attn.heads, head_dim
|
2318
|
-
).transpose(1, 2)
|
2319
|
-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2320
|
-
batch_size, -1, attn.heads, head_dim
|
2321
|
-
).transpose(1, 2)
|
2322
|
-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2323
|
-
batch_size, -1, attn.heads, head_dim
|
2324
|
-
).transpose(1, 2)
|
2325
|
-
|
2326
|
-
if attn.norm_added_q is not None:
|
2327
|
-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2328
|
-
if attn.norm_added_k is not None:
|
2329
|
-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
2330
|
-
|
2331
|
-
# attention
|
2332
|
-
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
2333
|
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2334
|
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
2335
|
-
|
2336
|
-
if image_rotary_emb is not None:
|
2337
|
-
from .embeddings import apply_rotary_emb
|
2338
|
-
|
2339
|
-
query = apply_rotary_emb(query, image_rotary_emb)
|
2340
|
-
key = apply_rotary_emb(key, image_rotary_emb)
|
2341
|
-
|
2342
|
-
hidden_states = F.scaled_dot_product_attention(
|
2343
|
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
2344
|
-
)
|
2345
|
-
|
2346
|
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2347
|
-
hidden_states = hidden_states.to(query.dtype)
|
2348
|
-
|
2349
|
-
if encoder_hidden_states is not None:
|
2350
|
-
encoder_hidden_states, hidden_states = (
|
2351
|
-
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2352
|
-
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2353
|
-
)
|
2354
|
-
|
2355
|
-
# linear proj
|
2356
|
-
hidden_states = attn.to_out[0](hidden_states)
|
2357
|
-
# dropout
|
2358
|
-
hidden_states = attn.to_out[1](hidden_states)
|
2359
|
-
|
2360
|
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2361
|
-
|
2362
|
-
return hidden_states, encoder_hidden_states
|
2363
|
-
else:
|
2364
|
-
return hidden_states
|
2365
|
-
|
2366
|
-
|
2367
|
-
class FluxAttnProcessor2_0_NPU:
|
2368
|
-
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
2369
|
-
|
2370
|
-
def __init__(self):
|
2371
|
-
if not hasattr(F, "scaled_dot_product_attention"):
|
2372
|
-
raise ImportError(
|
2373
|
-
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
|
2374
|
-
)
|
2375
|
-
|
2376
|
-
def __call__(
|
2377
|
-
self,
|
2378
|
-
attn: Attention,
|
2379
|
-
hidden_states: torch.FloatTensor,
|
2380
|
-
encoder_hidden_states: torch.FloatTensor = None,
|
2381
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
2382
|
-
image_rotary_emb: Optional[torch.Tensor] = None,
|
2383
|
-
) -> torch.FloatTensor:
|
2384
|
-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2385
|
-
|
2386
|
-
# `sample` projections.
|
2387
|
-
query = attn.to_q(hidden_states)
|
2388
|
-
key = attn.to_k(hidden_states)
|
2389
|
-
value = attn.to_v(hidden_states)
|
2390
|
-
|
2391
|
-
inner_dim = key.shape[-1]
|
2392
|
-
head_dim = inner_dim // attn.heads
|
2393
|
-
|
2394
|
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2395
|
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2396
|
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2397
|
-
|
2398
|
-
if attn.norm_q is not None:
|
2399
|
-
query = attn.norm_q(query)
|
2400
|
-
if attn.norm_k is not None:
|
2401
|
-
key = attn.norm_k(key)
|
2402
|
-
|
2403
|
-
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2404
|
-
if encoder_hidden_states is not None:
|
2405
|
-
# `context` projections.
|
2406
|
-
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
2407
|
-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
2408
|
-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
2409
|
-
|
2410
|
-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2411
|
-
batch_size, -1, attn.heads, head_dim
|
2412
|
-
).transpose(1, 2)
|
2413
|
-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2414
|
-
batch_size, -1, attn.heads, head_dim
|
2415
|
-
).transpose(1, 2)
|
2416
|
-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2417
|
-
batch_size, -1, attn.heads, head_dim
|
2418
|
-
).transpose(1, 2)
|
2419
|
-
|
2420
|
-
if attn.norm_added_q is not None:
|
2421
|
-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2422
|
-
if attn.norm_added_k is not None:
|
2423
|
-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
2424
|
-
|
2425
|
-
# attention
|
2426
|
-
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
2427
|
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2428
|
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
2429
|
-
|
2430
|
-
if image_rotary_emb is not None:
|
2431
|
-
from .embeddings import apply_rotary_emb
|
2432
|
-
|
2433
|
-
query = apply_rotary_emb(query, image_rotary_emb)
|
2434
|
-
key = apply_rotary_emb(key, image_rotary_emb)
|
2435
|
-
|
2436
|
-
if query.dtype in (torch.float16, torch.bfloat16):
|
2437
|
-
hidden_states = torch_npu.npu_fusion_attention(
|
2438
|
-
query,
|
2439
|
-
key,
|
2440
|
-
value,
|
2441
|
-
attn.heads,
|
2442
|
-
input_layout="BNSD",
|
2443
|
-
pse=None,
|
2444
|
-
scale=1.0 / math.sqrt(query.shape[-1]),
|
2445
|
-
pre_tockens=65536,
|
2446
|
-
next_tockens=65536,
|
2447
|
-
keep_prob=1.0,
|
2448
|
-
sync=False,
|
2449
|
-
inner_precise=0,
|
2450
|
-
)[0]
|
2451
|
-
else:
|
2452
|
-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
2453
|
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2454
|
-
hidden_states = hidden_states.to(query.dtype)
|
2455
|
-
|
2456
|
-
if encoder_hidden_states is not None:
|
2457
|
-
encoder_hidden_states, hidden_states = (
|
2458
|
-
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2459
|
-
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2460
|
-
)
|
2461
|
-
|
2462
|
-
# linear proj
|
2463
|
-
hidden_states = attn.to_out[0](hidden_states)
|
2464
|
-
# dropout
|
2465
|
-
hidden_states = attn.to_out[1](hidden_states)
|
2466
|
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2467
|
-
|
2468
|
-
return hidden_states, encoder_hidden_states
|
2469
|
-
else:
|
2470
|
-
return hidden_states
|
2471
|
-
|
2472
|
-
|
2473
|
-
class FusedFluxAttnProcessor2_0:
|
2474
|
-
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
2475
|
-
|
2476
|
-
def __init__(self):
|
2477
|
-
if not hasattr(F, "scaled_dot_product_attention"):
|
2478
|
-
raise ImportError(
|
2479
|
-
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
2480
|
-
)
|
2481
|
-
|
2482
|
-
def __call__(
|
2483
|
-
self,
|
2484
|
-
attn: Attention,
|
2485
|
-
hidden_states: torch.FloatTensor,
|
2486
|
-
encoder_hidden_states: torch.FloatTensor = None,
|
2487
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
2488
|
-
image_rotary_emb: Optional[torch.Tensor] = None,
|
2489
|
-
) -> torch.FloatTensor:
|
2490
|
-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2491
|
-
|
2492
|
-
# `sample` projections.
|
2493
|
-
qkv = attn.to_qkv(hidden_states)
|
2494
|
-
split_size = qkv.shape[-1] // 3
|
2495
|
-
query, key, value = torch.split(qkv, split_size, dim=-1)
|
2496
|
-
|
2497
|
-
inner_dim = key.shape[-1]
|
2498
|
-
head_dim = inner_dim // attn.heads
|
2499
|
-
|
2500
|
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2501
|
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2502
|
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2503
|
-
|
2504
|
-
if attn.norm_q is not None:
|
2505
|
-
query = attn.norm_q(query)
|
2506
|
-
if attn.norm_k is not None:
|
2507
|
-
key = attn.norm_k(key)
|
2508
|
-
|
2509
|
-
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2510
|
-
# `context` projections.
|
2511
|
-
if encoder_hidden_states is not None:
|
2512
|
-
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
2513
|
-
split_size = encoder_qkv.shape[-1] // 3
|
2514
|
-
(
|
2515
|
-
encoder_hidden_states_query_proj,
|
2516
|
-
encoder_hidden_states_key_proj,
|
2517
|
-
encoder_hidden_states_value_proj,
|
2518
|
-
) = torch.split(encoder_qkv, split_size, dim=-1)
|
2519
|
-
|
2520
|
-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2521
|
-
batch_size, -1, attn.heads, head_dim
|
2522
|
-
).transpose(1, 2)
|
2523
|
-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2524
|
-
batch_size, -1, attn.heads, head_dim
|
2525
|
-
).transpose(1, 2)
|
2526
|
-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2527
|
-
batch_size, -1, attn.heads, head_dim
|
2528
|
-
).transpose(1, 2)
|
2529
|
-
|
2530
|
-
if attn.norm_added_q is not None:
|
2531
|
-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2532
|
-
if attn.norm_added_k is not None:
|
2533
|
-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
2534
|
-
|
2535
|
-
# attention
|
2536
|
-
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
2537
|
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2538
|
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
2539
|
-
|
2540
|
-
if image_rotary_emb is not None:
|
2541
|
-
from .embeddings import apply_rotary_emb
|
2542
|
-
|
2543
|
-
query = apply_rotary_emb(query, image_rotary_emb)
|
2544
|
-
key = apply_rotary_emb(key, image_rotary_emb)
|
2545
|
-
|
2546
|
-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
2547
|
-
|
2548
|
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2549
|
-
hidden_states = hidden_states.to(query.dtype)
|
2550
|
-
|
2551
|
-
if encoder_hidden_states is not None:
|
2552
|
-
encoder_hidden_states, hidden_states = (
|
2553
|
-
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2554
|
-
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2555
|
-
)
|
2556
|
-
|
2557
|
-
# linear proj
|
2558
|
-
hidden_states = attn.to_out[0](hidden_states)
|
2559
|
-
# dropout
|
2560
|
-
hidden_states = attn.to_out[1](hidden_states)
|
2561
|
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2562
|
-
|
2563
|
-
return hidden_states, encoder_hidden_states
|
2564
|
-
else:
|
2565
|
-
return hidden_states
|
2566
|
-
|
2567
|
-
|
2568
|
-
class FusedFluxAttnProcessor2_0_NPU:
|
2569
|
-
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
2570
|
-
|
2571
|
-
def __init__(self):
|
2572
|
-
if not hasattr(F, "scaled_dot_product_attention"):
|
2573
|
-
raise ImportError(
|
2574
|
-
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
|
2575
|
-
)
|
2576
|
-
|
2577
|
-
def __call__(
|
2578
|
-
self,
|
2579
|
-
attn: Attention,
|
2580
|
-
hidden_states: torch.FloatTensor,
|
2581
|
-
encoder_hidden_states: torch.FloatTensor = None,
|
2582
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
2583
|
-
image_rotary_emb: Optional[torch.Tensor] = None,
|
2584
|
-
) -> torch.FloatTensor:
|
2585
|
-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2586
|
-
|
2587
|
-
# `sample` projections.
|
2588
|
-
qkv = attn.to_qkv(hidden_states)
|
2589
|
-
split_size = qkv.shape[-1] // 3
|
2590
|
-
query, key, value = torch.split(qkv, split_size, dim=-1)
|
2591
|
-
|
2592
|
-
inner_dim = key.shape[-1]
|
2593
|
-
head_dim = inner_dim // attn.heads
|
2594
|
-
|
2595
|
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2596
|
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2597
|
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2598
|
-
|
2599
|
-
if attn.norm_q is not None:
|
2600
|
-
query = attn.norm_q(query)
|
2601
|
-
if attn.norm_k is not None:
|
2602
|
-
key = attn.norm_k(key)
|
2603
|
-
|
2604
|
-
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2605
|
-
# `context` projections.
|
2606
|
-
if encoder_hidden_states is not None:
|
2607
|
-
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
2608
|
-
split_size = encoder_qkv.shape[-1] // 3
|
2609
|
-
(
|
2610
|
-
encoder_hidden_states_query_proj,
|
2611
|
-
encoder_hidden_states_key_proj,
|
2612
|
-
encoder_hidden_states_value_proj,
|
2613
|
-
) = torch.split(encoder_qkv, split_size, dim=-1)
|
2614
|
-
|
2615
|
-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2616
|
-
batch_size, -1, attn.heads, head_dim
|
2617
|
-
).transpose(1, 2)
|
2618
|
-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2619
|
-
batch_size, -1, attn.heads, head_dim
|
2620
|
-
).transpose(1, 2)
|
2621
|
-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2622
|
-
batch_size, -1, attn.heads, head_dim
|
2623
|
-
).transpose(1, 2)
|
2624
|
-
|
2625
|
-
if attn.norm_added_q is not None:
|
2626
|
-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2627
|
-
if attn.norm_added_k is not None:
|
2628
|
-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
2629
|
-
|
2630
|
-
# attention
|
2631
|
-
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
2632
|
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2633
|
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
2634
|
-
|
2635
|
-
if image_rotary_emb is not None:
|
2636
|
-
from .embeddings import apply_rotary_emb
|
2637
|
-
|
2638
|
-
query = apply_rotary_emb(query, image_rotary_emb)
|
2639
|
-
key = apply_rotary_emb(key, image_rotary_emb)
|
2640
|
-
|
2641
|
-
if query.dtype in (torch.float16, torch.bfloat16):
|
2642
|
-
hidden_states = torch_npu.npu_fusion_attention(
|
2643
|
-
query,
|
2644
|
-
key,
|
2645
|
-
value,
|
2646
|
-
attn.heads,
|
2647
|
-
input_layout="BNSD",
|
2648
|
-
pse=None,
|
2649
|
-
scale=1.0 / math.sqrt(query.shape[-1]),
|
2650
|
-
pre_tockens=65536,
|
2651
|
-
next_tockens=65536,
|
2652
|
-
keep_prob=1.0,
|
2653
|
-
sync=False,
|
2654
|
-
inner_precise=0,
|
2655
|
-
)[0]
|
2656
|
-
else:
|
2657
|
-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
2658
|
-
|
2659
|
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2660
|
-
hidden_states = hidden_states.to(query.dtype)
|
2661
|
-
|
2662
|
-
if encoder_hidden_states is not None:
|
2663
|
-
encoder_hidden_states, hidden_states = (
|
2664
|
-
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2665
|
-
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2666
|
-
)
|
2667
|
-
|
2668
|
-
# linear proj
|
2669
|
-
hidden_states = attn.to_out[0](hidden_states)
|
2670
|
-
# dropout
|
2671
|
-
hidden_states = attn.to_out[1](hidden_states)
|
2672
|
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2673
|
-
|
2674
|
-
return hidden_states, encoder_hidden_states
|
2675
|
-
else:
|
2676
|
-
return hidden_states
|
2677
|
-
|
2678
|
-
|
2679
|
-
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
2680
|
-
"""Flux Attention processor for IP-Adapter."""
|
2681
|
-
|
2682
|
-
def __init__(
|
2683
|
-
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
2684
|
-
):
|
2685
|
-
super().__init__()
|
2686
|
-
|
2687
|
-
if not hasattr(F, "scaled_dot_product_attention"):
|
2688
|
-
raise ImportError(
|
2689
|
-
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
2690
|
-
)
|
2691
|
-
|
2692
|
-
self.hidden_size = hidden_size
|
2693
|
-
self.cross_attention_dim = cross_attention_dim
|
2694
|
-
|
2695
|
-
if not isinstance(num_tokens, (tuple, list)):
|
2696
|
-
num_tokens = [num_tokens]
|
2697
|
-
|
2698
|
-
if not isinstance(scale, list):
|
2699
|
-
scale = [scale] * len(num_tokens)
|
2700
|
-
if len(scale) != len(num_tokens):
|
2701
|
-
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
2702
|
-
self.scale = scale
|
2703
|
-
|
2704
|
-
self.to_k_ip = nn.ModuleList(
|
2705
|
-
[
|
2706
|
-
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
2707
|
-
for _ in range(len(num_tokens))
|
2708
|
-
]
|
2709
|
-
)
|
2710
|
-
self.to_v_ip = nn.ModuleList(
|
2711
|
-
[
|
2712
|
-
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
2713
|
-
for _ in range(len(num_tokens))
|
2714
|
-
]
|
2715
|
-
)
|
2716
|
-
|
2717
|
-
def __call__(
|
2718
|
-
self,
|
2719
|
-
attn: Attention,
|
2720
|
-
hidden_states: torch.FloatTensor,
|
2721
|
-
encoder_hidden_states: torch.FloatTensor = None,
|
2722
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
2723
|
-
image_rotary_emb: Optional[torch.Tensor] = None,
|
2724
|
-
ip_hidden_states: Optional[List[torch.Tensor]] = None,
|
2725
|
-
ip_adapter_masks: Optional[torch.Tensor] = None,
|
2726
|
-
) -> torch.FloatTensor:
|
2727
|
-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2728
|
-
|
2729
|
-
# `sample` projections.
|
2730
|
-
hidden_states_query_proj = attn.to_q(hidden_states)
|
2731
|
-
key = attn.to_k(hidden_states)
|
2732
|
-
value = attn.to_v(hidden_states)
|
2733
|
-
|
2734
|
-
inner_dim = key.shape[-1]
|
2735
|
-
head_dim = inner_dim // attn.heads
|
2736
|
-
|
2737
|
-
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2738
|
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2739
|
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2740
|
-
|
2741
|
-
if attn.norm_q is not None:
|
2742
|
-
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
|
2743
|
-
if attn.norm_k is not None:
|
2744
|
-
key = attn.norm_k(key)
|
2745
|
-
|
2746
|
-
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2747
|
-
if encoder_hidden_states is not None:
|
2748
|
-
# `context` projections.
|
2749
|
-
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
2750
|
-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
2751
|
-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
2752
|
-
|
2753
|
-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2754
|
-
batch_size, -1, attn.heads, head_dim
|
2755
|
-
).transpose(1, 2)
|
2756
|
-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2757
|
-
batch_size, -1, attn.heads, head_dim
|
2758
|
-
).transpose(1, 2)
|
2759
|
-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2760
|
-
batch_size, -1, attn.heads, head_dim
|
2761
|
-
).transpose(1, 2)
|
2762
|
-
|
2763
|
-
if attn.norm_added_q is not None:
|
2764
|
-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2765
|
-
if attn.norm_added_k is not None:
|
2766
|
-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
2767
|
-
|
2768
|
-
# attention
|
2769
|
-
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
|
2770
|
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2771
|
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
2772
|
-
|
2773
|
-
if image_rotary_emb is not None:
|
2774
|
-
from .embeddings import apply_rotary_emb
|
2775
|
-
|
2776
|
-
query = apply_rotary_emb(query, image_rotary_emb)
|
2777
|
-
key = apply_rotary_emb(key, image_rotary_emb)
|
2778
|
-
|
2779
|
-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
2780
|
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2781
|
-
hidden_states = hidden_states.to(query.dtype)
|
2782
|
-
|
2783
|
-
if encoder_hidden_states is not None:
|
2784
|
-
encoder_hidden_states, hidden_states = (
|
2785
|
-
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2786
|
-
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2787
|
-
)
|
2788
|
-
|
2789
|
-
# linear proj
|
2790
|
-
hidden_states = attn.to_out[0](hidden_states)
|
2791
|
-
# dropout
|
2792
|
-
hidden_states = attn.to_out[1](hidden_states)
|
2793
|
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2794
|
-
|
2795
|
-
# IP-adapter
|
2796
|
-
ip_query = hidden_states_query_proj
|
2797
|
-
ip_attn_output = torch.zeros_like(hidden_states)
|
2798
|
-
|
2799
|
-
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
2800
|
-
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
2801
|
-
):
|
2802
|
-
ip_key = to_k_ip(current_ip_hidden_states)
|
2803
|
-
ip_value = to_v_ip(current_ip_hidden_states)
|
2804
|
-
|
2805
|
-
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2806
|
-
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2807
|
-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2808
|
-
# TODO: add support for attn.scale when we move to Torch 2.1
|
2809
|
-
current_ip_hidden_states = F.scaled_dot_product_attention(
|
2810
|
-
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
2811
|
-
)
|
2812
|
-
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
2813
|
-
batch_size, -1, attn.heads * head_dim
|
2814
|
-
)
|
2815
|
-
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
|
2816
|
-
ip_attn_output += scale * current_ip_hidden_states
|
2817
|
-
|
2818
|
-
return hidden_states, encoder_hidden_states, ip_attn_output
|
2819
|
-
else:
|
2820
|
-
return hidden_states
|
2821
|
-
|
2822
|
-
|
2823
2275
|
class CogVideoXAttnProcessor2_0:
|
2824
2276
|
r"""
|
2825
2277
|
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
@@ -3449,106 +2901,6 @@ class XLAFlashAttnProcessor2_0:
|
|
3449
2901
|
return hidden_states
|
3450
2902
|
|
3451
2903
|
|
3452
|
-
class XLAFluxFlashAttnProcessor2_0:
|
3453
|
-
r"""
|
3454
|
-
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
|
3455
|
-
"""
|
3456
|
-
|
3457
|
-
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
|
3458
|
-
if not hasattr(F, "scaled_dot_product_attention"):
|
3459
|
-
raise ImportError(
|
3460
|
-
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
3461
|
-
)
|
3462
|
-
if is_torch_xla_version("<", "2.3"):
|
3463
|
-
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
|
3464
|
-
if is_spmd() and is_torch_xla_version("<", "2.4"):
|
3465
|
-
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
|
3466
|
-
self.partition_spec = partition_spec
|
3467
|
-
|
3468
|
-
def __call__(
|
3469
|
-
self,
|
3470
|
-
attn: Attention,
|
3471
|
-
hidden_states: torch.FloatTensor,
|
3472
|
-
encoder_hidden_states: torch.FloatTensor = None,
|
3473
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
3474
|
-
image_rotary_emb: Optional[torch.Tensor] = None,
|
3475
|
-
) -> torch.FloatTensor:
|
3476
|
-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
3477
|
-
|
3478
|
-
# `sample` projections.
|
3479
|
-
query = attn.to_q(hidden_states)
|
3480
|
-
key = attn.to_k(hidden_states)
|
3481
|
-
value = attn.to_v(hidden_states)
|
3482
|
-
|
3483
|
-
inner_dim = key.shape[-1]
|
3484
|
-
head_dim = inner_dim // attn.heads
|
3485
|
-
|
3486
|
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3487
|
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3488
|
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3489
|
-
|
3490
|
-
if attn.norm_q is not None:
|
3491
|
-
query = attn.norm_q(query)
|
3492
|
-
if attn.norm_k is not None:
|
3493
|
-
key = attn.norm_k(key)
|
3494
|
-
|
3495
|
-
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
3496
|
-
if encoder_hidden_states is not None:
|
3497
|
-
# `context` projections.
|
3498
|
-
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
3499
|
-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
3500
|
-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
3501
|
-
|
3502
|
-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
3503
|
-
batch_size, -1, attn.heads, head_dim
|
3504
|
-
).transpose(1, 2)
|
3505
|
-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
3506
|
-
batch_size, -1, attn.heads, head_dim
|
3507
|
-
).transpose(1, 2)
|
3508
|
-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
3509
|
-
batch_size, -1, attn.heads, head_dim
|
3510
|
-
).transpose(1, 2)
|
3511
|
-
|
3512
|
-
if attn.norm_added_q is not None:
|
3513
|
-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
3514
|
-
if attn.norm_added_k is not None:
|
3515
|
-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
3516
|
-
|
3517
|
-
# attention
|
3518
|
-
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
3519
|
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
3520
|
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
3521
|
-
|
3522
|
-
if image_rotary_emb is not None:
|
3523
|
-
from .embeddings import apply_rotary_emb
|
3524
|
-
|
3525
|
-
query = apply_rotary_emb(query, image_rotary_emb)
|
3526
|
-
key = apply_rotary_emb(key, image_rotary_emb)
|
3527
|
-
|
3528
|
-
query /= math.sqrt(head_dim)
|
3529
|
-
hidden_states = flash_attention(query, key, value, causal=False)
|
3530
|
-
|
3531
|
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
3532
|
-
hidden_states = hidden_states.to(query.dtype)
|
3533
|
-
|
3534
|
-
if encoder_hidden_states is not None:
|
3535
|
-
encoder_hidden_states, hidden_states = (
|
3536
|
-
hidden_states[:, : encoder_hidden_states.shape[1]],
|
3537
|
-
hidden_states[:, encoder_hidden_states.shape[1] :],
|
3538
|
-
)
|
3539
|
-
|
3540
|
-
# linear proj
|
3541
|
-
hidden_states = attn.to_out[0](hidden_states)
|
3542
|
-
# dropout
|
3543
|
-
hidden_states = attn.to_out[1](hidden_states)
|
3544
|
-
|
3545
|
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
3546
|
-
|
3547
|
-
return hidden_states, encoder_hidden_states
|
3548
|
-
else:
|
3549
|
-
return hidden_states
|
3550
|
-
|
3551
|
-
|
3552
2904
|
class MochiVaeAttnProcessor2_0:
|
3553
2905
|
r"""
|
3554
2906
|
Attention processor used in Mochi VAE.
|
@@ -3972,7 +3324,7 @@ class PAGHunyuanAttnProcessor2_0:
|
|
3972
3324
|
r"""
|
3973
3325
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
3974
3326
|
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
|
3975
|
-
variant of the processor employs [Pertubed Attention Guidance](https://
|
3327
|
+
variant of the processor employs [Pertubed Attention Guidance](https://huggingface.co/papers/2403.17377).
|
3976
3328
|
"""
|
3977
3329
|
|
3978
3330
|
def __init__(self):
|
@@ -4095,7 +3447,7 @@ class PAGCFGHunyuanAttnProcessor2_0:
|
|
4095
3447
|
r"""
|
4096
3448
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
4097
3449
|
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
|
4098
|
-
variant of the processor employs [Pertubed Attention Guidance](https://
|
3450
|
+
variant of the processor employs [Pertubed Attention Guidance](https://huggingface.co/papers/2403.17377).
|
4099
3451
|
"""
|
4100
3452
|
|
4101
3453
|
def __init__(self):
|
@@ -4828,7 +4180,7 @@ class SlicedAttnAddedKVProcessor:
|
|
4828
4180
|
|
4829
4181
|
class SpatialNorm(nn.Module):
|
4830
4182
|
"""
|
4831
|
-
Spatially conditioned normalization as defined in https://
|
4183
|
+
Spatially conditioned normalization as defined in https://huggingface.co/papers/2209.09002.
|
4832
4184
|
|
4833
4185
|
Args:
|
4834
4186
|
f_channels (`int`):
|
@@ -5693,7 +5045,7 @@ class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
|
5693
5045
|
class PAGIdentitySelfAttnProcessor2_0:
|
5694
5046
|
r"""
|
5695
5047
|
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
5696
|
-
PAG reference: https://
|
5048
|
+
PAG reference: https://huggingface.co/papers/2403.17377
|
5697
5049
|
"""
|
5698
5050
|
|
5699
5051
|
def __init__(self):
|
@@ -5792,7 +5144,7 @@ class PAGIdentitySelfAttnProcessor2_0:
|
|
5792
5144
|
class PAGCFGIdentitySelfAttnProcessor2_0:
|
5793
5145
|
r"""
|
5794
5146
|
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
5795
|
-
PAG reference: https://
|
5147
|
+
PAG reference: https://huggingface.co/papers/2403.17377
|
5796
5148
|
"""
|
5797
5149
|
|
5798
5150
|
def __init__(self):
|
@@ -5988,17 +5340,6 @@ class LoRAAttnAddedKVProcessor:
|
|
5988
5340
|
pass
|
5989
5341
|
|
5990
5342
|
|
5991
|
-
class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
|
5992
|
-
r"""
|
5993
|
-
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
5994
|
-
"""
|
5995
|
-
|
5996
|
-
def __init__(self):
|
5997
|
-
deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
|
5998
|
-
deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
|
5999
|
-
super().__init__()
|
6000
|
-
|
6001
|
-
|
6002
5343
|
class SanaLinearAttnProcessor2_0:
|
6003
5344
|
r"""
|
6004
5345
|
Processor for implementing scaled dot-product linear attention.
|
@@ -6163,6 +5504,111 @@ class PAGIdentitySanaLinearAttnProcessor2_0:
|
|
6163
5504
|
return hidden_states
|
6164
5505
|
|
6165
5506
|
|
5507
|
+
class FluxAttnProcessor2_0:
|
5508
|
+
def __new__(cls, *args, **kwargs):
|
5509
|
+
deprecation_message = "`FluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
|
5510
|
+
deprecate("FluxAttnProcessor2_0", "1.0.0", deprecation_message)
|
5511
|
+
|
5512
|
+
from .transformers.transformer_flux import FluxAttnProcessor
|
5513
|
+
|
5514
|
+
return FluxAttnProcessor(*args, **kwargs)
|
5515
|
+
|
5516
|
+
|
5517
|
+
class FluxSingleAttnProcessor2_0:
|
5518
|
+
r"""
|
5519
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
5520
|
+
"""
|
5521
|
+
|
5522
|
+
def __new__(cls, *args, **kwargs):
|
5523
|
+
deprecation_message = "`FluxSingleAttnProcessor` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
|
5524
|
+
deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message)
|
5525
|
+
|
5526
|
+
from .transformers.transformer_flux import FluxAttnProcessor
|
5527
|
+
|
5528
|
+
return FluxAttnProcessor(*args, **kwargs)
|
5529
|
+
|
5530
|
+
|
5531
|
+
class FusedFluxAttnProcessor2_0:
|
5532
|
+
def __new__(cls, *args, **kwargs):
|
5533
|
+
deprecation_message = "`FusedFluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
|
5534
|
+
deprecate("FusedFluxAttnProcessor2_0", "1.0.0", deprecation_message)
|
5535
|
+
|
5536
|
+
from .transformers.transformer_flux import FluxAttnProcessor
|
5537
|
+
|
5538
|
+
return FluxAttnProcessor(*args, **kwargs)
|
5539
|
+
|
5540
|
+
|
5541
|
+
class FluxIPAdapterJointAttnProcessor2_0:
|
5542
|
+
def __new__(cls, *args, **kwargs):
|
5543
|
+
deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`"
|
5544
|
+
deprecate("FluxIPAdapterJointAttnProcessor2_0", "1.0.0", deprecation_message)
|
5545
|
+
|
5546
|
+
from .transformers.transformer_flux import FluxIPAdapterAttnProcessor
|
5547
|
+
|
5548
|
+
return FluxIPAdapterAttnProcessor(*args, **kwargs)
|
5549
|
+
|
5550
|
+
|
5551
|
+
class FluxAttnProcessor2_0_NPU:
|
5552
|
+
def __new__(cls, *args, **kwargs):
|
5553
|
+
deprecation_message = (
|
5554
|
+
"FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
|
5555
|
+
"alternative solution to use NPU Flash Attention will be provided in the future."
|
5556
|
+
)
|
5557
|
+
deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
|
5558
|
+
|
5559
|
+
from .transformers.transformer_flux import FluxAttnProcessor
|
5560
|
+
|
5561
|
+
processor = FluxAttnProcessor()
|
5562
|
+
processor._attention_backend = "_native_npu"
|
5563
|
+
return processor
|
5564
|
+
|
5565
|
+
|
5566
|
+
class FusedFluxAttnProcessor2_0_NPU:
|
5567
|
+
def __new__(self):
|
5568
|
+
deprecation_message = (
|
5569
|
+
"FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
|
5570
|
+
"alternative solution to use NPU Flash Attention will be provided in the future."
|
5571
|
+
)
|
5572
|
+
deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
|
5573
|
+
|
5574
|
+
from .transformers.transformer_flux import FluxAttnProcessor
|
5575
|
+
|
5576
|
+
processor = FluxAttnProcessor()
|
5577
|
+
processor._attention_backend = "_fused_npu"
|
5578
|
+
return processor
|
5579
|
+
|
5580
|
+
|
5581
|
+
class XLAFluxFlashAttnProcessor2_0:
|
5582
|
+
r"""
|
5583
|
+
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
|
5584
|
+
"""
|
5585
|
+
|
5586
|
+
def __new__(cls, *args, **kwargs):
|
5587
|
+
deprecation_message = (
|
5588
|
+
"XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An "
|
5589
|
+
"alternative solution to using XLA Flash Attention will be provided in the future."
|
5590
|
+
)
|
5591
|
+
deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
|
5592
|
+
|
5593
|
+
if is_torch_xla_version("<", "2.3"):
|
5594
|
+
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
|
5595
|
+
if is_spmd() and is_torch_xla_version("<", "2.4"):
|
5596
|
+
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
|
5597
|
+
|
5598
|
+
from .transformers.transformer_flux import FluxAttnProcessor
|
5599
|
+
|
5600
|
+
if len(args) > 0 or kwargs.get("partition_spec", None) is not None:
|
5601
|
+
deprecation_message = (
|
5602
|
+
"partition_spec was not used in the processor implementation when it was added. Passing it "
|
5603
|
+
"is a no-op and support for it will be removed."
|
5604
|
+
)
|
5605
|
+
deprecate("partition_spec", "1.0.0", deprecation_message)
|
5606
|
+
|
5607
|
+
processor = FluxAttnProcessor(*args, **kwargs)
|
5608
|
+
processor._attention_backend = "_native_xla"
|
5609
|
+
return processor
|
5610
|
+
|
5611
|
+
|
6166
5612
|
ADDED_KV_ATTENTION_PROCESSORS = (
|
6167
5613
|
AttnAddedKVProcessor,
|
6168
5614
|
SlicedAttnAddedKVProcessor,
|