diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import torch
|
|
18
18
|
import torch.nn.functional as F
|
19
19
|
from torch import nn
|
20
20
|
|
21
|
-
from ...utils import deprecate,
|
21
|
+
from ...utils import deprecate, logging
|
22
22
|
from ...utils.torch_utils import apply_freeu
|
23
23
|
from ..activations import get_activation
|
24
24
|
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
@@ -737,25 +737,9 @@ class UNetMidBlock2D(nn.Module):
|
|
737
737
|
hidden_states = self.resnets[0](hidden_states, temb)
|
738
738
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
739
739
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
740
|
-
|
741
|
-
def create_custom_forward(module, return_dict=None):
|
742
|
-
def custom_forward(*inputs):
|
743
|
-
if return_dict is not None:
|
744
|
-
return module(*inputs, return_dict=return_dict)
|
745
|
-
else:
|
746
|
-
return module(*inputs)
|
747
|
-
|
748
|
-
return custom_forward
|
749
|
-
|
750
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
751
740
|
if attn is not None:
|
752
741
|
hidden_states = attn(hidden_states, temb=temb)
|
753
|
-
hidden_states =
|
754
|
-
create_custom_forward(resnet),
|
755
|
-
hidden_states,
|
756
|
-
temb,
|
757
|
-
**ckpt_kwargs,
|
758
|
-
)
|
742
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
759
743
|
else:
|
760
744
|
if attn is not None:
|
761
745
|
hidden_states = attn(hidden_states, temb=temb)
|
@@ -883,17 +867,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
883
867
|
hidden_states = self.resnets[0](hidden_states, temb)
|
884
868
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
885
869
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
886
|
-
|
887
|
-
def create_custom_forward(module, return_dict=None):
|
888
|
-
def custom_forward(*inputs):
|
889
|
-
if return_dict is not None:
|
890
|
-
return module(*inputs, return_dict=return_dict)
|
891
|
-
else:
|
892
|
-
return module(*inputs)
|
893
|
-
|
894
|
-
return custom_forward
|
895
|
-
|
896
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
897
870
|
hidden_states = attn(
|
898
871
|
hidden_states,
|
899
872
|
encoder_hidden_states=encoder_hidden_states,
|
@@ -902,12 +875,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
902
875
|
encoder_attention_mask=encoder_attention_mask,
|
903
876
|
return_dict=False,
|
904
877
|
)[0]
|
905
|
-
hidden_states =
|
906
|
-
create_custom_forward(resnet),
|
907
|
-
hidden_states,
|
908
|
-
temb,
|
909
|
-
**ckpt_kwargs,
|
910
|
-
)
|
878
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
911
879
|
else:
|
912
880
|
hidden_states = attn(
|
913
881
|
hidden_states,
|
@@ -1156,23 +1124,7 @@ class AttnDownBlock2D(nn.Module):
|
|
1156
1124
|
|
1157
1125
|
for resnet, attn in zip(self.resnets, self.attentions):
|
1158
1126
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1159
|
-
|
1160
|
-
def create_custom_forward(module, return_dict=None):
|
1161
|
-
def custom_forward(*inputs):
|
1162
|
-
if return_dict is not None:
|
1163
|
-
return module(*inputs, return_dict=return_dict)
|
1164
|
-
else:
|
1165
|
-
return module(*inputs)
|
1166
|
-
|
1167
|
-
return custom_forward
|
1168
|
-
|
1169
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1170
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1171
|
-
create_custom_forward(resnet),
|
1172
|
-
hidden_states,
|
1173
|
-
temb,
|
1174
|
-
**ckpt_kwargs,
|
1175
|
-
)
|
1127
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
1176
1128
|
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
1177
1129
|
output_states = output_states + (hidden_states,)
|
1178
1130
|
else:
|
@@ -1304,23 +1256,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
1304
1256
|
|
1305
1257
|
for i, (resnet, attn) in enumerate(blocks):
|
1306
1258
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1307
|
-
|
1308
|
-
def create_custom_forward(module, return_dict=None):
|
1309
|
-
def custom_forward(*inputs):
|
1310
|
-
if return_dict is not None:
|
1311
|
-
return module(*inputs, return_dict=return_dict)
|
1312
|
-
else:
|
1313
|
-
return module(*inputs)
|
1314
|
-
|
1315
|
-
return custom_forward
|
1316
|
-
|
1317
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1318
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1319
|
-
create_custom_forward(resnet),
|
1320
|
-
hidden_states,
|
1321
|
-
temb,
|
1322
|
-
**ckpt_kwargs,
|
1323
|
-
)
|
1259
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
1324
1260
|
hidden_states = attn(
|
1325
1261
|
hidden_states,
|
1326
1262
|
encoder_hidden_states=encoder_hidden_states,
|
@@ -1418,21 +1354,7 @@ class DownBlock2D(nn.Module):
|
|
1418
1354
|
|
1419
1355
|
for resnet in self.resnets:
|
1420
1356
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1421
|
-
|
1422
|
-
def create_custom_forward(module):
|
1423
|
-
def custom_forward(*inputs):
|
1424
|
-
return module(*inputs)
|
1425
|
-
|
1426
|
-
return custom_forward
|
1427
|
-
|
1428
|
-
if is_torch_version(">=", "1.11.0"):
|
1429
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1430
|
-
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
1431
|
-
)
|
1432
|
-
else:
|
1433
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1434
|
-
create_custom_forward(resnet), hidden_states, temb
|
1435
|
-
)
|
1357
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
1436
1358
|
else:
|
1437
1359
|
hidden_states = resnet(hidden_states, temb)
|
1438
1360
|
|
@@ -1906,21 +1828,7 @@ class ResnetDownsampleBlock2D(nn.Module):
|
|
1906
1828
|
|
1907
1829
|
for resnet in self.resnets:
|
1908
1830
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1909
|
-
|
1910
|
-
def create_custom_forward(module):
|
1911
|
-
def custom_forward(*inputs):
|
1912
|
-
return module(*inputs)
|
1913
|
-
|
1914
|
-
return custom_forward
|
1915
|
-
|
1916
|
-
if is_torch_version(">=", "1.11.0"):
|
1917
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1918
|
-
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
1919
|
-
)
|
1920
|
-
else:
|
1921
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1922
|
-
create_custom_forward(resnet), hidden_states, temb
|
1923
|
-
)
|
1831
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
1924
1832
|
else:
|
1925
1833
|
hidden_states = resnet(hidden_states, temb)
|
1926
1834
|
|
@@ -2058,17 +1966,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
|
2058
1966
|
|
2059
1967
|
for resnet, attn in zip(self.resnets, self.attentions):
|
2060
1968
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2061
|
-
|
2062
|
-
def create_custom_forward(module, return_dict=None):
|
2063
|
-
def custom_forward(*inputs):
|
2064
|
-
if return_dict is not None:
|
2065
|
-
return module(*inputs, return_dict=return_dict)
|
2066
|
-
else:
|
2067
|
-
return module(*inputs)
|
2068
|
-
|
2069
|
-
return custom_forward
|
2070
|
-
|
2071
|
-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
1969
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
2072
1970
|
hidden_states = attn(
|
2073
1971
|
hidden_states,
|
2074
1972
|
encoder_hidden_states=encoder_hidden_states,
|
@@ -2153,21 +2051,7 @@ class KDownBlock2D(nn.Module):
|
|
2153
2051
|
|
2154
2052
|
for resnet in self.resnets:
|
2155
2053
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2156
|
-
|
2157
|
-
def create_custom_forward(module):
|
2158
|
-
def custom_forward(*inputs):
|
2159
|
-
return module(*inputs)
|
2160
|
-
|
2161
|
-
return custom_forward
|
2162
|
-
|
2163
|
-
if is_torch_version(">=", "1.11.0"):
|
2164
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
2165
|
-
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
2166
|
-
)
|
2167
|
-
else:
|
2168
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
2169
|
-
create_custom_forward(resnet), hidden_states, temb
|
2170
|
-
)
|
2054
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
2171
2055
|
else:
|
2172
2056
|
hidden_states = resnet(hidden_states, temb)
|
2173
2057
|
|
@@ -2262,22 +2146,10 @@ class KCrossAttnDownBlock2D(nn.Module):
|
|
2262
2146
|
|
2263
2147
|
for resnet, attn in zip(self.resnets, self.attentions):
|
2264
2148
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2265
|
-
|
2266
|
-
|
2267
|
-
def custom_forward(*inputs):
|
2268
|
-
if return_dict is not None:
|
2269
|
-
return module(*inputs, return_dict=return_dict)
|
2270
|
-
else:
|
2271
|
-
return module(*inputs)
|
2272
|
-
|
2273
|
-
return custom_forward
|
2274
|
-
|
2275
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
2276
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
2277
|
-
create_custom_forward(resnet),
|
2149
|
+
hidden_states = self._gradient_checkpointing_func(
|
2150
|
+
resnet,
|
2278
2151
|
hidden_states,
|
2279
2152
|
temb,
|
2280
|
-
**ckpt_kwargs,
|
2281
2153
|
)
|
2282
2154
|
hidden_states = attn(
|
2283
2155
|
hidden_states,
|
@@ -2423,23 +2295,7 @@ class AttnUpBlock2D(nn.Module):
|
|
2423
2295
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2424
2296
|
|
2425
2297
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2426
|
-
|
2427
|
-
def create_custom_forward(module, return_dict=None):
|
2428
|
-
def custom_forward(*inputs):
|
2429
|
-
if return_dict is not None:
|
2430
|
-
return module(*inputs, return_dict=return_dict)
|
2431
|
-
else:
|
2432
|
-
return module(*inputs)
|
2433
|
-
|
2434
|
-
return custom_forward
|
2435
|
-
|
2436
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
2437
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
2438
|
-
create_custom_forward(resnet),
|
2439
|
-
hidden_states,
|
2440
|
-
temb,
|
2441
|
-
**ckpt_kwargs,
|
2442
|
-
)
|
2298
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
2443
2299
|
hidden_states = attn(hidden_states)
|
2444
2300
|
else:
|
2445
2301
|
hidden_states = resnet(hidden_states, temb)
|
@@ -2588,23 +2444,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
2588
2444
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2589
2445
|
|
2590
2446
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2591
|
-
|
2592
|
-
def create_custom_forward(module, return_dict=None):
|
2593
|
-
def custom_forward(*inputs):
|
2594
|
-
if return_dict is not None:
|
2595
|
-
return module(*inputs, return_dict=return_dict)
|
2596
|
-
else:
|
2597
|
-
return module(*inputs)
|
2598
|
-
|
2599
|
-
return custom_forward
|
2600
|
-
|
2601
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
2602
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
2603
|
-
create_custom_forward(resnet),
|
2604
|
-
hidden_states,
|
2605
|
-
temb,
|
2606
|
-
**ckpt_kwargs,
|
2607
|
-
)
|
2447
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
2608
2448
|
hidden_states = attn(
|
2609
2449
|
hidden_states,
|
2610
2450
|
encoder_hidden_states=encoder_hidden_states,
|
@@ -2721,21 +2561,7 @@ class UpBlock2D(nn.Module):
|
|
2721
2561
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2722
2562
|
|
2723
2563
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2724
|
-
|
2725
|
-
def create_custom_forward(module):
|
2726
|
-
def custom_forward(*inputs):
|
2727
|
-
return module(*inputs)
|
2728
|
-
|
2729
|
-
return custom_forward
|
2730
|
-
|
2731
|
-
if is_torch_version(">=", "1.11.0"):
|
2732
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
2733
|
-
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
2734
|
-
)
|
2735
|
-
else:
|
2736
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
2737
|
-
create_custom_forward(resnet), hidden_states, temb
|
2738
|
-
)
|
2564
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
2739
2565
|
else:
|
2740
2566
|
hidden_states = resnet(hidden_states, temb)
|
2741
2567
|
|
@@ -3251,21 +3077,7 @@ class ResnetUpsampleBlock2D(nn.Module):
|
|
3251
3077
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
3252
3078
|
|
3253
3079
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
3254
|
-
|
3255
|
-
def create_custom_forward(module):
|
3256
|
-
def custom_forward(*inputs):
|
3257
|
-
return module(*inputs)
|
3258
|
-
|
3259
|
-
return custom_forward
|
3260
|
-
|
3261
|
-
if is_torch_version(">=", "1.11.0"):
|
3262
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
3263
|
-
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
3264
|
-
)
|
3265
|
-
else:
|
3266
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
3267
|
-
create_custom_forward(resnet), hidden_states, temb
|
3268
|
-
)
|
3080
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
3269
3081
|
else:
|
3270
3082
|
hidden_states = resnet(hidden_states, temb)
|
3271
3083
|
|
@@ -3409,17 +3221,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
|
3409
3221
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
3410
3222
|
|
3411
3223
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
3412
|
-
|
3413
|
-
def create_custom_forward(module, return_dict=None):
|
3414
|
-
def custom_forward(*inputs):
|
3415
|
-
if return_dict is not None:
|
3416
|
-
return module(*inputs, return_dict=return_dict)
|
3417
|
-
else:
|
3418
|
-
return module(*inputs)
|
3419
|
-
|
3420
|
-
return custom_forward
|
3421
|
-
|
3422
|
-
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
3224
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
3423
3225
|
hidden_states = attn(
|
3424
3226
|
hidden_states,
|
3425
3227
|
encoder_hidden_states=encoder_hidden_states,
|
@@ -3512,21 +3314,7 @@ class KUpBlock2D(nn.Module):
|
|
3512
3314
|
|
3513
3315
|
for resnet in self.resnets:
|
3514
3316
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
3515
|
-
|
3516
|
-
def create_custom_forward(module):
|
3517
|
-
def custom_forward(*inputs):
|
3518
|
-
return module(*inputs)
|
3519
|
-
|
3520
|
-
return custom_forward
|
3521
|
-
|
3522
|
-
if is_torch_version(">=", "1.11.0"):
|
3523
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
3524
|
-
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
3525
|
-
)
|
3526
|
-
else:
|
3527
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
3528
|
-
create_custom_forward(resnet), hidden_states, temb
|
3529
|
-
)
|
3317
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
3530
3318
|
else:
|
3531
3319
|
hidden_states = resnet(hidden_states, temb)
|
3532
3320
|
|
@@ -3640,22 +3428,10 @@ class KCrossAttnUpBlock2D(nn.Module):
|
|
3640
3428
|
|
3641
3429
|
for resnet, attn in zip(self.resnets, self.attentions):
|
3642
3430
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
3643
|
-
|
3644
|
-
|
3645
|
-
def custom_forward(*inputs):
|
3646
|
-
if return_dict is not None:
|
3647
|
-
return module(*inputs, return_dict=return_dict)
|
3648
|
-
else:
|
3649
|
-
return module(*inputs)
|
3650
|
-
|
3651
|
-
return custom_forward
|
3652
|
-
|
3653
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
3654
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
3655
|
-
create_custom_forward(resnet),
|
3431
|
+
hidden_states = self._gradient_checkpointing_func(
|
3432
|
+
resnet,
|
3656
3433
|
hidden_states,
|
3657
3434
|
temb,
|
3658
|
-
**ckpt_kwargs,
|
3659
3435
|
)
|
3660
3436
|
hidden_states = attn(
|
3661
3437
|
hidden_states,
|
@@ -166,6 +166,7 @@ class UNet2DConditionModel(
|
|
166
166
|
|
167
167
|
_supports_gradient_checkpointing = True
|
168
168
|
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
|
169
|
+
_skip_layerwise_casting_patterns = ["norm"]
|
169
170
|
|
170
171
|
@register_to_config
|
171
172
|
def __init__(
|
@@ -833,10 +834,6 @@ class UNet2DConditionModel(
|
|
833
834
|
for module in self.children():
|
834
835
|
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
835
836
|
|
836
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
837
|
-
if hasattr(module, "gradient_checkpointing"):
|
838
|
-
module.gradient_checkpointing = value
|
839
|
-
|
840
837
|
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
841
838
|
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
842
839
|
|
@@ -915,10 +912,11 @@ class UNet2DConditionModel(
|
|
915
912
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
916
913
|
# This would be a good case for the `match` statement (Python 3.10+)
|
917
914
|
is_mps = sample.device.type == "mps"
|
915
|
+
is_npu = sample.device.type == "npu"
|
918
916
|
if isinstance(timestep, float):
|
919
|
-
dtype = torch.float32 if is_mps else torch.float64
|
917
|
+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
920
918
|
else:
|
921
|
-
dtype = torch.int32 if is_mps else torch.int64
|
919
|
+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
922
920
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
923
921
|
elif len(timesteps.shape) == 0:
|
924
922
|
timesteps = timesteps[None].to(sample.device)
|
@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|
17
17
|
import torch
|
18
18
|
from torch import nn
|
19
19
|
|
20
|
-
from ...utils import deprecate,
|
20
|
+
from ...utils import deprecate, logging
|
21
21
|
from ...utils.torch_utils import apply_freeu
|
22
22
|
from ..attention import Attention
|
23
23
|
from ..resnet import (
|
@@ -1078,31 +1078,14 @@ class UNetMidBlockSpatioTemporal(nn.Module):
|
|
1078
1078
|
)
|
1079
1079
|
|
1080
1080
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
1081
|
-
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1082
|
-
|
1083
|
-
def create_custom_forward(module, return_dict=None):
|
1084
|
-
def custom_forward(*inputs):
|
1085
|
-
if return_dict is not None:
|
1086
|
-
return module(*inputs, return_dict=return_dict)
|
1087
|
-
else:
|
1088
|
-
return module(*inputs)
|
1089
|
-
|
1090
|
-
return custom_forward
|
1091
|
-
|
1092
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1081
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1093
1082
|
hidden_states = attn(
|
1094
1083
|
hidden_states,
|
1095
1084
|
encoder_hidden_states=encoder_hidden_states,
|
1096
1085
|
image_only_indicator=image_only_indicator,
|
1097
1086
|
return_dict=False,
|
1098
1087
|
)[0]
|
1099
|
-
hidden_states =
|
1100
|
-
create_custom_forward(resnet),
|
1101
|
-
hidden_states,
|
1102
|
-
temb,
|
1103
|
-
image_only_indicator,
|
1104
|
-
**ckpt_kwargs,
|
1105
|
-
)
|
1088
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
|
1106
1089
|
else:
|
1107
1090
|
hidden_states = attn(
|
1108
1091
|
hidden_states,
|
@@ -1110,11 +1093,7 @@ class UNetMidBlockSpatioTemporal(nn.Module):
|
|
1110
1093
|
image_only_indicator=image_only_indicator,
|
1111
1094
|
return_dict=False,
|
1112
1095
|
)[0]
|
1113
|
-
hidden_states = resnet(
|
1114
|
-
hidden_states,
|
1115
|
-
temb,
|
1116
|
-
image_only_indicator=image_only_indicator,
|
1117
|
-
)
|
1096
|
+
hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
|
1118
1097
|
|
1119
1098
|
return hidden_states
|
1120
1099
|
|
@@ -1169,34 +1148,9 @@ class DownBlockSpatioTemporal(nn.Module):
|
|
1169
1148
|
output_states = ()
|
1170
1149
|
for resnet in self.resnets:
|
1171
1150
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1172
|
-
|
1173
|
-
def create_custom_forward(module):
|
1174
|
-
def custom_forward(*inputs):
|
1175
|
-
return module(*inputs)
|
1176
|
-
|
1177
|
-
return custom_forward
|
1178
|
-
|
1179
|
-
if is_torch_version(">=", "1.11.0"):
|
1180
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1181
|
-
create_custom_forward(resnet),
|
1182
|
-
hidden_states,
|
1183
|
-
temb,
|
1184
|
-
image_only_indicator,
|
1185
|
-
use_reentrant=False,
|
1186
|
-
)
|
1187
|
-
else:
|
1188
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1189
|
-
create_custom_forward(resnet),
|
1190
|
-
hidden_states,
|
1191
|
-
temb,
|
1192
|
-
image_only_indicator,
|
1193
|
-
)
|
1151
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
|
1194
1152
|
else:
|
1195
|
-
hidden_states = resnet(
|
1196
|
-
hidden_states,
|
1197
|
-
temb,
|
1198
|
-
image_only_indicator=image_only_indicator,
|
1199
|
-
)
|
1153
|
+
hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
|
1200
1154
|
|
1201
1155
|
output_states = output_states + (hidden_states,)
|
1202
1156
|
|
@@ -1281,25 +1235,8 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
|
|
1281
1235
|
|
1282
1236
|
blocks = list(zip(self.resnets, self.attentions))
|
1283
1237
|
for resnet, attn in blocks:
|
1284
|
-
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1285
|
-
|
1286
|
-
def create_custom_forward(module, return_dict=None):
|
1287
|
-
def custom_forward(*inputs):
|
1288
|
-
if return_dict is not None:
|
1289
|
-
return module(*inputs, return_dict=return_dict)
|
1290
|
-
else:
|
1291
|
-
return module(*inputs)
|
1292
|
-
|
1293
|
-
return custom_forward
|
1294
|
-
|
1295
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1296
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1297
|
-
create_custom_forward(resnet),
|
1298
|
-
hidden_states,
|
1299
|
-
temb,
|
1300
|
-
image_only_indicator,
|
1301
|
-
**ckpt_kwargs,
|
1302
|
-
)
|
1238
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1239
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
|
1303
1240
|
|
1304
1241
|
hidden_states = attn(
|
1305
1242
|
hidden_states,
|
@@ -1308,11 +1245,7 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
|
|
1308
1245
|
return_dict=False,
|
1309
1246
|
)[0]
|
1310
1247
|
else:
|
1311
|
-
hidden_states = resnet(
|
1312
|
-
hidden_states,
|
1313
|
-
temb,
|
1314
|
-
image_only_indicator=image_only_indicator,
|
1315
|
-
)
|
1248
|
+
hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
|
1316
1249
|
hidden_states = attn(
|
1317
1250
|
hidden_states,
|
1318
1251
|
encoder_hidden_states=encoder_hidden_states,
|
@@ -1385,34 +1318,9 @@ class UpBlockSpatioTemporal(nn.Module):
|
|
1385
1318
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1386
1319
|
|
1387
1320
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1388
|
-
|
1389
|
-
def create_custom_forward(module):
|
1390
|
-
def custom_forward(*inputs):
|
1391
|
-
return module(*inputs)
|
1392
|
-
|
1393
|
-
return custom_forward
|
1394
|
-
|
1395
|
-
if is_torch_version(">=", "1.11.0"):
|
1396
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1397
|
-
create_custom_forward(resnet),
|
1398
|
-
hidden_states,
|
1399
|
-
temb,
|
1400
|
-
image_only_indicator,
|
1401
|
-
use_reentrant=False,
|
1402
|
-
)
|
1403
|
-
else:
|
1404
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1405
|
-
create_custom_forward(resnet),
|
1406
|
-
hidden_states,
|
1407
|
-
temb,
|
1408
|
-
image_only_indicator,
|
1409
|
-
)
|
1321
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
|
1410
1322
|
else:
|
1411
|
-
hidden_states = resnet(
|
1412
|
-
hidden_states,
|
1413
|
-
temb,
|
1414
|
-
image_only_indicator=image_only_indicator,
|
1415
|
-
)
|
1323
|
+
hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
|
1416
1324
|
|
1417
1325
|
if self.upsamplers is not None:
|
1418
1326
|
for upsampler in self.upsamplers:
|
@@ -1495,25 +1403,8 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
|
|
1495
1403
|
|
1496
1404
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1497
1405
|
|
1498
|
-
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1499
|
-
|
1500
|
-
def create_custom_forward(module, return_dict=None):
|
1501
|
-
def custom_forward(*inputs):
|
1502
|
-
if return_dict is not None:
|
1503
|
-
return module(*inputs, return_dict=return_dict)
|
1504
|
-
else:
|
1505
|
-
return module(*inputs)
|
1506
|
-
|
1507
|
-
return custom_forward
|
1508
|
-
|
1509
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1510
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1511
|
-
create_custom_forward(resnet),
|
1512
|
-
hidden_states,
|
1513
|
-
temb,
|
1514
|
-
image_only_indicator,
|
1515
|
-
**ckpt_kwargs,
|
1516
|
-
)
|
1406
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1407
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator)
|
1517
1408
|
hidden_states = attn(
|
1518
1409
|
hidden_states,
|
1519
1410
|
encoder_hidden_states=encoder_hidden_states,
|
@@ -1521,11 +1412,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
|
|
1521
1412
|
return_dict=False,
|
1522
1413
|
)[0]
|
1523
1414
|
else:
|
1524
|
-
hidden_states = resnet(
|
1525
|
-
hidden_states,
|
1526
|
-
temb,
|
1527
|
-
image_only_indicator=image_only_indicator,
|
1528
|
-
)
|
1415
|
+
hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator)
|
1529
1416
|
hidden_states = attn(
|
1530
1417
|
hidden_states,
|
1531
1418
|
encoder_hidden_states=encoder_hidden_states,
|