diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +595 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
- diffusers-0.33.1.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -37,11 +37,7 @@ from ..embeddings import TimestepEmbedding, Timesteps
|
|
37
37
|
from ..modeling_utils import ModelMixin
|
38
38
|
from ..transformers.transformer_temporal import TransformerTemporalModel
|
39
39
|
from .unet_3d_blocks import (
|
40
|
-
CrossAttnDownBlock3D,
|
41
|
-
CrossAttnUpBlock3D,
|
42
|
-
DownBlock3D,
|
43
40
|
UNetMidBlock3DCrossAttn,
|
44
|
-
UpBlock3D,
|
45
41
|
get_down_block,
|
46
42
|
get_up_block,
|
47
43
|
)
|
@@ -97,6 +93,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
97
93
|
"""
|
98
94
|
|
99
95
|
_supports_gradient_checkpointing = False
|
96
|
+
_skip_layerwise_casting_patterns = ["norm", "time_embedding"]
|
100
97
|
|
101
98
|
@register_to_config
|
102
99
|
def __init__(
|
@@ -471,10 +468,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
471
468
|
|
472
469
|
self.set_attn_processor(processor)
|
473
470
|
|
474
|
-
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
475
|
-
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
476
|
-
module.gradient_checkpointing = value
|
477
|
-
|
478
471
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
479
472
|
def enable_freeu(self, s1, s2, b1, b2):
|
480
473
|
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
@@ -624,10 +617,11 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
624
617
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
625
618
|
# This would be a good case for the `match` statement (Python 3.10+)
|
626
619
|
is_mps = sample.device.type == "mps"
|
620
|
+
is_npu = sample.device.type == "npu"
|
627
621
|
if isinstance(timestep, float):
|
628
|
-
dtype = torch.float32 if is_mps else torch.float64
|
622
|
+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
629
623
|
else:
|
630
|
-
dtype = torch.int32 if is_mps else torch.int64
|
624
|
+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
631
625
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
632
626
|
elif len(timesteps.shape) == 0:
|
633
627
|
timesteps = timesteps[None].to(sample.device)
|
@@ -644,8 +638,10 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
644
638
|
t_emb = t_emb.to(dtype=self.dtype)
|
645
639
|
|
646
640
|
emb = self.time_embedding(t_emb, timestep_cond)
|
647
|
-
emb = emb.repeat_interleave(
|
648
|
-
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
|
641
|
+
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
|
642
|
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
|
643
|
+
num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
|
644
|
+
)
|
649
645
|
|
650
646
|
# 2. pre-process
|
651
647
|
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
|
@@ -35,11 +35,7 @@ from ..embeddings import TimestepEmbedding, Timesteps
|
|
35
35
|
from ..modeling_utils import ModelMixin
|
36
36
|
from ..transformers.transformer_temporal import TransformerTemporalModel
|
37
37
|
from .unet_3d_blocks import (
|
38
|
-
CrossAttnDownBlock3D,
|
39
|
-
CrossAttnUpBlock3D,
|
40
|
-
DownBlock3D,
|
41
38
|
UNetMidBlock3DCrossAttn,
|
42
|
-
UpBlock3D,
|
43
39
|
get_down_block,
|
44
40
|
get_up_block,
|
45
41
|
)
|
@@ -436,11 +432,6 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
436
432
|
|
437
433
|
self.set_attn_processor(processor)
|
438
434
|
|
439
|
-
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel._set_gradient_checkpointing
|
440
|
-
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
441
|
-
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
442
|
-
module.gradient_checkpointing = value
|
443
|
-
|
444
435
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
445
436
|
def enable_freeu(self, s1, s2, b1, b2):
|
446
437
|
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
@@ -575,10 +566,11 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
575
566
|
# TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can
|
576
567
|
# This would be a good case for the `match` statement (Python 3.10+)
|
577
568
|
is_mps = sample.device.type == "mps"
|
569
|
+
is_npu = sample.device.type == "npu"
|
578
570
|
if isinstance(timesteps, float):
|
579
|
-
dtype = torch.float32 if is_mps else torch.float64
|
571
|
+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
580
572
|
else:
|
581
|
-
dtype = torch.int32 if is_mps else torch.int64
|
573
|
+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
582
574
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
583
575
|
elif len(timesteps.shape) == 0:
|
584
576
|
timesteps = timesteps[None].to(sample.device)
|
@@ -600,7 +592,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
600
592
|
|
601
593
|
# 3. time + FPS embeddings.
|
602
594
|
emb = t_emb + fps_emb
|
603
|
-
emb = emb.repeat_interleave(
|
595
|
+
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
|
604
596
|
|
605
597
|
# 4. context embeddings.
|
606
598
|
# The context embeddings consist of both text embeddings from the input prompt
|
@@ -628,7 +620,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
628
620
|
image_emb = self.context_embedding(image_embeddings)
|
629
621
|
image_emb = image_emb.view(-1, self.config.in_channels, self.config.cross_attention_dim)
|
630
622
|
context_emb = torch.cat([context_emb, image_emb], dim=1)
|
631
|
-
context_emb = context_emb.repeat_interleave(
|
623
|
+
context_emb = context_emb.repeat_interleave(num_frames, dim=0, output_size=context_emb.shape[0] * num_frames)
|
632
624
|
|
633
625
|
image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape(
|
634
626
|
image_latents.shape[0] * image_latents.shape[2],
|
@@ -205,10 +205,6 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
|
|
205
205
|
"""
|
206
206
|
self.set_attn_processor(AttnProcessor())
|
207
207
|
|
208
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
209
|
-
if hasattr(module, "gradient_checkpointing"):
|
210
|
-
module.gradient_checkpointing = value
|
211
|
-
|
212
208
|
def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
|
213
209
|
if encoder_attention_mask is not None:
|
214
210
|
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
@@ -22,7 +22,7 @@ import torch.utils.checkpoint
|
|
22
22
|
|
23
23
|
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
24
24
|
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
|
25
|
-
from ...utils import BaseOutput, deprecate,
|
25
|
+
from ...utils import BaseOutput, deprecate, logging
|
26
26
|
from ...utils.torch_utils import apply_freeu
|
27
27
|
from ..attention import BasicTransformerBlock
|
28
28
|
from ..attention_processor import (
|
@@ -324,25 +324,7 @@ class DownBlockMotion(nn.Module):
|
|
324
324
|
blocks = zip(self.resnets, self.motion_modules)
|
325
325
|
for resnet, motion_module in blocks:
|
326
326
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
327
|
-
|
328
|
-
def create_custom_forward(module):
|
329
|
-
def custom_forward(*inputs):
|
330
|
-
return module(*inputs)
|
331
|
-
|
332
|
-
return custom_forward
|
333
|
-
|
334
|
-
if is_torch_version(">=", "1.11.0"):
|
335
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
336
|
-
create_custom_forward(resnet),
|
337
|
-
hidden_states,
|
338
|
-
temb,
|
339
|
-
use_reentrant=False,
|
340
|
-
)
|
341
|
-
else:
|
342
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
343
|
-
create_custom_forward(resnet), hidden_states, temb
|
344
|
-
)
|
345
|
-
|
327
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
346
328
|
else:
|
347
329
|
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
348
330
|
|
@@ -514,23 +496,7 @@ class CrossAttnDownBlockMotion(nn.Module):
|
|
514
496
|
blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
|
515
497
|
for i, (resnet, attn, motion_module) in enumerate(blocks):
|
516
498
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
517
|
-
|
518
|
-
def create_custom_forward(module, return_dict=None):
|
519
|
-
def custom_forward(*inputs):
|
520
|
-
if return_dict is not None:
|
521
|
-
return module(*inputs, return_dict=return_dict)
|
522
|
-
else:
|
523
|
-
return module(*inputs)
|
524
|
-
|
525
|
-
return custom_forward
|
526
|
-
|
527
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
528
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
529
|
-
create_custom_forward(resnet),
|
530
|
-
hidden_states,
|
531
|
-
temb,
|
532
|
-
**ckpt_kwargs,
|
533
|
-
)
|
499
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
534
500
|
else:
|
535
501
|
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
536
502
|
|
@@ -543,10 +509,7 @@ class CrossAttnDownBlockMotion(nn.Module):
|
|
543
509
|
return_dict=False,
|
544
510
|
)[0]
|
545
511
|
|
546
|
-
hidden_states = motion_module(
|
547
|
-
hidden_states,
|
548
|
-
num_frames=num_frames,
|
549
|
-
)
|
512
|
+
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
550
513
|
|
551
514
|
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
552
515
|
if i == len(blocks) - 1 and additional_residuals is not None:
|
@@ -733,23 +696,7 @@ class CrossAttnUpBlockMotion(nn.Module):
|
|
733
696
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
734
697
|
|
735
698
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
736
|
-
|
737
|
-
def create_custom_forward(module, return_dict=None):
|
738
|
-
def custom_forward(*inputs):
|
739
|
-
if return_dict is not None:
|
740
|
-
return module(*inputs, return_dict=return_dict)
|
741
|
-
else:
|
742
|
-
return module(*inputs)
|
743
|
-
|
744
|
-
return custom_forward
|
745
|
-
|
746
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
747
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
748
|
-
create_custom_forward(resnet),
|
749
|
-
hidden_states,
|
750
|
-
temb,
|
751
|
-
**ckpt_kwargs,
|
752
|
-
)
|
699
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
753
700
|
else:
|
754
701
|
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
755
702
|
|
@@ -762,10 +709,7 @@ class CrossAttnUpBlockMotion(nn.Module):
|
|
762
709
|
return_dict=False,
|
763
710
|
)[0]
|
764
711
|
|
765
|
-
hidden_states = motion_module(
|
766
|
-
hidden_states,
|
767
|
-
num_frames=num_frames,
|
768
|
-
)
|
712
|
+
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
769
713
|
|
770
714
|
if self.upsamplers is not None:
|
771
715
|
for upsampler in self.upsamplers:
|
@@ -896,24 +840,7 @@ class UpBlockMotion(nn.Module):
|
|
896
840
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
897
841
|
|
898
842
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
899
|
-
|
900
|
-
def create_custom_forward(module):
|
901
|
-
def custom_forward(*inputs):
|
902
|
-
return module(*inputs)
|
903
|
-
|
904
|
-
return custom_forward
|
905
|
-
|
906
|
-
if is_torch_version(">=", "1.11.0"):
|
907
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
908
|
-
create_custom_forward(resnet),
|
909
|
-
hidden_states,
|
910
|
-
temb,
|
911
|
-
use_reentrant=False,
|
912
|
-
)
|
913
|
-
else:
|
914
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
915
|
-
create_custom_forward(resnet), hidden_states, temb
|
916
|
-
)
|
843
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
917
844
|
else:
|
918
845
|
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
919
846
|
|
@@ -1080,34 +1007,12 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
|
1080
1007
|
)[0]
|
1081
1008
|
|
1082
1009
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1083
|
-
|
1084
|
-
|
1085
|
-
def custom_forward(*inputs):
|
1086
|
-
if return_dict is not None:
|
1087
|
-
return module(*inputs, return_dict=return_dict)
|
1088
|
-
else:
|
1089
|
-
return module(*inputs)
|
1090
|
-
|
1091
|
-
return custom_forward
|
1092
|
-
|
1093
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1094
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1095
|
-
create_custom_forward(motion_module),
|
1096
|
-
hidden_states,
|
1097
|
-
temb,
|
1098
|
-
**ckpt_kwargs,
|
1099
|
-
)
|
1100
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1101
|
-
create_custom_forward(resnet),
|
1102
|
-
hidden_states,
|
1103
|
-
temb,
|
1104
|
-
**ckpt_kwargs,
|
1010
|
+
hidden_states = self._gradient_checkpointing_func(
|
1011
|
+
motion_module, hidden_states, None, None, None, num_frames, None
|
1105
1012
|
)
|
1013
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
1106
1014
|
else:
|
1107
|
-
hidden_states = motion_module(
|
1108
|
-
hidden_states,
|
1109
|
-
num_frames=num_frames,
|
1110
|
-
)
|
1015
|
+
hidden_states = motion_module(hidden_states, None, None, None, num_frames, None)
|
1111
1016
|
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
1112
1017
|
|
1113
1018
|
return hidden_states
|
@@ -1301,6 +1206,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
|
1301
1206
|
"""
|
1302
1207
|
|
1303
1208
|
_supports_gradient_checkpointing = True
|
1209
|
+
_skip_layerwise_casting_patterns = ["norm"]
|
1304
1210
|
|
1305
1211
|
@register_to_config
|
1306
1212
|
def __init__(
|
@@ -1965,10 +1871,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
|
1965
1871
|
|
1966
1872
|
self.set_attn_processor(processor)
|
1967
1873
|
|
1968
|
-
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
1969
|
-
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
|
1970
|
-
module.gradient_checkpointing = value
|
1971
|
-
|
1972
1874
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
1973
1875
|
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
|
1974
1876
|
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
@@ -2114,10 +2016,11 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
|
2114
2016
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
2115
2017
|
# This would be a good case for the `match` statement (Python 3.10+)
|
2116
2018
|
is_mps = sample.device.type == "mps"
|
2019
|
+
is_npu = sample.device.type == "npu"
|
2117
2020
|
if isinstance(timestep, float):
|
2118
|
-
dtype = torch.float32 if is_mps else torch.float64
|
2021
|
+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
2119
2022
|
else:
|
2120
|
-
dtype = torch.int32 if is_mps else torch.int64
|
2023
|
+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
2121
2024
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
2122
2025
|
elif len(timesteps.shape) == 0:
|
2123
2026
|
timesteps = timesteps[None].to(sample.device)
|
@@ -2156,7 +2059,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
|
2156
2059
|
aug_emb = self.add_embedding(add_embeds)
|
2157
2060
|
|
2158
2061
|
emb = emb if aug_emb is None else emb + aug_emb
|
2159
|
-
emb = emb.repeat_interleave(
|
2062
|
+
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
|
2160
2063
|
|
2161
2064
|
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
2162
2065
|
if "image_embeds" not in added_cond_kwargs:
|
@@ -2165,7 +2068,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
|
2165
2068
|
)
|
2166
2069
|
image_embeds = added_cond_kwargs.get("image_embeds")
|
2167
2070
|
image_embeds = self.encoder_hid_proj(image_embeds)
|
2168
|
-
image_embeds = [
|
2071
|
+
image_embeds = [
|
2072
|
+
image_embed.repeat_interleave(num_frames, dim=0, output_size=image_embed.shape[0] * num_frames)
|
2073
|
+
for image_embed in image_embeds
|
2074
|
+
]
|
2169
2075
|
encoder_hidden_states = (encoder_hidden_states, image_embeds)
|
2170
2076
|
|
2171
2077
|
# 2. pre-process
|
@@ -320,10 +320,6 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
|
320
320
|
|
321
321
|
self.set_attn_processor(processor)
|
322
322
|
|
323
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
324
|
-
if hasattr(module, "gradient_checkpointing"):
|
325
|
-
module.gradient_checkpointing = value
|
326
|
-
|
327
323
|
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
328
324
|
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
329
325
|
"""
|
@@ -402,10 +398,11 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
|
402
398
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
403
399
|
# This would be a good case for the `match` statement (Python 3.10+)
|
404
400
|
is_mps = sample.device.type == "mps"
|
401
|
+
is_npu = sample.device.type == "npu"
|
405
402
|
if isinstance(timestep, float):
|
406
|
-
dtype = torch.float32 if is_mps else torch.float64
|
403
|
+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
407
404
|
else:
|
408
|
-
dtype = torch.int32 if is_mps else torch.int64
|
405
|
+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
409
406
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
410
407
|
elif len(timesteps.shape) == 0:
|
411
408
|
timesteps = timesteps[None].to(sample.device)
|
@@ -434,9 +431,11 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
|
434
431
|
sample = sample.flatten(0, 1)
|
435
432
|
# Repeat the embeddings num_video_frames times
|
436
433
|
# emb: [batch, channels] -> [batch * frames, channels]
|
437
|
-
emb = emb.repeat_interleave(num_frames, dim=0)
|
434
|
+
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
|
438
435
|
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
|
439
|
-
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
|
436
|
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
|
437
|
+
num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
|
438
|
+
)
|
440
439
|
|
441
440
|
# 2. pre-process
|
442
441
|
sample = self.conv_in(sample)
|
@@ -387,9 +387,6 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
387
387
|
|
388
388
|
self.gradient_checkpointing = False
|
389
389
|
|
390
|
-
def _set_gradient_checkpointing(self, value=False):
|
391
|
-
self.gradient_checkpointing = value
|
392
|
-
|
393
390
|
def _init_weights(self, m):
|
394
391
|
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
395
392
|
torch.nn.init.xavier_uniform_(m.weight)
|
@@ -456,29 +453,18 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
456
453
|
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
457
454
|
|
458
455
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
459
|
-
|
460
|
-
def create_custom_forward(module):
|
461
|
-
def custom_forward(*inputs):
|
462
|
-
return module(*inputs)
|
463
|
-
|
464
|
-
return custom_forward
|
465
|
-
|
466
456
|
for down_block, downscaler, repmap in block_group:
|
467
457
|
x = downscaler(x)
|
468
458
|
for i in range(len(repmap) + 1):
|
469
459
|
for block in down_block:
|
470
460
|
if isinstance(block, SDCascadeResBlock):
|
471
|
-
x =
|
461
|
+
x = self._gradient_checkpointing_func(block, x)
|
472
462
|
elif isinstance(block, SDCascadeAttnBlock):
|
473
|
-
x =
|
474
|
-
create_custom_forward(block), x, clip, use_reentrant=False
|
475
|
-
)
|
463
|
+
x = self._gradient_checkpointing_func(block, x, clip)
|
476
464
|
elif isinstance(block, SDCascadeTimestepBlock):
|
477
|
-
x =
|
478
|
-
create_custom_forward(block), x, r_embed, use_reentrant=False
|
479
|
-
)
|
465
|
+
x = self._gradient_checkpointing_func(block, x, r_embed)
|
480
466
|
else:
|
481
|
-
x =
|
467
|
+
x = self._gradient_checkpointing_func(block)
|
482
468
|
if i < len(repmap):
|
483
469
|
x = repmap[i](x)
|
484
470
|
level_outputs.insert(0, x)
|
@@ -505,13 +491,6 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
505
491
|
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
506
492
|
|
507
493
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
508
|
-
|
509
|
-
def create_custom_forward(module):
|
510
|
-
def custom_forward(*inputs):
|
511
|
-
return module(*inputs)
|
512
|
-
|
513
|
-
return custom_forward
|
514
|
-
|
515
494
|
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
516
495
|
for j in range(len(repmap) + 1):
|
517
496
|
for k, block in enumerate(up_block):
|
@@ -523,19 +502,13 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
523
502
|
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
|
524
503
|
)
|
525
504
|
x = x.to(orig_type)
|
526
|
-
x =
|
527
|
-
create_custom_forward(block), x, skip, use_reentrant=False
|
528
|
-
)
|
505
|
+
x = self._gradient_checkpointing_func(block, x, skip)
|
529
506
|
elif isinstance(block, SDCascadeAttnBlock):
|
530
|
-
x =
|
531
|
-
create_custom_forward(block), x, clip, use_reentrant=False
|
532
|
-
)
|
507
|
+
x = self._gradient_checkpointing_func(block, x, clip)
|
533
508
|
elif isinstance(block, SDCascadeTimestepBlock):
|
534
|
-
x =
|
535
|
-
create_custom_forward(block), x, r_embed, use_reentrant=False
|
536
|
-
)
|
509
|
+
x = self._gradient_checkpointing_func(block, x, r_embed)
|
537
510
|
else:
|
538
|
-
x =
|
511
|
+
x = self._gradient_checkpointing_func(block, x)
|
539
512
|
if j < len(repmap):
|
540
513
|
x = repmap[j](x)
|
541
514
|
x = upscaler(x)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2025 The HuggingFace Inc. team.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -148,9 +148,6 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
148
148
|
|
149
149
|
self.gradient_checkpointing = False
|
150
150
|
|
151
|
-
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
152
|
-
pass
|
153
|
-
|
154
151
|
def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None):
|
155
152
|
encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
|
156
153
|
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
|
diffusers/optimization.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2025 The HuggingFace Inc. team.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -258,7 +258,7 @@ def get_polynomial_decay_schedule_with_warmup(
|
|
258
258
|
|
259
259
|
lr_init = optimizer.defaults["lr"]
|
260
260
|
if not (lr_init > lr_end):
|
261
|
-
raise ValueError(f"lr_end ({lr_end}) must be
|
261
|
+
raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")
|
262
262
|
|
263
263
|
def lr_lambda(current_step: int):
|
264
264
|
if current_step < num_warmup_steps:
|