diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +595 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
- diffusers-0.33.1.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -154,10 +154,6 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
|
154
154
|
self.register_to_config(block_out_channels=decoder_block_out_channels)
|
155
155
|
self.register_to_config(force_upcast=False)
|
156
156
|
|
157
|
-
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
158
|
-
if isinstance(module, (EncoderTiny, DecoderTiny)):
|
159
|
-
module.gradient_checkpointing = value
|
160
|
-
|
161
157
|
def scale_latents(self, x: torch.Tensor) -> torch.Tensor:
|
162
158
|
"""raw latents -> [0, 1]"""
|
163
159
|
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
|
@@ -60,7 +60,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
|
60
60
|
|
61
61
|
>>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
|
62
62
|
>>> pipe = StableDiffusionPipeline.from_pretrained(
|
63
|
-
... "
|
63
|
+
... "stable-diffusion-v1-5/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
|
64
64
|
... ).to("cuda")
|
65
65
|
|
66
66
|
>>> image = pipe("horse", generator=torch.manual_seed(0)).images[0]
|
@@ -68,6 +68,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
|
68
68
|
```
|
69
69
|
"""
|
70
70
|
|
71
|
+
_supports_group_offloading = False
|
72
|
+
|
71
73
|
@register_to_config
|
72
74
|
def __init__(
|
73
75
|
self,
|
@@ -18,7 +18,7 @@ import numpy as np
|
|
18
18
|
import torch
|
19
19
|
import torch.nn as nn
|
20
20
|
|
21
|
-
from ...utils import BaseOutput
|
21
|
+
from ...utils import BaseOutput
|
22
22
|
from ...utils.torch_utils import randn_tensor
|
23
23
|
from ..activations import get_activation
|
24
24
|
from ..attention_processor import SpatialNorm
|
@@ -156,28 +156,11 @@ class Encoder(nn.Module):
|
|
156
156
|
sample = self.conv_in(sample)
|
157
157
|
|
158
158
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
159
|
-
|
160
|
-
def create_custom_forward(module):
|
161
|
-
def custom_forward(*inputs):
|
162
|
-
return module(*inputs)
|
163
|
-
|
164
|
-
return custom_forward
|
165
|
-
|
166
159
|
# down
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
)
|
172
|
-
# middle
|
173
|
-
sample = torch.utils.checkpoint.checkpoint(
|
174
|
-
create_custom_forward(self.mid_block), sample, use_reentrant=False
|
175
|
-
)
|
176
|
-
else:
|
177
|
-
for down_block in self.down_blocks:
|
178
|
-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
|
179
|
-
# middle
|
180
|
-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
|
160
|
+
for down_block in self.down_blocks:
|
161
|
+
sample = self._gradient_checkpointing_func(down_block, sample)
|
162
|
+
# middle
|
163
|
+
sample = self._gradient_checkpointing_func(self.mid_block, sample)
|
181
164
|
|
182
165
|
else:
|
183
166
|
# down
|
@@ -305,41 +288,13 @@ class Decoder(nn.Module):
|
|
305
288
|
|
306
289
|
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
307
290
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
291
|
+
# middle
|
292
|
+
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
|
293
|
+
sample = sample.to(upscale_dtype)
|
308
294
|
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
return custom_forward
|
314
|
-
|
315
|
-
if is_torch_version(">=", "1.11.0"):
|
316
|
-
# middle
|
317
|
-
sample = torch.utils.checkpoint.checkpoint(
|
318
|
-
create_custom_forward(self.mid_block),
|
319
|
-
sample,
|
320
|
-
latent_embeds,
|
321
|
-
use_reentrant=False,
|
322
|
-
)
|
323
|
-
sample = sample.to(upscale_dtype)
|
324
|
-
|
325
|
-
# up
|
326
|
-
for up_block in self.up_blocks:
|
327
|
-
sample = torch.utils.checkpoint.checkpoint(
|
328
|
-
create_custom_forward(up_block),
|
329
|
-
sample,
|
330
|
-
latent_embeds,
|
331
|
-
use_reentrant=False,
|
332
|
-
)
|
333
|
-
else:
|
334
|
-
# middle
|
335
|
-
sample = torch.utils.checkpoint.checkpoint(
|
336
|
-
create_custom_forward(self.mid_block), sample, latent_embeds
|
337
|
-
)
|
338
|
-
sample = sample.to(upscale_dtype)
|
339
|
-
|
340
|
-
# up
|
341
|
-
for up_block in self.up_blocks:
|
342
|
-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
295
|
+
# up
|
296
|
+
for up_block in self.up_blocks:
|
297
|
+
sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
|
343
298
|
else:
|
344
299
|
# middle
|
345
300
|
sample = self.mid_block(sample, latent_embeds)
|
@@ -558,72 +513,28 @@ class MaskConditionDecoder(nn.Module):
|
|
558
513
|
|
559
514
|
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
560
515
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
516
|
+
# middle
|
517
|
+
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
|
518
|
+
sample = sample.to(upscale_dtype)
|
561
519
|
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
# middle
|
570
|
-
sample = torch.utils.checkpoint.checkpoint(
|
571
|
-
create_custom_forward(self.mid_block),
|
572
|
-
sample,
|
573
|
-
latent_embeds,
|
574
|
-
use_reentrant=False,
|
575
|
-
)
|
576
|
-
sample = sample.to(upscale_dtype)
|
577
|
-
|
578
|
-
# condition encoder
|
579
|
-
if image is not None and mask is not None:
|
580
|
-
masked_image = (1 - mask) * image
|
581
|
-
im_x = torch.utils.checkpoint.checkpoint(
|
582
|
-
create_custom_forward(self.condition_encoder),
|
583
|
-
masked_image,
|
584
|
-
mask,
|
585
|
-
use_reentrant=False,
|
586
|
-
)
|
587
|
-
|
588
|
-
# up
|
589
|
-
for up_block in self.up_blocks:
|
590
|
-
if image is not None and mask is not None:
|
591
|
-
sample_ = im_x[str(tuple(sample.shape))]
|
592
|
-
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
593
|
-
sample = sample * mask_ + sample_ * (1 - mask_)
|
594
|
-
sample = torch.utils.checkpoint.checkpoint(
|
595
|
-
create_custom_forward(up_block),
|
596
|
-
sample,
|
597
|
-
latent_embeds,
|
598
|
-
use_reentrant=False,
|
599
|
-
)
|
600
|
-
if image is not None and mask is not None:
|
601
|
-
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
602
|
-
else:
|
603
|
-
# middle
|
604
|
-
sample = torch.utils.checkpoint.checkpoint(
|
605
|
-
create_custom_forward(self.mid_block), sample, latent_embeds
|
520
|
+
# condition encoder
|
521
|
+
if image is not None and mask is not None:
|
522
|
+
masked_image = (1 - mask) * image
|
523
|
+
im_x = self._gradient_checkpointing_func(
|
524
|
+
self.condition_encoder,
|
525
|
+
masked_image,
|
526
|
+
mask,
|
606
527
|
)
|
607
|
-
sample = sample.to(upscale_dtype)
|
608
528
|
|
609
|
-
|
610
|
-
|
611
|
-
masked_image = (1 - mask) * image
|
612
|
-
im_x = torch.utils.checkpoint.checkpoint(
|
613
|
-
create_custom_forward(self.condition_encoder),
|
614
|
-
masked_image,
|
615
|
-
mask,
|
616
|
-
)
|
617
|
-
|
618
|
-
# up
|
619
|
-
for up_block in self.up_blocks:
|
620
|
-
if image is not None and mask is not None:
|
621
|
-
sample_ = im_x[str(tuple(sample.shape))]
|
622
|
-
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
623
|
-
sample = sample * mask_ + sample_ * (1 - mask_)
|
624
|
-
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
529
|
+
# up
|
530
|
+
for up_block in self.up_blocks:
|
625
531
|
if image is not None and mask is not None:
|
626
|
-
|
532
|
+
sample_ = im_x[str(tuple(sample.shape))]
|
533
|
+
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
534
|
+
sample = sample * mask_ + sample_ * (1 - mask_)
|
535
|
+
sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
|
536
|
+
if image is not None and mask is not None:
|
537
|
+
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
627
538
|
else:
|
628
539
|
# middle
|
629
540
|
sample = self.mid_block(sample, latent_embeds)
|
@@ -890,17 +801,7 @@ class EncoderTiny(nn.Module):
|
|
890
801
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
891
802
|
r"""The forward method of the `EncoderTiny` class."""
|
892
803
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
893
|
-
|
894
|
-
def create_custom_forward(module):
|
895
|
-
def custom_forward(*inputs):
|
896
|
-
return module(*inputs)
|
897
|
-
|
898
|
-
return custom_forward
|
899
|
-
|
900
|
-
if is_torch_version(">=", "1.11.0"):
|
901
|
-
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
|
902
|
-
else:
|
903
|
-
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
|
804
|
+
x = self._gradient_checkpointing_func(self.layers, x)
|
904
805
|
|
905
806
|
else:
|
906
807
|
# scale image from [-1, 1] to [0, 1] to match TAESD convention
|
@@ -976,18 +877,7 @@ class DecoderTiny(nn.Module):
|
|
976
877
|
x = torch.tanh(x / 3) * 3
|
977
878
|
|
978
879
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
979
|
-
|
980
|
-
def create_custom_forward(module):
|
981
|
-
def custom_forward(*inputs):
|
982
|
-
return module(*inputs)
|
983
|
-
|
984
|
-
return custom_forward
|
985
|
-
|
986
|
-
if is_torch_version(">=", "1.11.0"):
|
987
|
-
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
|
988
|
-
else:
|
989
|
-
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
|
990
|
-
|
880
|
+
x = self._gradient_checkpointing_func(self.layers, x)
|
991
881
|
else:
|
992
882
|
x = self.layers(x)
|
993
883
|
|
@@ -71,6 +71,9 @@ class VQModel(ModelMixin, ConfigMixin):
|
|
71
71
|
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
|
72
72
|
"""
|
73
73
|
|
74
|
+
_skip_layerwise_casting_patterns = ["quantize"]
|
75
|
+
_supports_group_offloading = False
|
76
|
+
|
74
77
|
@register_to_config
|
75
78
|
def __init__(
|
76
79
|
self,
|
@@ -0,0 +1,108 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ..utils.logging import get_logger
|
16
|
+
|
17
|
+
|
18
|
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
19
|
+
|
20
|
+
|
21
|
+
class CacheMixin:
|
22
|
+
r"""
|
23
|
+
A class for enable/disabling caching techniques on diffusion models.
|
24
|
+
|
25
|
+
Supported caching techniques:
|
26
|
+
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
|
27
|
+
- [FasterCache](https://huggingface.co/papers/2410.19355)
|
28
|
+
"""
|
29
|
+
|
30
|
+
_cache_config = None
|
31
|
+
|
32
|
+
@property
|
33
|
+
def is_cache_enabled(self) -> bool:
|
34
|
+
return self._cache_config is not None
|
35
|
+
|
36
|
+
def enable_cache(self, config) -> None:
|
37
|
+
r"""
|
38
|
+
Enable caching techniques on the model.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
config (`Union[PyramidAttentionBroadcastConfig]`):
|
42
|
+
The configuration for applying the caching technique. Currently supported caching techniques are:
|
43
|
+
- [`~hooks.PyramidAttentionBroadcastConfig`]
|
44
|
+
|
45
|
+
Example:
|
46
|
+
|
47
|
+
```python
|
48
|
+
>>> import torch
|
49
|
+
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
|
50
|
+
|
51
|
+
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
52
|
+
>>> pipe.to("cuda")
|
53
|
+
|
54
|
+
>>> config = PyramidAttentionBroadcastConfig(
|
55
|
+
... spatial_attention_block_skip_range=2,
|
56
|
+
... spatial_attention_timestep_skip_range=(100, 800),
|
57
|
+
... current_timestep_callback=lambda: pipe.current_timestep,
|
58
|
+
... )
|
59
|
+
>>> pipe.transformer.enable_cache(config)
|
60
|
+
```
|
61
|
+
"""
|
62
|
+
|
63
|
+
from ..hooks import (
|
64
|
+
FasterCacheConfig,
|
65
|
+
PyramidAttentionBroadcastConfig,
|
66
|
+
apply_faster_cache,
|
67
|
+
apply_pyramid_attention_broadcast,
|
68
|
+
)
|
69
|
+
|
70
|
+
if self.is_cache_enabled:
|
71
|
+
raise ValueError(
|
72
|
+
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
|
73
|
+
)
|
74
|
+
|
75
|
+
if isinstance(config, PyramidAttentionBroadcastConfig):
|
76
|
+
apply_pyramid_attention_broadcast(self, config)
|
77
|
+
elif isinstance(config, FasterCacheConfig):
|
78
|
+
apply_faster_cache(self, config)
|
79
|
+
else:
|
80
|
+
raise ValueError(f"Cache config {type(config)} is not supported.")
|
81
|
+
|
82
|
+
self._cache_config = config
|
83
|
+
|
84
|
+
def disable_cache(self) -> None:
|
85
|
+
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
|
86
|
+
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
87
|
+
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
88
|
+
|
89
|
+
if self._cache_config is None:
|
90
|
+
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
|
91
|
+
return
|
92
|
+
|
93
|
+
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
94
|
+
registry = HookRegistry.check_if_exists_or_initialize(self)
|
95
|
+
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
96
|
+
elif isinstance(self._cache_config, FasterCacheConfig):
|
97
|
+
registry = HookRegistry.check_if_exists_or_initialize(self)
|
98
|
+
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
|
99
|
+
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
|
100
|
+
else:
|
101
|
+
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
|
102
|
+
|
103
|
+
self._cache_config = None
|
104
|
+
|
105
|
+
def _reset_stateful_cache(self, recurse: bool = True) -> None:
|
106
|
+
from ..hooks import HookRegistry
|
107
|
+
|
108
|
+
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
|
@@ -18,6 +18,7 @@ if is_torch_available():
|
|
18
18
|
from .controlnet_union import ControlNetUnionModel
|
19
19
|
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
|
20
20
|
from .multicontrolnet import MultiControlNetModel
|
21
|
+
from .multicontrolnet_union import MultiControlNetUnionModel
|
21
22
|
|
22
23
|
if is_flax_available():
|
23
24
|
from .controlnet_flax import FlaxControlNetModel
|
@@ -31,8 +31,6 @@ from ..attention_processor import (
|
|
31
31
|
from ..embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
32
32
|
from ..modeling_utils import ModelMixin
|
33
33
|
from ..unets.unet_2d_blocks import (
|
34
|
-
CrossAttnDownBlock2D,
|
35
|
-
DownBlock2D,
|
36
34
|
UNetMidBlock2D,
|
37
35
|
UNetMidBlock2DCrossAttn,
|
38
36
|
get_down_block,
|
@@ -659,10 +657,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
659
657
|
for module in self.children():
|
660
658
|
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
661
659
|
|
662
|
-
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
663
|
-
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
664
|
-
module.gradient_checkpointing = value
|
665
|
-
|
666
660
|
def forward(
|
667
661
|
self,
|
668
662
|
sample: torch.Tensor,
|
@@ -740,10 +734,11 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
740
734
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
741
735
|
# This would be a good case for the `match` statement (Python 3.10+)
|
742
736
|
is_mps = sample.device.type == "mps"
|
737
|
+
is_npu = sample.device.type == "npu"
|
743
738
|
if isinstance(timestep, float):
|
744
|
-
dtype = torch.float32 if is_mps else torch.float64
|
739
|
+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
745
740
|
else:
|
746
|
-
dtype = torch.int32 if is_mps else torch.int64
|
741
|
+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
747
742
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
748
743
|
elif len(timesteps.shape) == 0:
|
749
744
|
timesteps = timesteps[None].to(sample.device)
|
@@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
|
22
22
|
from ...loaders import PeftAdapterMixin
|
23
23
|
from ...models.attention_processor import AttentionProcessor
|
24
24
|
from ...models.modeling_utils import ModelMixin
|
25
|
-
from ...utils import USE_PEFT_BACKEND, BaseOutput,
|
25
|
+
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
26
26
|
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
|
27
27
|
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
28
28
|
from ..modeling_outputs import Transformer2DModelOutput
|
@@ -178,10 +178,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
178
178
|
for name, module in self.named_children():
|
179
179
|
fn_recursive_attn_processor(name, module, processor)
|
180
180
|
|
181
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
182
|
-
if hasattr(module, "gradient_checkpointing"):
|
183
|
-
module.gradient_checkpointing = value
|
184
|
-
|
185
181
|
@classmethod
|
186
182
|
def from_transformer(
|
187
183
|
cls,
|
@@ -302,15 +298,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
302
298
|
)
|
303
299
|
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
304
300
|
|
305
|
-
if self.union:
|
306
|
-
# union mode
|
307
|
-
if controlnet_mode is None:
|
308
|
-
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
|
309
|
-
# union mode emb
|
310
|
-
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
|
311
|
-
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
|
312
|
-
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
|
313
|
-
|
314
301
|
if txt_ids.ndim == 3:
|
315
302
|
logger.warning(
|
316
303
|
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
@@ -324,30 +311,27 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
324
311
|
)
|
325
312
|
img_ids = img_ids[0]
|
326
313
|
|
314
|
+
if self.union:
|
315
|
+
# union mode
|
316
|
+
if controlnet_mode is None:
|
317
|
+
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
|
318
|
+
# union mode emb
|
319
|
+
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
|
320
|
+
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
|
321
|
+
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
|
322
|
+
|
327
323
|
ids = torch.cat((txt_ids, img_ids), dim=0)
|
328
324
|
image_rotary_emb = self.pos_embed(ids)
|
329
325
|
|
330
326
|
block_samples = ()
|
331
327
|
for index_block, block in enumerate(self.transformer_blocks):
|
332
328
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
333
|
-
|
334
|
-
|
335
|
-
def custom_forward(*inputs):
|
336
|
-
if return_dict is not None:
|
337
|
-
return module(*inputs, return_dict=return_dict)
|
338
|
-
else:
|
339
|
-
return module(*inputs)
|
340
|
-
|
341
|
-
return custom_forward
|
342
|
-
|
343
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
344
|
-
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
345
|
-
create_custom_forward(block),
|
329
|
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
330
|
+
block,
|
346
331
|
hidden_states,
|
347
332
|
encoder_hidden_states,
|
348
333
|
temb,
|
349
334
|
image_rotary_emb,
|
350
|
-
**ckpt_kwargs,
|
351
335
|
)
|
352
336
|
|
353
337
|
else:
|
@@ -364,23 +348,11 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
364
348
|
single_block_samples = ()
|
365
349
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
366
350
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
367
|
-
|
368
|
-
|
369
|
-
def custom_forward(*inputs):
|
370
|
-
if return_dict is not None:
|
371
|
-
return module(*inputs, return_dict=return_dict)
|
372
|
-
else:
|
373
|
-
return module(*inputs)
|
374
|
-
|
375
|
-
return custom_forward
|
376
|
-
|
377
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
378
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
379
|
-
create_custom_forward(block),
|
351
|
+
hidden_states = self._gradient_checkpointing_func(
|
352
|
+
block,
|
380
353
|
hidden_states,
|
381
354
|
temb,
|
382
355
|
image_rotary_emb,
|
383
|
-
**ckpt_kwargs,
|
384
356
|
)
|
385
357
|
|
386
358
|
else:
|