diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +595 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
- diffusers-0.33.1.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,7 @@ import torch
|
|
19
19
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
20
20
|
|
21
21
|
from ...image_processor import PipelineImageInput
|
22
|
-
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
22
|
+
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
23
23
|
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
24
24
|
from ...models.lora import adjust_lora_scale_text_encoder
|
25
25
|
from ...models.unets.unet_motion_model import MotionAdapter
|
@@ -31,7 +31,7 @@ from ...schedulers import (
|
|
31
31
|
LMSDiscreteScheduler,
|
32
32
|
PNDMScheduler,
|
33
33
|
)
|
34
|
-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
34
|
+
from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
|
35
35
|
from ...utils.torch_utils import randn_tensor
|
36
36
|
from ...video_processor import VideoProcessor
|
37
37
|
from ..free_init_utils import FreeInitMixin
|
@@ -40,8 +40,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
|
40
40
|
from .pipeline_output import AnimateDiffPipelineOutput
|
41
41
|
|
42
42
|
|
43
|
+
if is_torch_xla_available():
|
44
|
+
import torch_xla.core.xla_model as xm
|
45
|
+
|
46
|
+
XLA_AVAILABLE = True
|
47
|
+
else:
|
48
|
+
XLA_AVAILABLE = False
|
49
|
+
|
43
50
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
44
51
|
|
52
|
+
|
45
53
|
EXAMPLE_DOC_STRING = """
|
46
54
|
Examples:
|
47
55
|
```py
|
@@ -178,6 +186,7 @@ class AnimateDiffVideoToVideoPipeline(
|
|
178
186
|
StableDiffusionLoraLoaderMixin,
|
179
187
|
FreeInitMixin,
|
180
188
|
AnimateDiffFreeNoiseMixin,
|
189
|
+
FromSingleFileMixin,
|
181
190
|
):
|
182
191
|
r"""
|
183
192
|
Pipeline for video-to-video generation.
|
@@ -216,7 +225,7 @@ class AnimateDiffVideoToVideoPipeline(
|
|
216
225
|
vae: AutoencoderKL,
|
217
226
|
text_encoder: CLIPTextModel,
|
218
227
|
tokenizer: CLIPTokenizer,
|
219
|
-
unet: UNet2DConditionModel,
|
228
|
+
unet: Union[UNet2DConditionModel, UNetMotionModel],
|
220
229
|
motion_adapter: MotionAdapter,
|
221
230
|
scheduler: Union[
|
222
231
|
DDIMScheduler,
|
@@ -243,7 +252,7 @@ class AnimateDiffVideoToVideoPipeline(
|
|
243
252
|
feature_extractor=feature_extractor,
|
244
253
|
image_encoder=image_encoder,
|
245
254
|
)
|
246
|
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
255
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
247
256
|
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
|
248
257
|
|
249
258
|
def encode_prompt(
|
@@ -1037,6 +1046,9 @@ class AnimateDiffVideoToVideoPipeline(
|
|
1037
1046
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1038
1047
|
progress_bar.update()
|
1039
1048
|
|
1049
|
+
if XLA_AVAILABLE:
|
1050
|
+
xm.mark_step()
|
1051
|
+
|
1040
1052
|
# 10. Post-processing
|
1041
1053
|
if output_type == "latent":
|
1042
1054
|
video = latents
|
@@ -20,7 +20,7 @@ import torch.nn.functional as F
|
|
20
20
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
21
21
|
|
22
22
|
from ...image_processor import PipelineImageInput
|
23
|
-
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
23
|
+
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
24
24
|
from ...models import (
|
25
25
|
AutoencoderKL,
|
26
26
|
ControlNetModel,
|
@@ -39,7 +39,7 @@ from ...schedulers import (
|
|
39
39
|
LMSDiscreteScheduler,
|
40
40
|
PNDMScheduler,
|
41
41
|
)
|
42
|
-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
42
|
+
from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
|
43
43
|
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
44
44
|
from ...video_processor import VideoProcessor
|
45
45
|
from ..free_init_utils import FreeInitMixin
|
@@ -48,8 +48,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
|
48
48
|
from .pipeline_output import AnimateDiffPipelineOutput
|
49
49
|
|
50
50
|
|
51
|
+
if is_torch_xla_available():
|
52
|
+
import torch_xla.core.xla_model as xm
|
53
|
+
|
54
|
+
XLA_AVAILABLE = True
|
55
|
+
else:
|
56
|
+
XLA_AVAILABLE = False
|
57
|
+
|
51
58
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
52
59
|
|
60
|
+
|
53
61
|
EXAMPLE_DOC_STRING = """
|
54
62
|
Examples:
|
55
63
|
```py
|
@@ -196,6 +204,7 @@ class AnimateDiffVideoToVideoControlNetPipeline(
|
|
196
204
|
StableDiffusionLoraLoaderMixin,
|
197
205
|
FreeInitMixin,
|
198
206
|
AnimateDiffFreeNoiseMixin,
|
207
|
+
FromSingleFileMixin,
|
199
208
|
):
|
200
209
|
r"""
|
201
210
|
Pipeline for video-to-video generation with ControlNet guidance.
|
@@ -238,7 +247,7 @@ class AnimateDiffVideoToVideoControlNetPipeline(
|
|
238
247
|
vae: AutoencoderKL,
|
239
248
|
text_encoder: CLIPTextModel,
|
240
249
|
tokenizer: CLIPTokenizer,
|
241
|
-
unet: UNet2DConditionModel,
|
250
|
+
unet: Union[UNet2DConditionModel, UNetMotionModel],
|
242
251
|
motion_adapter: MotionAdapter,
|
243
252
|
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
244
253
|
scheduler: Union[
|
@@ -270,7 +279,7 @@ class AnimateDiffVideoToVideoControlNetPipeline(
|
|
270
279
|
feature_extractor=feature_extractor,
|
271
280
|
image_encoder=image_encoder,
|
272
281
|
)
|
273
|
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
282
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
274
283
|
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
|
275
284
|
self.control_video_processor = VideoProcessor(
|
276
285
|
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
@@ -1325,6 +1334,9 @@ class AnimateDiffVideoToVideoControlNetPipeline(
|
|
1325
1334
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1326
1335
|
progress_bar.update()
|
1327
1336
|
|
1337
|
+
if XLA_AVAILABLE:
|
1338
|
+
xm.mark_step()
|
1339
|
+
|
1328
1340
|
# 11. Post-processing
|
1329
1341
|
if output_type == "latent":
|
1330
1342
|
video = latents
|
@@ -22,13 +22,21 @@ from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaT
|
|
22
22
|
|
23
23
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
24
24
|
from ...schedulers import KarrasDiffusionSchedulers
|
25
|
-
from ...utils import logging, replace_example_docstring
|
25
|
+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
26
26
|
from ...utils.torch_utils import randn_tensor
|
27
27
|
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin
|
28
28
|
|
29
29
|
|
30
|
+
if is_torch_xla_available():
|
31
|
+
import torch_xla.core.xla_model as xm
|
32
|
+
|
33
|
+
XLA_AVAILABLE = True
|
34
|
+
else:
|
35
|
+
XLA_AVAILABLE = False
|
36
|
+
|
30
37
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
31
38
|
|
39
|
+
|
32
40
|
EXAMPLE_DOC_STRING = """
|
33
41
|
Examples:
|
34
42
|
```py
|
@@ -94,7 +102,7 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
|
|
94
102
|
scheduler=scheduler,
|
95
103
|
vocoder=vocoder,
|
96
104
|
)
|
97
|
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
105
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
98
106
|
|
99
107
|
def _encode_prompt(
|
100
108
|
self,
|
@@ -530,6 +538,9 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
|
|
530
538
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
531
539
|
callback(step_idx, t, latents)
|
532
540
|
|
541
|
+
if XLA_AVAILABLE:
|
542
|
+
xm.mark_step()
|
543
|
+
|
533
544
|
# 8. Post-processing
|
534
545
|
mel_spectrogram = self.decode_latents(latents)
|
535
546
|
|
@@ -38,7 +38,7 @@ from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
|
38
38
|
from ...models.transformers.transformer_2d import Transformer2DModel
|
39
39
|
from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
|
40
40
|
from ...models.unets.unet_2d_condition import UNet2DConditionOutput
|
41
|
-
from ...utils import BaseOutput,
|
41
|
+
from ...utils import BaseOutput, logging
|
42
42
|
|
43
43
|
|
44
44
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -673,11 +673,6 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
|
673
673
|
for module in self.children():
|
674
674
|
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
675
675
|
|
676
|
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
|
677
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
678
|
-
if hasattr(module, "gradient_checkpointing"):
|
679
|
-
module.gradient_checkpointing = value
|
680
|
-
|
681
676
|
def forward(
|
682
677
|
self,
|
683
678
|
sample: torch.Tensor,
|
@@ -768,10 +763,11 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
|
768
763
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
769
764
|
# This would be a good case for the `match` statement (Python 3.10+)
|
770
765
|
is_mps = sample.device.type == "mps"
|
766
|
+
is_npu = sample.device.type == "npu"
|
771
767
|
if isinstance(timestep, float):
|
772
|
-
dtype = torch.float32 if is_mps else torch.float64
|
768
|
+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
773
769
|
else:
|
774
|
-
dtype = torch.int32 if is_mps else torch.int64
|
770
|
+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
775
771
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
776
772
|
elif len(timesteps.shape) == 0:
|
777
773
|
timesteps = timesteps[None].to(sample.device)
|
@@ -1113,23 +1109,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
1113
1109
|
|
1114
1110
|
for i in range(num_layers):
|
1115
1111
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1116
|
-
|
1117
|
-
def create_custom_forward(module, return_dict=None):
|
1118
|
-
def custom_forward(*inputs):
|
1119
|
-
if return_dict is not None:
|
1120
|
-
return module(*inputs, return_dict=return_dict)
|
1121
|
-
else:
|
1122
|
-
return module(*inputs)
|
1123
|
-
|
1124
|
-
return custom_forward
|
1125
|
-
|
1126
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1127
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1128
|
-
create_custom_forward(self.resnets[i]),
|
1129
|
-
hidden_states,
|
1130
|
-
temb,
|
1131
|
-
**ckpt_kwargs,
|
1132
|
-
)
|
1112
|
+
hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
|
1133
1113
|
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
1134
1114
|
if cross_attention_dim is not None and idx <= 1:
|
1135
1115
|
forward_encoder_hidden_states = encoder_hidden_states
|
@@ -1140,8 +1120,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
1140
1120
|
else:
|
1141
1121
|
forward_encoder_hidden_states = None
|
1142
1122
|
forward_encoder_attention_mask = None
|
1143
|
-
hidden_states =
|
1144
|
-
|
1123
|
+
hidden_states = self._gradient_checkpointing_func(
|
1124
|
+
self.attentions[i * num_attention_per_layer + idx],
|
1145
1125
|
hidden_states,
|
1146
1126
|
forward_encoder_hidden_states,
|
1147
1127
|
None, # timestep
|
@@ -1149,7 +1129,6 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
1149
1129
|
cross_attention_kwargs,
|
1150
1130
|
attention_mask,
|
1151
1131
|
forward_encoder_attention_mask,
|
1152
|
-
**ckpt_kwargs,
|
1153
1132
|
)[0]
|
1154
1133
|
else:
|
1155
1134
|
hidden_states = self.resnets[i](hidden_states, temb)
|
@@ -1291,17 +1270,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
1291
1270
|
|
1292
1271
|
for i in range(len(self.resnets[1:])):
|
1293
1272
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1294
|
-
|
1295
|
-
def create_custom_forward(module, return_dict=None):
|
1296
|
-
def custom_forward(*inputs):
|
1297
|
-
if return_dict is not None:
|
1298
|
-
return module(*inputs, return_dict=return_dict)
|
1299
|
-
else:
|
1300
|
-
return module(*inputs)
|
1301
|
-
|
1302
|
-
return custom_forward
|
1303
|
-
|
1304
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1305
1273
|
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
1306
1274
|
if cross_attention_dim is not None and idx <= 1:
|
1307
1275
|
forward_encoder_hidden_states = encoder_hidden_states
|
@@ -1312,8 +1280,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
1312
1280
|
else:
|
1313
1281
|
forward_encoder_hidden_states = None
|
1314
1282
|
forward_encoder_attention_mask = None
|
1315
|
-
hidden_states =
|
1316
|
-
|
1283
|
+
hidden_states = self._gradient_checkpointing_func(
|
1284
|
+
self.attentions[i * num_attention_per_layer + idx],
|
1317
1285
|
hidden_states,
|
1318
1286
|
forward_encoder_hidden_states,
|
1319
1287
|
None, # timestep
|
@@ -1321,14 +1289,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
1321
1289
|
cross_attention_kwargs,
|
1322
1290
|
attention_mask,
|
1323
1291
|
forward_encoder_attention_mask,
|
1324
|
-
**ckpt_kwargs,
|
1325
1292
|
)[0]
|
1326
|
-
hidden_states =
|
1327
|
-
create_custom_forward(self.resnets[i + 1]),
|
1328
|
-
hidden_states,
|
1329
|
-
temb,
|
1330
|
-
**ckpt_kwargs,
|
1331
|
-
)
|
1293
|
+
hidden_states = self._gradient_checkpointing_func(self.resnets[i + 1], hidden_states, temb)
|
1332
1294
|
else:
|
1333
1295
|
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
1334
1296
|
if cross_attention_dim is not None and idx <= 1:
|
@@ -1465,23 +1427,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
1465
1427
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1466
1428
|
|
1467
1429
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1468
|
-
|
1469
|
-
def create_custom_forward(module, return_dict=None):
|
1470
|
-
def custom_forward(*inputs):
|
1471
|
-
if return_dict is not None:
|
1472
|
-
return module(*inputs, return_dict=return_dict)
|
1473
|
-
else:
|
1474
|
-
return module(*inputs)
|
1475
|
-
|
1476
|
-
return custom_forward
|
1477
|
-
|
1478
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1479
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1480
|
-
create_custom_forward(self.resnets[i]),
|
1481
|
-
hidden_states,
|
1482
|
-
temb,
|
1483
|
-
**ckpt_kwargs,
|
1484
|
-
)
|
1430
|
+
hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
|
1485
1431
|
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
1486
1432
|
if cross_attention_dim is not None and idx <= 1:
|
1487
1433
|
forward_encoder_hidden_states = encoder_hidden_states
|
@@ -1492,8 +1438,8 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
1492
1438
|
else:
|
1493
1439
|
forward_encoder_hidden_states = None
|
1494
1440
|
forward_encoder_attention_mask = None
|
1495
|
-
hidden_states =
|
1496
|
-
|
1441
|
+
hidden_states = self._gradient_checkpointing_func(
|
1442
|
+
self.attentions[i * num_attention_per_layer + idx],
|
1497
1443
|
hidden_states,
|
1498
1444
|
forward_encoder_hidden_states,
|
1499
1445
|
None, # timestep
|
@@ -1501,7 +1447,6 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
1501
1447
|
cross_attention_kwargs,
|
1502
1448
|
attention_mask,
|
1503
1449
|
forward_encoder_attention_mask,
|
1504
|
-
**ckpt_kwargs,
|
1505
1450
|
)[0]
|
1506
1451
|
else:
|
1507
1452
|
hidden_states = self.resnets[i](hidden_states, temb)
|
@@ -20,7 +20,7 @@ import torch
|
|
20
20
|
from transformers import (
|
21
21
|
ClapFeatureExtractor,
|
22
22
|
ClapModel,
|
23
|
-
|
23
|
+
GPT2LMHeadModel,
|
24
24
|
RobertaTokenizer,
|
25
25
|
RobertaTokenizerFast,
|
26
26
|
SpeechT5HifiGan,
|
@@ -48,8 +48,20 @@ from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditi
|
|
48
48
|
if is_librosa_available():
|
49
49
|
import librosa
|
50
50
|
|
51
|
+
|
52
|
+
from ...utils import is_torch_xla_available
|
53
|
+
|
54
|
+
|
55
|
+
if is_torch_xla_available():
|
56
|
+
import torch_xla.core.xla_model as xm
|
57
|
+
|
58
|
+
XLA_AVAILABLE = True
|
59
|
+
else:
|
60
|
+
XLA_AVAILABLE = False
|
61
|
+
|
51
62
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
52
63
|
|
64
|
+
|
53
65
|
EXAMPLE_DOC_STRING = """
|
54
66
|
Examples:
|
55
67
|
```py
|
@@ -184,7 +196,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
|
184
196
|
text_encoder: ClapModel,
|
185
197
|
text_encoder_2: Union[T5EncoderModel, VitsModel],
|
186
198
|
projection_model: AudioLDM2ProjectionModel,
|
187
|
-
language_model:
|
199
|
+
language_model: GPT2LMHeadModel,
|
188
200
|
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
|
189
201
|
tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer],
|
190
202
|
feature_extractor: ClapFeatureExtractor,
|
@@ -207,7 +219,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
|
207
219
|
scheduler=scheduler,
|
208
220
|
vocoder=vocoder,
|
209
221
|
)
|
210
|
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
222
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
211
223
|
|
212
224
|
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing
|
213
225
|
def enable_vae_slicing(self):
|
@@ -225,7 +237,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
|
225
237
|
"""
|
226
238
|
self.vae.disable_slicing()
|
227
239
|
|
228
|
-
def enable_model_cpu_offload(self, gpu_id=
|
240
|
+
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
229
241
|
r"""
|
230
242
|
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
231
243
|
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
@@ -237,11 +249,26 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
|
237
249
|
else:
|
238
250
|
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
239
251
|
|
240
|
-
|
252
|
+
torch_device = torch.device(device)
|
253
|
+
device_index = torch_device.index
|
254
|
+
|
255
|
+
if gpu_id is not None and device_index is not None:
|
256
|
+
raise ValueError(
|
257
|
+
f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
|
258
|
+
f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
|
259
|
+
)
|
260
|
+
|
261
|
+
device_type = torch_device.type
|
262
|
+
device_str = device_type
|
263
|
+
if gpu_id or torch_device.index:
|
264
|
+
device_str = f"{device_str}:{gpu_id or torch_device.index}"
|
265
|
+
device = torch.device(device_str)
|
241
266
|
|
242
267
|
if self.device.type != "cpu":
|
243
268
|
self.to("cpu", silence_dtype_warnings=True)
|
244
|
-
|
269
|
+
device_mod = getattr(torch, device.type, None)
|
270
|
+
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
|
271
|
+
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
245
272
|
|
246
273
|
model_sequence = [
|
247
274
|
self.text_encoder.text_model,
|
@@ -292,9 +319,9 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
|
292
319
|
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
|
293
320
|
|
294
321
|
# forward pass to get next hidden states
|
295
|
-
output = self.language_model(**model_inputs, return_dict=True)
|
322
|
+
output = self.language_model(**model_inputs, output_hidden_states=True, return_dict=True)
|
296
323
|
|
297
|
-
next_hidden_states = output.
|
324
|
+
next_hidden_states = output.hidden_states[-1]
|
298
325
|
|
299
326
|
# Update the model input
|
300
327
|
inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
|
@@ -764,7 +791,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
|
764
791
|
|
765
792
|
if transcription is None:
|
766
793
|
if self.text_encoder_2.config.model_type == "vits":
|
767
|
-
raise ValueError("Cannot forward without transcription. Please make sure to
|
794
|
+
raise ValueError("Cannot forward without transcription. Please make sure to have transcription")
|
768
795
|
elif transcription is not None and (
|
769
796
|
not isinstance(transcription, str) and not isinstance(transcription, list)
|
770
797
|
):
|
@@ -1033,6 +1060,9 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
|
1033
1060
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
1034
1061
|
callback(step_idx, t, latents)
|
1035
1062
|
|
1063
|
+
if XLA_AVAILABLE:
|
1064
|
+
xm.mark_step()
|
1065
|
+
|
1036
1066
|
self.maybe_free_model_hooks()
|
1037
1067
|
|
1038
1068
|
# 8. Post-processing
|
@@ -12,20 +12,28 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
import inspect
|
15
|
-
from typing import List, Optional, Tuple, Union
|
15
|
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from transformers import T5Tokenizer, UMT5EncoderModel
|
19
19
|
|
20
|
+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
20
21
|
from ...image_processor import VaeImageProcessor
|
21
22
|
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
|
22
23
|
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
|
23
24
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
24
|
-
from ...utils import logging, replace_example_docstring
|
25
|
+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
25
26
|
from ...utils.torch_utils import randn_tensor
|
26
27
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
27
28
|
|
28
29
|
|
30
|
+
if is_torch_xla_available():
|
31
|
+
import torch_xla.core.xla_model as xm
|
32
|
+
|
33
|
+
XLA_AVAILABLE = True
|
34
|
+
else:
|
35
|
+
XLA_AVAILABLE = False
|
36
|
+
|
29
37
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
30
38
|
|
31
39
|
|
@@ -124,6 +132,10 @@ class AuraFlowPipeline(DiffusionPipeline):
|
|
124
132
|
|
125
133
|
_optional_components = []
|
126
134
|
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
135
|
+
_callback_tensor_inputs = [
|
136
|
+
"latents",
|
137
|
+
"prompt_embeds",
|
138
|
+
]
|
127
139
|
|
128
140
|
def __init__(
|
129
141
|
self,
|
@@ -139,9 +151,7 @@ class AuraFlowPipeline(DiffusionPipeline):
|
|
139
151
|
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
140
152
|
)
|
141
153
|
|
142
|
-
self.vae_scale_factor = (
|
143
|
-
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
144
|
-
)
|
154
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
145
155
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
146
156
|
|
147
157
|
def check_inputs(
|
@@ -154,10 +164,19 @@ class AuraFlowPipeline(DiffusionPipeline):
|
|
154
164
|
negative_prompt_embeds=None,
|
155
165
|
prompt_attention_mask=None,
|
156
166
|
negative_prompt_attention_mask=None,
|
167
|
+
callback_on_step_end_tensor_inputs=None,
|
157
168
|
):
|
158
|
-
if height %
|
159
|
-
raise ValueError(
|
169
|
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
170
|
+
raise ValueError(
|
171
|
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
|
172
|
+
)
|
160
173
|
|
174
|
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
175
|
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
176
|
+
):
|
177
|
+
raise ValueError(
|
178
|
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
179
|
+
)
|
161
180
|
if prompt is not None and prompt_embeds is not None:
|
162
181
|
raise ValueError(
|
163
182
|
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
@@ -380,6 +399,14 @@ class AuraFlowPipeline(DiffusionPipeline):
|
|
380
399
|
self.vae.decoder.conv_in.to(dtype)
|
381
400
|
self.vae.decoder.mid_block.to(dtype)
|
382
401
|
|
402
|
+
@property
|
403
|
+
def guidance_scale(self):
|
404
|
+
return self._guidance_scale
|
405
|
+
|
406
|
+
@property
|
407
|
+
def num_timesteps(self):
|
408
|
+
return self._num_timesteps
|
409
|
+
|
383
410
|
@torch.no_grad()
|
384
411
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
385
412
|
def __call__(
|
@@ -401,6 +428,10 @@ class AuraFlowPipeline(DiffusionPipeline):
|
|
401
428
|
max_sequence_length: int = 256,
|
402
429
|
output_type: Optional[str] = "pil",
|
403
430
|
return_dict: bool = True,
|
431
|
+
callback_on_step_end: Optional[
|
432
|
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
433
|
+
] = None,
|
434
|
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
404
435
|
) -> Union[ImagePipelineOutput, Tuple]:
|
405
436
|
r"""
|
406
437
|
Function invoked when calling the pipeline for generation.
|
@@ -455,6 +486,15 @@ class AuraFlowPipeline(DiffusionPipeline):
|
|
455
486
|
return_dict (`bool`, *optional*, defaults to `True`):
|
456
487
|
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
457
488
|
of a plain tuple.
|
489
|
+
callback_on_step_end (`Callable`, *optional*):
|
490
|
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
491
|
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
492
|
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
493
|
+
`callback_on_step_end_tensor_inputs`.
|
494
|
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
495
|
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
496
|
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
497
|
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
458
498
|
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
459
499
|
|
460
500
|
Examples:
|
@@ -476,8 +516,11 @@ class AuraFlowPipeline(DiffusionPipeline):
|
|
476
516
|
negative_prompt_embeds,
|
477
517
|
prompt_attention_mask,
|
478
518
|
negative_prompt_attention_mask,
|
519
|
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
479
520
|
)
|
480
521
|
|
522
|
+
self._guidance_scale = guidance_scale
|
523
|
+
|
481
524
|
# 2. Determine batch size.
|
482
525
|
if prompt is not None and isinstance(prompt, str):
|
483
526
|
batch_size = 1
|
@@ -534,6 +577,7 @@ class AuraFlowPipeline(DiffusionPipeline):
|
|
534
577
|
|
535
578
|
# 6. Denoising loop
|
536
579
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
580
|
+
self._num_timesteps = len(timesteps)
|
537
581
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
538
582
|
for i, t in enumerate(timesteps):
|
539
583
|
# expand the latents if we are doing classifier free guidance
|
@@ -560,10 +604,22 @@ class AuraFlowPipeline(DiffusionPipeline):
|
|
560
604
|
# compute the previous noisy sample x_t -> x_t-1
|
561
605
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
562
606
|
|
607
|
+
if callback_on_step_end is not None:
|
608
|
+
callback_kwargs = {}
|
609
|
+
for k in callback_on_step_end_tensor_inputs:
|
610
|
+
callback_kwargs[k] = locals()[k]
|
611
|
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
612
|
+
|
613
|
+
latents = callback_outputs.pop("latents", latents)
|
614
|
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
615
|
+
|
563
616
|
# call the callback, if provided
|
564
617
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
565
618
|
progress_bar.update()
|
566
619
|
|
620
|
+
if XLA_AVAILABLE:
|
621
|
+
xm.mark_step()
|
622
|
+
|
567
623
|
if output_type == "latent":
|
568
624
|
image = latents
|
569
625
|
else:
|