diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +198 -28
- diffusers/loaders/lora_conversion_utils.py +679 -44
- diffusers/loaders/lora_pipeline.py +1963 -801
- diffusers/loaders/peft.py +169 -84
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +653 -75
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +22 -32
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +10 -2
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +14 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.1.dist-info/RECORD +0 -550
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,311 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import re
|
16
|
+
from dataclasses import dataclass
|
17
|
+
from typing import Any, Callable, Optional, Tuple, Union
|
18
|
+
|
19
|
+
import torch
|
20
|
+
|
21
|
+
from ..models.attention_processor import Attention, MochiAttention
|
22
|
+
from ..utils import logging
|
23
|
+
from .hooks import HookRegistry, ModelHook
|
24
|
+
|
25
|
+
|
26
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
27
|
+
|
28
|
+
|
29
|
+
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
|
30
|
+
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
31
|
+
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
|
32
|
+
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
33
|
+
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
|
34
|
+
|
35
|
+
|
36
|
+
@dataclass
|
37
|
+
class PyramidAttentionBroadcastConfig:
|
38
|
+
r"""
|
39
|
+
Configuration for Pyramid Attention Broadcast.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`):
|
43
|
+
The number of times a specific spatial attention broadcast is skipped before computing the attention states
|
44
|
+
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
|
45
|
+
old attention states will be re-used) before computing the new attention states again.
|
46
|
+
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
|
47
|
+
The number of times a specific temporal attention broadcast is skipped before computing the attention
|
48
|
+
states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times
|
49
|
+
(i.e., old attention states will be re-used) before computing the new attention states again.
|
50
|
+
cross_attention_block_skip_range (`int`, *optional*, defaults to `None`):
|
51
|
+
The number of times a specific cross-attention broadcast is skipped before computing the attention states
|
52
|
+
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
|
53
|
+
old attention states will be re-used) before computing the new attention states again.
|
54
|
+
spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
55
|
+
The range of timesteps to skip in the spatial attention layer. The attention computations will be
|
56
|
+
conditionally skipped if the current timestep is within the specified range.
|
57
|
+
temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
58
|
+
The range of timesteps to skip in the temporal attention layer. The attention computations will be
|
59
|
+
conditionally skipped if the current timestep is within the specified range.
|
60
|
+
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
61
|
+
The range of timesteps to skip in the cross-attention layer. The attention computations will be
|
62
|
+
conditionally skipped if the current timestep is within the specified range.
|
63
|
+
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
|
64
|
+
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
|
65
|
+
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
|
66
|
+
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
|
67
|
+
cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
|
68
|
+
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
|
69
|
+
"""
|
70
|
+
|
71
|
+
spatial_attention_block_skip_range: Optional[int] = None
|
72
|
+
temporal_attention_block_skip_range: Optional[int] = None
|
73
|
+
cross_attention_block_skip_range: Optional[int] = None
|
74
|
+
|
75
|
+
spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
76
|
+
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
77
|
+
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
78
|
+
|
79
|
+
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
80
|
+
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
81
|
+
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
|
82
|
+
|
83
|
+
current_timestep_callback: Callable[[], int] = None
|
84
|
+
|
85
|
+
# TODO(aryan): add PAB for MLP layers (very limited speedup from testing with original codebase
|
86
|
+
# so not added for now)
|
87
|
+
|
88
|
+
def __repr__(self) -> str:
|
89
|
+
return (
|
90
|
+
f"PyramidAttentionBroadcastConfig(\n"
|
91
|
+
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
|
92
|
+
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
|
93
|
+
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
|
94
|
+
f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
|
95
|
+
f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
|
96
|
+
f" cross_attention_timestep_skip_range={self.cross_attention_timestep_skip_range},\n"
|
97
|
+
f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
|
98
|
+
f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
|
99
|
+
f" cross_attention_block_identifiers={self.cross_attention_block_identifiers},\n"
|
100
|
+
f" current_timestep_callback={self.current_timestep_callback}\n"
|
101
|
+
")"
|
102
|
+
)
|
103
|
+
|
104
|
+
|
105
|
+
class PyramidAttentionBroadcastState:
|
106
|
+
r"""
|
107
|
+
State for Pyramid Attention Broadcast.
|
108
|
+
|
109
|
+
Attributes:
|
110
|
+
iteration (`int`):
|
111
|
+
The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is
|
112
|
+
called before starting a new inference forward pass for PAB to work correctly.
|
113
|
+
cache (`Any`):
|
114
|
+
The cached output from the previous forward pass. This is used to re-use the attention states when the
|
115
|
+
attention computation is skipped. It is either a tensor or a tuple of tensors, depending on the module.
|
116
|
+
"""
|
117
|
+
|
118
|
+
def __init__(self) -> None:
|
119
|
+
self.iteration = 0
|
120
|
+
self.cache = None
|
121
|
+
|
122
|
+
def reset(self):
|
123
|
+
self.iteration = 0
|
124
|
+
self.cache = None
|
125
|
+
|
126
|
+
def __repr__(self):
|
127
|
+
cache_repr = ""
|
128
|
+
if self.cache is None:
|
129
|
+
cache_repr = "None"
|
130
|
+
else:
|
131
|
+
cache_repr = f"Tensor(shape={self.cache.shape}, dtype={self.cache.dtype})"
|
132
|
+
return f"PyramidAttentionBroadcastState(iteration={self.iteration}, cache={cache_repr})"
|
133
|
+
|
134
|
+
|
135
|
+
class PyramidAttentionBroadcastHook(ModelHook):
|
136
|
+
r"""A hook that applies Pyramid Attention Broadcast to a given module."""
|
137
|
+
|
138
|
+
_is_stateful = True
|
139
|
+
|
140
|
+
def __init__(
|
141
|
+
self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
|
142
|
+
) -> None:
|
143
|
+
super().__init__()
|
144
|
+
|
145
|
+
self.timestep_skip_range = timestep_skip_range
|
146
|
+
self.block_skip_range = block_skip_range
|
147
|
+
self.current_timestep_callback = current_timestep_callback
|
148
|
+
|
149
|
+
def initialize_hook(self, module):
|
150
|
+
self.state = PyramidAttentionBroadcastState()
|
151
|
+
return module
|
152
|
+
|
153
|
+
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
|
154
|
+
is_within_timestep_range = (
|
155
|
+
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
|
156
|
+
)
|
157
|
+
should_compute_attention = (
|
158
|
+
self.state.cache is None
|
159
|
+
or self.state.iteration == 0
|
160
|
+
or not is_within_timestep_range
|
161
|
+
or self.state.iteration % self.block_skip_range == 0
|
162
|
+
)
|
163
|
+
|
164
|
+
if should_compute_attention:
|
165
|
+
output = self.fn_ref.original_forward(*args, **kwargs)
|
166
|
+
else:
|
167
|
+
output = self.state.cache
|
168
|
+
|
169
|
+
self.state.cache = output
|
170
|
+
self.state.iteration += 1
|
171
|
+
return output
|
172
|
+
|
173
|
+
def reset_state(self, module: torch.nn.Module) -> None:
|
174
|
+
self.state.reset()
|
175
|
+
return module
|
176
|
+
|
177
|
+
|
178
|
+
def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig):
|
179
|
+
r"""
|
180
|
+
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
|
181
|
+
|
182
|
+
PAB is an attention approximation method that leverages the similarity in attention states between timesteps to
|
183
|
+
reduce the computational cost of attention computation. The key takeaway from the paper is that the attention
|
184
|
+
similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and
|
185
|
+
spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently
|
186
|
+
than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process.
|
187
|
+
|
188
|
+
Args:
|
189
|
+
module (`torch.nn.Module`):
|
190
|
+
The module to apply Pyramid Attention Broadcast to.
|
191
|
+
config (`Optional[PyramidAttentionBroadcastConfig]`, `optional`, defaults to `None`):
|
192
|
+
The configuration to use for Pyramid Attention Broadcast.
|
193
|
+
|
194
|
+
Example:
|
195
|
+
|
196
|
+
```python
|
197
|
+
>>> import torch
|
198
|
+
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
199
|
+
>>> from diffusers.utils import export_to_video
|
200
|
+
|
201
|
+
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
202
|
+
>>> pipe.to("cuda")
|
203
|
+
|
204
|
+
>>> config = PyramidAttentionBroadcastConfig(
|
205
|
+
... spatial_attention_block_skip_range=2,
|
206
|
+
... spatial_attention_timestep_skip_range=(100, 800),
|
207
|
+
... current_timestep_callback=lambda: pipe.current_timestep,
|
208
|
+
... )
|
209
|
+
>>> apply_pyramid_attention_broadcast(pipe.transformer, config)
|
210
|
+
```
|
211
|
+
"""
|
212
|
+
if config.current_timestep_callback is None:
|
213
|
+
raise ValueError(
|
214
|
+
"The `current_timestep_callback` function must be provided in the configuration to apply Pyramid Attention Broadcast."
|
215
|
+
)
|
216
|
+
|
217
|
+
if (
|
218
|
+
config.spatial_attention_block_skip_range is None
|
219
|
+
and config.temporal_attention_block_skip_range is None
|
220
|
+
and config.cross_attention_block_skip_range is None
|
221
|
+
):
|
222
|
+
logger.warning(
|
223
|
+
"Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range`, `temporal_attention_block_skip_range` "
|
224
|
+
"or `cross_attention_block_skip_range` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2`. "
|
225
|
+
"To avoid this warning, please set one of the above parameters."
|
226
|
+
)
|
227
|
+
config.spatial_attention_block_skip_range = 2
|
228
|
+
|
229
|
+
for name, submodule in module.named_modules():
|
230
|
+
if not isinstance(submodule, _ATTENTION_CLASSES):
|
231
|
+
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
|
232
|
+
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
|
233
|
+
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
|
234
|
+
continue
|
235
|
+
_apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config)
|
236
|
+
|
237
|
+
|
238
|
+
def _apply_pyramid_attention_broadcast_on_attention_class(
|
239
|
+
name: str, module: Attention, config: PyramidAttentionBroadcastConfig
|
240
|
+
) -> bool:
|
241
|
+
is_spatial_self_attention = (
|
242
|
+
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
|
243
|
+
and config.spatial_attention_block_skip_range is not None
|
244
|
+
and not getattr(module, "is_cross_attention", False)
|
245
|
+
)
|
246
|
+
is_temporal_self_attention = (
|
247
|
+
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
|
248
|
+
and config.temporal_attention_block_skip_range is not None
|
249
|
+
and not getattr(module, "is_cross_attention", False)
|
250
|
+
)
|
251
|
+
is_cross_attention = (
|
252
|
+
any(re.search(identifier, name) is not None for identifier in config.cross_attention_block_identifiers)
|
253
|
+
and config.cross_attention_block_skip_range is not None
|
254
|
+
and getattr(module, "is_cross_attention", False)
|
255
|
+
)
|
256
|
+
|
257
|
+
block_skip_range, timestep_skip_range, block_type = None, None, None
|
258
|
+
if is_spatial_self_attention:
|
259
|
+
block_skip_range = config.spatial_attention_block_skip_range
|
260
|
+
timestep_skip_range = config.spatial_attention_timestep_skip_range
|
261
|
+
block_type = "spatial"
|
262
|
+
elif is_temporal_self_attention:
|
263
|
+
block_skip_range = config.temporal_attention_block_skip_range
|
264
|
+
timestep_skip_range = config.temporal_attention_timestep_skip_range
|
265
|
+
block_type = "temporal"
|
266
|
+
elif is_cross_attention:
|
267
|
+
block_skip_range = config.cross_attention_block_skip_range
|
268
|
+
timestep_skip_range = config.cross_attention_timestep_skip_range
|
269
|
+
block_type = "cross"
|
270
|
+
|
271
|
+
if block_skip_range is None or timestep_skip_range is None:
|
272
|
+
logger.info(
|
273
|
+
f'Unable to apply Pyramid Attention Broadcast to the selected layer: "{name}" because it does '
|
274
|
+
f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, "
|
275
|
+
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
|
276
|
+
f"block identifiers in the configuration."
|
277
|
+
)
|
278
|
+
return False
|
279
|
+
|
280
|
+
logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}")
|
281
|
+
_apply_pyramid_attention_broadcast_hook(
|
282
|
+
module, timestep_skip_range, block_skip_range, config.current_timestep_callback
|
283
|
+
)
|
284
|
+
return True
|
285
|
+
|
286
|
+
|
287
|
+
def _apply_pyramid_attention_broadcast_hook(
|
288
|
+
module: Union[Attention, MochiAttention],
|
289
|
+
timestep_skip_range: Tuple[int, int],
|
290
|
+
block_skip_range: int,
|
291
|
+
current_timestep_callback: Callable[[], int],
|
292
|
+
):
|
293
|
+
r"""
|
294
|
+
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module.
|
295
|
+
|
296
|
+
Args:
|
297
|
+
module (`torch.nn.Module`):
|
298
|
+
The module to apply Pyramid Attention Broadcast to.
|
299
|
+
timestep_skip_range (`Tuple[int, int]`):
|
300
|
+
The range of timesteps to skip in the attention layer. The attention computations will be conditionally
|
301
|
+
skipped if the current timestep is within the specified range.
|
302
|
+
block_skip_range (`int`):
|
303
|
+
The number of times a specific attention broadcast is skipped before computing the attention states to
|
304
|
+
re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old
|
305
|
+
attention states will be re-used) before computing the new attention states again.
|
306
|
+
current_timestep_callback (`Callable[[], int]`):
|
307
|
+
A callback function that returns the current inference timestep.
|
308
|
+
"""
|
309
|
+
registry = HookRegistry.check_if_exists_or_initialize(module)
|
310
|
+
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
|
311
|
+
registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK)
|
diffusers/loaders/__init__.py
CHANGED
@@ -70,9 +70,12 @@ if is_torch_available():
|
|
70
70
|
"LoraLoaderMixin",
|
71
71
|
"FluxLoraLoaderMixin",
|
72
72
|
"CogVideoXLoraLoaderMixin",
|
73
|
+
"CogView4LoraLoaderMixin",
|
73
74
|
"Mochi1LoraLoaderMixin",
|
74
75
|
"HunyuanVideoLoraLoaderMixin",
|
75
76
|
"SanaLoraLoaderMixin",
|
77
|
+
"Lumina2LoraLoaderMixin",
|
78
|
+
"WanLoraLoaderMixin",
|
76
79
|
]
|
77
80
|
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
78
81
|
_import_structure["ip_adapter"] = [
|
@@ -101,15 +104,18 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
101
104
|
from .lora_pipeline import (
|
102
105
|
AmusedLoraLoaderMixin,
|
103
106
|
CogVideoXLoraLoaderMixin,
|
107
|
+
CogView4LoraLoaderMixin,
|
104
108
|
FluxLoraLoaderMixin,
|
105
109
|
HunyuanVideoLoraLoaderMixin,
|
106
110
|
LoraLoaderMixin,
|
107
111
|
LTXVideoLoraLoaderMixin,
|
112
|
+
Lumina2LoraLoaderMixin,
|
108
113
|
Mochi1LoraLoaderMixin,
|
109
114
|
SanaLoraLoaderMixin,
|
110
115
|
SD3LoraLoaderMixin,
|
111
116
|
StableDiffusionLoraLoaderMixin,
|
112
117
|
StableDiffusionXLLoraLoaderMixin,
|
118
|
+
WanLoraLoaderMixin,
|
113
119
|
)
|
114
120
|
from .single_file import FromSingleFileMixin
|
115
121
|
from .textual_inversion import TextualInversionLoaderMixin
|
diffusers/loaders/ip_adapter.py
CHANGED
@@ -23,7 +23,9 @@ from safetensors import safe_open
|
|
23
23
|
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
24
24
|
from ..utils import (
|
25
25
|
USE_PEFT_BACKEND,
|
26
|
+
_get_detailed_type,
|
26
27
|
_get_model_file,
|
28
|
+
_is_valid_type,
|
27
29
|
is_accelerate_available,
|
28
30
|
is_torch_version,
|
29
31
|
is_transformers_available,
|
@@ -213,7 +215,8 @@ class IPAdapterMixin:
|
|
213
215
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
214
216
|
cache_dir=cache_dir,
|
215
217
|
local_files_only=local_files_only,
|
216
|
-
|
218
|
+
torch_dtype=self.dtype,
|
219
|
+
).to(self.device)
|
217
220
|
self.register_modules(image_encoder=image_encoder)
|
218
221
|
else:
|
219
222
|
raise ValueError(
|
@@ -292,8 +295,7 @@ class IPAdapterMixin:
|
|
292
295
|
):
|
293
296
|
if len(scale_configs) != len(attn_processor.scale):
|
294
297
|
raise ValueError(
|
295
|
-
f"Cannot assign {len(scale_configs)} scale_configs to "
|
296
|
-
f"{len(attn_processor.scale)} IP-Adapter."
|
298
|
+
f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
|
297
299
|
)
|
298
300
|
elif len(scale_configs) == 1:
|
299
301
|
scale_configs = scale_configs * len(attn_processor.scale)
|
@@ -524,8 +526,9 @@ class FluxIPAdapterMixin:
|
|
524
526
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
525
527
|
cache_dir=cache_dir,
|
526
528
|
local_files_only=local_files_only,
|
529
|
+
dtype=image_encoder_dtype,
|
527
530
|
)
|
528
|
-
.to(self.device
|
531
|
+
.to(self.device)
|
529
532
|
.eval()
|
530
533
|
)
|
531
534
|
self.register_modules(image_encoder=image_encoder)
|
@@ -577,29 +580,36 @@ class FluxIPAdapterMixin:
|
|
577
580
|
pipeline.set_ip_adapter_scale(ip_strengths)
|
578
581
|
```
|
579
582
|
"""
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
583
|
+
|
584
|
+
scale_type = Union[int, float]
|
585
|
+
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
|
586
|
+
num_layers = self.transformer.config.num_layers
|
587
|
+
|
588
|
+
# Single value for all layers of all IP-Adapters
|
589
|
+
if isinstance(scale, scale_type):
|
590
|
+
scale = [scale for _ in range(num_ip_adapters)]
|
591
|
+
# List of per-layer scales for a single IP-Adapter
|
592
|
+
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
|
586
593
|
scale = [scale]
|
594
|
+
# Invalid scale type
|
595
|
+
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
|
596
|
+
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
|
587
597
|
|
588
|
-
|
598
|
+
if len(scale) != num_ip_adapters:
|
599
|
+
raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")
|
589
600
|
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
key_id += 1
|
601
|
+
if any(len(s) != num_layers for s in scale if isinstance(s, list)):
|
602
|
+
invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}
|
603
|
+
raise ValueError(
|
604
|
+
f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."
|
605
|
+
)
|
606
|
+
|
607
|
+
# Scalars are transformed to lists with length num_layers
|
608
|
+
scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]
|
609
|
+
|
610
|
+
# Set scales. zip over scale_configs prevents going into single transformer layers
|
611
|
+
for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):
|
612
|
+
attn_processor.scale = scale
|
603
613
|
|
604
614
|
def unload_ip_adapter(self):
|
605
615
|
"""
|
@@ -793,12 +803,10 @@ class SD3IPAdapterMixin:
|
|
793
803
|
}
|
794
804
|
|
795
805
|
self.register_modules(
|
796
|
-
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs)
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
self.device, dtype=self.dtype
|
801
|
-
),
|
806
|
+
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs),
|
807
|
+
image_encoder=SiglipVisionModel.from_pretrained(
|
808
|
+
image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
|
809
|
+
).to(self.device),
|
802
810
|
)
|
803
811
|
else:
|
804
812
|
raise ValueError(
|