diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +198 -28
- diffusers/loaders/lora_conversion_utils.py +679 -44
- diffusers/loaders/lora_pipeline.py +1963 -801
- diffusers/loaders/peft.py +169 -84
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +653 -75
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +22 -32
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +10 -2
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +14 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.1.dist-info/RECORD +0 -550
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
diffusers/hooks/hooks.py
ADDED
@@ -0,0 +1,236 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import functools
|
16
|
+
from typing import Any, Dict, Optional, Tuple
|
17
|
+
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from ..utils.logging import get_logger
|
21
|
+
|
22
|
+
|
23
|
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
24
|
+
|
25
|
+
|
26
|
+
class ModelHook:
|
27
|
+
r"""
|
28
|
+
A hook that contains callbacks to be executed just before and after the forward method of a model.
|
29
|
+
"""
|
30
|
+
|
31
|
+
_is_stateful = False
|
32
|
+
|
33
|
+
def __init__(self):
|
34
|
+
self.fn_ref: "HookFunctionReference" = None
|
35
|
+
|
36
|
+
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
37
|
+
r"""
|
38
|
+
Hook that is executed when a model is initialized.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
module (`torch.nn.Module`):
|
42
|
+
The module attached to this hook.
|
43
|
+
"""
|
44
|
+
return module
|
45
|
+
|
46
|
+
def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
47
|
+
r"""
|
48
|
+
Hook that is executed when a model is deinitalized.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
module (`torch.nn.Module`):
|
52
|
+
The module attached to this hook.
|
53
|
+
"""
|
54
|
+
return module
|
55
|
+
|
56
|
+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
|
57
|
+
r"""
|
58
|
+
Hook that is executed just before the forward method of the model.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
module (`torch.nn.Module`):
|
62
|
+
The module whose forward pass will be executed just after this event.
|
63
|
+
args (`Tuple[Any]`):
|
64
|
+
The positional arguments passed to the module.
|
65
|
+
kwargs (`Dict[Str, Any]`):
|
66
|
+
The keyword arguments passed to the module.
|
67
|
+
Returns:
|
68
|
+
`Tuple[Tuple[Any], Dict[Str, Any]]`:
|
69
|
+
A tuple with the treated `args` and `kwargs`.
|
70
|
+
"""
|
71
|
+
return args, kwargs
|
72
|
+
|
73
|
+
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
|
74
|
+
r"""
|
75
|
+
Hook that is executed just after the forward method of the model.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
module (`torch.nn.Module`):
|
79
|
+
The module whose forward pass been executed just before this event.
|
80
|
+
output (`Any`):
|
81
|
+
The output of the module.
|
82
|
+
Returns:
|
83
|
+
`Any`: The processed `output`.
|
84
|
+
"""
|
85
|
+
return output
|
86
|
+
|
87
|
+
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
88
|
+
r"""
|
89
|
+
Hook that is executed when the hook is detached from a module.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
module (`torch.nn.Module`):
|
93
|
+
The module detached from this hook.
|
94
|
+
"""
|
95
|
+
return module
|
96
|
+
|
97
|
+
def reset_state(self, module: torch.nn.Module):
|
98
|
+
if self._is_stateful:
|
99
|
+
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
|
100
|
+
return module
|
101
|
+
|
102
|
+
|
103
|
+
class HookFunctionReference:
|
104
|
+
def __init__(self) -> None:
|
105
|
+
"""A container class that maintains mutable references to forward pass functions in a hook chain.
|
106
|
+
|
107
|
+
Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the
|
108
|
+
entire forward pass structure.
|
109
|
+
|
110
|
+
Attributes:
|
111
|
+
pre_forward: A callable that processes inputs before the main forward pass.
|
112
|
+
post_forward: A callable that processes outputs after the main forward pass.
|
113
|
+
forward: The current forward function in the hook chain.
|
114
|
+
original_forward: The original forward function, stored when a hook provides a custom new_forward.
|
115
|
+
|
116
|
+
The class enables hook removal by allowing updates to the forward chain through reference modification rather
|
117
|
+
than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to
|
118
|
+
be updated, preserving the execution order of the remaining hooks.
|
119
|
+
"""
|
120
|
+
self.pre_forward = None
|
121
|
+
self.post_forward = None
|
122
|
+
self.forward = None
|
123
|
+
self.original_forward = None
|
124
|
+
|
125
|
+
|
126
|
+
class HookRegistry:
|
127
|
+
def __init__(self, module_ref: torch.nn.Module) -> None:
|
128
|
+
super().__init__()
|
129
|
+
|
130
|
+
self.hooks: Dict[str, ModelHook] = {}
|
131
|
+
|
132
|
+
self._module_ref = module_ref
|
133
|
+
self._hook_order = []
|
134
|
+
self._fn_refs = []
|
135
|
+
|
136
|
+
def register_hook(self, hook: ModelHook, name: str) -> None:
|
137
|
+
if name in self.hooks.keys():
|
138
|
+
raise ValueError(
|
139
|
+
f"Hook with name {name} already exists in the registry. Please use a different name or "
|
140
|
+
f"first remove the existing hook and then add a new one."
|
141
|
+
)
|
142
|
+
|
143
|
+
self._module_ref = hook.initialize_hook(self._module_ref)
|
144
|
+
|
145
|
+
def create_new_forward(function_reference: HookFunctionReference):
|
146
|
+
def new_forward(module, *args, **kwargs):
|
147
|
+
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
|
148
|
+
output = function_reference.forward(*args, **kwargs)
|
149
|
+
return function_reference.post_forward(module, output)
|
150
|
+
|
151
|
+
return new_forward
|
152
|
+
|
153
|
+
forward = self._module_ref.forward
|
154
|
+
|
155
|
+
fn_ref = HookFunctionReference()
|
156
|
+
fn_ref.pre_forward = hook.pre_forward
|
157
|
+
fn_ref.post_forward = hook.post_forward
|
158
|
+
fn_ref.forward = forward
|
159
|
+
|
160
|
+
if hasattr(hook, "new_forward"):
|
161
|
+
fn_ref.original_forward = forward
|
162
|
+
fn_ref.forward = functools.update_wrapper(
|
163
|
+
functools.partial(hook.new_forward, self._module_ref), hook.new_forward
|
164
|
+
)
|
165
|
+
|
166
|
+
rewritten_forward = create_new_forward(fn_ref)
|
167
|
+
self._module_ref.forward = functools.update_wrapper(
|
168
|
+
functools.partial(rewritten_forward, self._module_ref), rewritten_forward
|
169
|
+
)
|
170
|
+
|
171
|
+
hook.fn_ref = fn_ref
|
172
|
+
self.hooks[name] = hook
|
173
|
+
self._hook_order.append(name)
|
174
|
+
self._fn_refs.append(fn_ref)
|
175
|
+
|
176
|
+
def get_hook(self, name: str) -> Optional[ModelHook]:
|
177
|
+
return self.hooks.get(name, None)
|
178
|
+
|
179
|
+
def remove_hook(self, name: str, recurse: bool = True) -> None:
|
180
|
+
if name in self.hooks.keys():
|
181
|
+
num_hooks = len(self._hook_order)
|
182
|
+
hook = self.hooks[name]
|
183
|
+
index = self._hook_order.index(name)
|
184
|
+
fn_ref = self._fn_refs[index]
|
185
|
+
|
186
|
+
old_forward = fn_ref.forward
|
187
|
+
if fn_ref.original_forward is not None:
|
188
|
+
old_forward = fn_ref.original_forward
|
189
|
+
|
190
|
+
if index == num_hooks - 1:
|
191
|
+
self._module_ref.forward = old_forward
|
192
|
+
else:
|
193
|
+
self._fn_refs[index + 1].forward = old_forward
|
194
|
+
|
195
|
+
self._module_ref = hook.deinitalize_hook(self._module_ref)
|
196
|
+
del self.hooks[name]
|
197
|
+
self._hook_order.pop(index)
|
198
|
+
self._fn_refs.pop(index)
|
199
|
+
|
200
|
+
if recurse:
|
201
|
+
for module_name, module in self._module_ref.named_modules():
|
202
|
+
if module_name == "":
|
203
|
+
continue
|
204
|
+
if hasattr(module, "_diffusers_hook"):
|
205
|
+
module._diffusers_hook.remove_hook(name, recurse=False)
|
206
|
+
|
207
|
+
def reset_stateful_hooks(self, recurse: bool = True) -> None:
|
208
|
+
for hook_name in reversed(self._hook_order):
|
209
|
+
hook = self.hooks[hook_name]
|
210
|
+
if hook._is_stateful:
|
211
|
+
hook.reset_state(self._module_ref)
|
212
|
+
|
213
|
+
if recurse:
|
214
|
+
for module_name, module in self._module_ref.named_modules():
|
215
|
+
if module_name == "":
|
216
|
+
continue
|
217
|
+
if hasattr(module, "_diffusers_hook"):
|
218
|
+
module._diffusers_hook.reset_stateful_hooks(recurse=False)
|
219
|
+
|
220
|
+
@classmethod
|
221
|
+
def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
|
222
|
+
if not hasattr(module, "_diffusers_hook"):
|
223
|
+
module._diffusers_hook = cls(module)
|
224
|
+
return module._diffusers_hook
|
225
|
+
|
226
|
+
def __repr__(self) -> str:
|
227
|
+
registry_repr = ""
|
228
|
+
for i, hook_name in enumerate(self._hook_order):
|
229
|
+
if self.hooks[hook_name].__class__.__repr__ is not object.__repr__:
|
230
|
+
hook_repr = self.hooks[hook_name].__repr__()
|
231
|
+
else:
|
232
|
+
hook_repr = self.hooks[hook_name].__class__.__name__
|
233
|
+
registry_repr += f" ({i}) {hook_name} - {hook_repr}"
|
234
|
+
if i < len(self._hook_order) - 1:
|
235
|
+
registry_repr += "\n"
|
236
|
+
return f"HookRegistry(\n{registry_repr}\n)"
|
@@ -0,0 +1,245 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import re
|
16
|
+
from typing import Optional, Tuple, Type, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from ..utils import get_logger, is_peft_available, is_peft_version
|
21
|
+
from .hooks import HookRegistry, ModelHook
|
22
|
+
|
23
|
+
|
24
|
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
25
|
+
|
26
|
+
|
27
|
+
# fmt: off
|
28
|
+
_LAYERWISE_CASTING_HOOK = "layerwise_casting"
|
29
|
+
_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
|
30
|
+
SUPPORTED_PYTORCH_LAYERS = (
|
31
|
+
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
|
32
|
+
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
|
33
|
+
torch.nn.Linear,
|
34
|
+
)
|
35
|
+
|
36
|
+
DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
|
37
|
+
# fmt: on
|
38
|
+
|
39
|
+
_SHOULD_DISABLE_PEFT_INPUT_AUTOCAST = is_peft_available() and is_peft_version(">", "0.14.0")
|
40
|
+
if _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST:
|
41
|
+
from peft.helpers import disable_input_dtype_casting
|
42
|
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
43
|
+
|
44
|
+
|
45
|
+
class LayerwiseCastingHook(ModelHook):
|
46
|
+
r"""
|
47
|
+
A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
|
48
|
+
for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
|
49
|
+
footprint.
|
50
|
+
"""
|
51
|
+
|
52
|
+
_is_stateful = False
|
53
|
+
|
54
|
+
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
|
55
|
+
self.storage_dtype = storage_dtype
|
56
|
+
self.compute_dtype = compute_dtype
|
57
|
+
self.non_blocking = non_blocking
|
58
|
+
|
59
|
+
def initialize_hook(self, module: torch.nn.Module):
|
60
|
+
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
|
61
|
+
return module
|
62
|
+
|
63
|
+
def deinitalize_hook(self, module: torch.nn.Module):
|
64
|
+
raise NotImplementedError(
|
65
|
+
"LayerwiseCastingHook does not support deinitalization. A model once enabled with layerwise casting will "
|
66
|
+
"have casted its weights to a lower precision dtype for storage. Casting this back to the original dtype "
|
67
|
+
"will lead to precision loss, which might have an impact on the model's generation quality. The model should "
|
68
|
+
"be re-initialized and loaded in the original dtype."
|
69
|
+
)
|
70
|
+
|
71
|
+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
|
72
|
+
module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
|
73
|
+
return args, kwargs
|
74
|
+
|
75
|
+
def post_forward(self, module: torch.nn.Module, output):
|
76
|
+
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
|
77
|
+
return output
|
78
|
+
|
79
|
+
|
80
|
+
class PeftInputAutocastDisableHook(ModelHook):
|
81
|
+
r"""
|
82
|
+
A hook that disables the casting of inputs to the module weight dtype during the forward pass. By default, PEFT
|
83
|
+
casts the inputs to the weight dtype of the module, which can lead to precision loss.
|
84
|
+
|
85
|
+
The reasons for needing this are:
|
86
|
+
- If we don't add PEFT layers' weight names to `skip_modules_pattern` when applying layerwise casting, the
|
87
|
+
inputs will be casted to the, possibly lower precision, storage dtype. Reference:
|
88
|
+
https://github.com/huggingface/peft/blob/0facdebf6208139cbd8f3586875acb378813dd97/src/peft/tuners/lora/layer.py#L706
|
89
|
+
- We can, on our end, use something like accelerate's `send_to_device` but for dtypes. This way, we can ensure
|
90
|
+
that the inputs are casted to the computation dtype correctly always. However, there are two goals we are
|
91
|
+
hoping to achieve:
|
92
|
+
1. Making forward implementations independent of device/dtype casting operations as much as possible.
|
93
|
+
2. Peforming inference without losing information from casting to different precisions. With the current
|
94
|
+
PEFT implementation (as linked in the reference above), and assuming running layerwise casting inference
|
95
|
+
with storage_dtype=torch.float8_e4m3fn and compute_dtype=torch.bfloat16, inputs are cast to
|
96
|
+
torch.float8_e4m3fn in the lora layer. We will then upcast back to torch.bfloat16 when we continue the
|
97
|
+
forward pass in PEFT linear forward or Diffusers layer forward, with a `send_to_dtype` operation from
|
98
|
+
LayerwiseCastingHook. This will be a lossy operation and result in poorer generation quality.
|
99
|
+
"""
|
100
|
+
|
101
|
+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
102
|
+
with disable_input_dtype_casting(module):
|
103
|
+
return self.fn_ref.original_forward(*args, **kwargs)
|
104
|
+
|
105
|
+
|
106
|
+
def apply_layerwise_casting(
|
107
|
+
module: torch.nn.Module,
|
108
|
+
storage_dtype: torch.dtype,
|
109
|
+
compute_dtype: torch.dtype,
|
110
|
+
skip_modules_pattern: Union[str, Tuple[str, ...]] = "auto",
|
111
|
+
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
|
112
|
+
non_blocking: bool = False,
|
113
|
+
) -> None:
|
114
|
+
r"""
|
115
|
+
Applies layerwise casting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
|
116
|
+
nn.Module using diffusers layers or pytorch primitives.
|
117
|
+
|
118
|
+
Example:
|
119
|
+
|
120
|
+
```python
|
121
|
+
>>> import torch
|
122
|
+
>>> from diffusers import CogVideoXTransformer3DModel
|
123
|
+
|
124
|
+
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
|
125
|
+
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
126
|
+
... )
|
127
|
+
|
128
|
+
>>> apply_layerwise_casting(
|
129
|
+
... transformer,
|
130
|
+
... storage_dtype=torch.float8_e4m3fn,
|
131
|
+
... compute_dtype=torch.bfloat16,
|
132
|
+
... skip_modules_pattern=["patch_embed", "norm", "proj_out"],
|
133
|
+
... non_blocking=True,
|
134
|
+
... )
|
135
|
+
```
|
136
|
+
|
137
|
+
Args:
|
138
|
+
module (`torch.nn.Module`):
|
139
|
+
The module whose leaf modules will be cast to a high precision dtype for computation, and to a low
|
140
|
+
precision dtype for storage.
|
141
|
+
storage_dtype (`torch.dtype`):
|
142
|
+
The dtype to cast the module to before/after the forward pass for storage.
|
143
|
+
compute_dtype (`torch.dtype`):
|
144
|
+
The dtype to cast the module to during the forward pass for computation.
|
145
|
+
skip_modules_pattern (`Tuple[str, ...]`, defaults to `"auto"`):
|
146
|
+
A list of patterns to match the names of the modules to skip during the layerwise casting process. If set
|
147
|
+
to `"auto"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None`
|
148
|
+
alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the module
|
149
|
+
instead of its internal submodules.
|
150
|
+
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
|
151
|
+
A list of module classes to skip during the layerwise casting process.
|
152
|
+
non_blocking (`bool`, defaults to `False`):
|
153
|
+
If `True`, the weight casting operations are non-blocking.
|
154
|
+
"""
|
155
|
+
if skip_modules_pattern == "auto":
|
156
|
+
skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
|
157
|
+
|
158
|
+
if skip_modules_classes is None and skip_modules_pattern is None:
|
159
|
+
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
|
160
|
+
return
|
161
|
+
|
162
|
+
_apply_layerwise_casting(
|
163
|
+
module,
|
164
|
+
storage_dtype,
|
165
|
+
compute_dtype,
|
166
|
+
skip_modules_pattern,
|
167
|
+
skip_modules_classes,
|
168
|
+
non_blocking,
|
169
|
+
)
|
170
|
+
_disable_peft_input_autocast(module)
|
171
|
+
|
172
|
+
|
173
|
+
def _apply_layerwise_casting(
|
174
|
+
module: torch.nn.Module,
|
175
|
+
storage_dtype: torch.dtype,
|
176
|
+
compute_dtype: torch.dtype,
|
177
|
+
skip_modules_pattern: Optional[Tuple[str, ...]] = None,
|
178
|
+
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
|
179
|
+
non_blocking: bool = False,
|
180
|
+
_prefix: str = "",
|
181
|
+
) -> None:
|
182
|
+
should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
|
183
|
+
skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
|
184
|
+
)
|
185
|
+
if should_skip:
|
186
|
+
logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
|
187
|
+
return
|
188
|
+
|
189
|
+
if isinstance(module, SUPPORTED_PYTORCH_LAYERS):
|
190
|
+
logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
|
191
|
+
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
|
192
|
+
return
|
193
|
+
|
194
|
+
for name, submodule in module.named_children():
|
195
|
+
layer_name = f"{_prefix}.{name}" if _prefix else name
|
196
|
+
_apply_layerwise_casting(
|
197
|
+
submodule,
|
198
|
+
storage_dtype,
|
199
|
+
compute_dtype,
|
200
|
+
skip_modules_pattern,
|
201
|
+
skip_modules_classes,
|
202
|
+
non_blocking,
|
203
|
+
_prefix=layer_name,
|
204
|
+
)
|
205
|
+
|
206
|
+
|
207
|
+
def apply_layerwise_casting_hook(
|
208
|
+
module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool
|
209
|
+
) -> None:
|
210
|
+
r"""
|
211
|
+
Applies a `LayerwiseCastingHook` to a given module.
|
212
|
+
|
213
|
+
Args:
|
214
|
+
module (`torch.nn.Module`):
|
215
|
+
The module to attach the hook to.
|
216
|
+
storage_dtype (`torch.dtype`):
|
217
|
+
The dtype to cast the module to before the forward pass.
|
218
|
+
compute_dtype (`torch.dtype`):
|
219
|
+
The dtype to cast the module to during the forward pass.
|
220
|
+
non_blocking (`bool`):
|
221
|
+
If `True`, the weight casting operations are non-blocking.
|
222
|
+
"""
|
223
|
+
registry = HookRegistry.check_if_exists_or_initialize(module)
|
224
|
+
hook = LayerwiseCastingHook(storage_dtype, compute_dtype, non_blocking)
|
225
|
+
registry.register_hook(hook, _LAYERWISE_CASTING_HOOK)
|
226
|
+
|
227
|
+
|
228
|
+
def _is_layerwise_casting_active(module: torch.nn.Module) -> bool:
|
229
|
+
for submodule in module.modules():
|
230
|
+
if (
|
231
|
+
hasattr(submodule, "_diffusers_hook")
|
232
|
+
and submodule._diffusers_hook.get_hook(_LAYERWISE_CASTING_HOOK) is not None
|
233
|
+
):
|
234
|
+
return True
|
235
|
+
return False
|
236
|
+
|
237
|
+
|
238
|
+
def _disable_peft_input_autocast(module: torch.nn.Module) -> None:
|
239
|
+
if not _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST:
|
240
|
+
return
|
241
|
+
for submodule in module.modules():
|
242
|
+
if isinstance(submodule, BaseTunerLayer) and _is_layerwise_casting_active(submodule):
|
243
|
+
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
244
|
+
hook = PeftInputAutocastDisableHook()
|
245
|
+
registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)
|