diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -104,13 +104,6 @@ class RMSNorm(torch.nn.Module):
|
|
104
104
|
return (self.weight * hidden_states).to(input_dtype)
|
105
105
|
|
106
106
|
|
107
|
-
def _config_to_kwargs(args):
|
108
|
-
common_kwargs = {
|
109
|
-
"dtype": args.torch_dtype,
|
110
|
-
}
|
111
|
-
return common_kwargs
|
112
|
-
|
113
|
-
|
114
107
|
class CoreAttention(torch.nn.Module):
|
115
108
|
def __init__(self, config: ChatGLMConfig, layer_number):
|
116
109
|
super(CoreAttention, self).__init__()
|
@@ -314,7 +307,6 @@ class SelfAttention(torch.nn.Module):
|
|
314
307
|
self.qkv_hidden_size,
|
315
308
|
bias=config.add_bias_linear or config.add_qkv_bias,
|
316
309
|
device=device,
|
317
|
-
**_config_to_kwargs(config),
|
318
310
|
)
|
319
311
|
|
320
312
|
self.core_attention = CoreAttention(config, self.layer_number)
|
@@ -325,7 +317,6 @@ class SelfAttention(torch.nn.Module):
|
|
325
317
|
config.hidden_size,
|
326
318
|
bias=config.add_bias_linear,
|
327
319
|
device=device,
|
328
|
-
**_config_to_kwargs(config),
|
329
320
|
)
|
330
321
|
|
331
322
|
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
|
@@ -449,7 +440,6 @@ class MLP(torch.nn.Module):
|
|
449
440
|
config.ffn_hidden_size * 2,
|
450
441
|
bias=self.add_bias,
|
451
442
|
device=device,
|
452
|
-
**_config_to_kwargs(config),
|
453
443
|
)
|
454
444
|
|
455
445
|
def swiglu(x):
|
@@ -459,9 +449,7 @@ class MLP(torch.nn.Module):
|
|
459
449
|
self.activation_func = swiglu
|
460
450
|
|
461
451
|
# Project back to h.
|
462
|
-
self.dense_4h_to_h = nn.Linear(
|
463
|
-
config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config)
|
464
|
-
)
|
452
|
+
self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device)
|
465
453
|
|
466
454
|
def forward(self, hidden_states):
|
467
455
|
# [s, b, 4hp]
|
@@ -488,18 +476,14 @@ class GLMBlock(torch.nn.Module):
|
|
488
476
|
|
489
477
|
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
490
478
|
# Layernorm on the input data.
|
491
|
-
self.input_layernorm = LayerNormFunc(
|
492
|
-
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
|
493
|
-
)
|
479
|
+
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
|
494
480
|
|
495
481
|
# Self attention.
|
496
482
|
self.self_attention = SelfAttention(config, layer_number, device=device)
|
497
483
|
self.hidden_dropout = config.hidden_dropout
|
498
484
|
|
499
485
|
# Layernorm on the attention output
|
500
|
-
self.post_attention_layernorm = LayerNormFunc(
|
501
|
-
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
|
502
|
-
)
|
486
|
+
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
|
503
487
|
|
504
488
|
# MLP
|
505
489
|
self.mlp = MLP(config, device=device)
|
@@ -569,9 +553,7 @@ class GLMTransformer(torch.nn.Module):
|
|
569
553
|
if self.post_layer_norm:
|
570
554
|
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
571
555
|
# Final layer norm before output.
|
572
|
-
self.final_layernorm = LayerNormFunc(
|
573
|
-
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
|
574
|
-
)
|
556
|
+
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
|
575
557
|
|
576
558
|
self.gradient_checkpointing = False
|
577
559
|
|
@@ -605,7 +587,7 @@ class GLMTransformer(torch.nn.Module):
|
|
605
587
|
|
606
588
|
layer = self._get_layer(index)
|
607
589
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
608
|
-
layer_ret =
|
590
|
+
layer_ret = self._gradient_checkpointing_func(
|
609
591
|
layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache
|
610
592
|
)
|
611
593
|
else:
|
@@ -666,10 +648,6 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
666
648
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
667
649
|
return position_ids
|
668
650
|
|
669
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
670
|
-
if isinstance(module, GLMTransformer):
|
671
|
-
module.gradient_checkpointing = value
|
672
|
-
|
673
651
|
|
674
652
|
def default_init(cls, *args, **kwargs):
|
675
653
|
return cls(*args, **kwargs)
|
@@ -683,9 +661,7 @@ class Embedding(torch.nn.Module):
|
|
683
661
|
|
684
662
|
self.hidden_size = config.hidden_size
|
685
663
|
# Word embeddings (parallel).
|
686
|
-
self.word_embeddings = nn.Embedding(
|
687
|
-
config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device
|
688
|
-
)
|
664
|
+
self.word_embeddings = nn.Embedding(config.padded_vocab_size, self.hidden_size, device=device)
|
689
665
|
self.fp32_residual_connection = config.fp32_residual_connection
|
690
666
|
|
691
667
|
def forward(self, input_ids):
|
@@ -788,16 +764,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
788
764
|
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
789
765
|
)
|
790
766
|
|
791
|
-
self.rotary_pos_emb = RotaryEmbedding(
|
792
|
-
rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype
|
793
|
-
)
|
767
|
+
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device)
|
794
768
|
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
795
769
|
self.output_layer = init_method(
|
796
770
|
nn.Linear,
|
797
771
|
config.hidden_size,
|
798
772
|
config.padded_vocab_size,
|
799
773
|
bias=False,
|
800
|
-
dtype=config.torch_dtype,
|
801
774
|
**init_kwargs,
|
802
775
|
)
|
803
776
|
self.pre_seq_len = config.pre_seq_len
|
@@ -30,6 +30,7 @@ from ...schedulers import LCMScheduler
|
|
30
30
|
from ...utils import (
|
31
31
|
USE_PEFT_BACKEND,
|
32
32
|
deprecate,
|
33
|
+
is_torch_xla_available,
|
33
34
|
logging,
|
34
35
|
replace_example_docstring,
|
35
36
|
scale_lora_layers,
|
@@ -40,6 +41,13 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
|
40
41
|
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
41
42
|
|
42
43
|
|
44
|
+
if is_torch_xla_available():
|
45
|
+
import torch_xla.core.xla_model as xm
|
46
|
+
|
47
|
+
XLA_AVAILABLE = True
|
48
|
+
else:
|
49
|
+
XLA_AVAILABLE = False
|
50
|
+
|
43
51
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
44
52
|
|
45
53
|
|
@@ -226,7 +234,7 @@ class LatentConsistencyModelImg2ImgPipeline(
|
|
226
234
|
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
227
235
|
)
|
228
236
|
|
229
|
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
237
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
230
238
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
231
239
|
|
232
240
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
@@ -952,6 +960,9 @@ class LatentConsistencyModelImg2ImgPipeline(
|
|
952
960
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
953
961
|
callback(step_idx, t, latents)
|
954
962
|
|
963
|
+
if XLA_AVAILABLE:
|
964
|
+
xm.mark_step()
|
965
|
+
|
955
966
|
denoised = denoised.to(prompt_embeds.dtype)
|
956
967
|
if not output_type == "latent":
|
957
968
|
image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
|
@@ -29,6 +29,7 @@ from ...schedulers import LCMScheduler
|
|
29
29
|
from ...utils import (
|
30
30
|
USE_PEFT_BACKEND,
|
31
31
|
deprecate,
|
32
|
+
is_torch_xla_available,
|
32
33
|
logging,
|
33
34
|
replace_example_docstring,
|
34
35
|
scale_lora_layers,
|
@@ -39,8 +40,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
|
39
40
|
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
40
41
|
|
41
42
|
|
43
|
+
if is_torch_xla_available():
|
44
|
+
import torch_xla.core.xla_model as xm
|
45
|
+
|
46
|
+
XLA_AVAILABLE = True
|
47
|
+
else:
|
48
|
+
XLA_AVAILABLE = False
|
49
|
+
|
42
50
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
43
51
|
|
52
|
+
|
44
53
|
EXAMPLE_DOC_STRING = """
|
45
54
|
Examples:
|
46
55
|
```py
|
@@ -209,7 +218,7 @@ class LatentConsistencyModelPipeline(
|
|
209
218
|
feature_extractor=feature_extractor,
|
210
219
|
image_encoder=image_encoder,
|
211
220
|
)
|
212
|
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
221
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
213
222
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
214
223
|
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
215
224
|
|
@@ -881,6 +890,9 @@ class LatentConsistencyModelPipeline(
|
|
881
890
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
882
891
|
callback(step_idx, t, latents)
|
883
892
|
|
893
|
+
if XLA_AVAILABLE:
|
894
|
+
xm.mark_step()
|
895
|
+
|
884
896
|
denoised = denoised.to(prompt_embeds.dtype)
|
885
897
|
if not output_type == "latent":
|
886
898
|
image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
|
@@ -25,10 +25,19 @@ from transformers.utils import logging
|
|
25
25
|
|
26
26
|
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
|
27
27
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
28
|
+
from ...utils import is_torch_xla_available
|
28
29
|
from ...utils.torch_utils import randn_tensor
|
29
30
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
30
31
|
|
31
32
|
|
33
|
+
if is_torch_xla_available():
|
34
|
+
import torch_xla.core.xla_model as xm
|
35
|
+
|
36
|
+
XLA_AVAILABLE = True
|
37
|
+
else:
|
38
|
+
XLA_AVAILABLE = False
|
39
|
+
|
40
|
+
|
32
41
|
class LDMTextToImagePipeline(DiffusionPipeline):
|
33
42
|
r"""
|
34
43
|
Pipeline for text-to-image generation using latent diffusion.
|
@@ -202,6 +211,9 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
|
202
211
|
# compute the previous noisy sample x_t -> x_t-1
|
203
212
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
|
204
213
|
|
214
|
+
if XLA_AVAILABLE:
|
215
|
+
xm.mark_step()
|
216
|
+
|
205
217
|
# scale and decode the image latents with vae
|
206
218
|
latents = 1 / self.vqvae.config.scaling_factor * latents
|
207
219
|
image = self.vqvae.decode(latents).sample
|
@@ -532,10 +544,6 @@ class LDMBertPreTrainedModel(PreTrainedModel):
|
|
532
544
|
if module.padding_idx is not None:
|
533
545
|
module.weight.data[module.padding_idx].zero_()
|
534
546
|
|
535
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
536
|
-
if isinstance(module, (LDMBertEncoder,)):
|
537
|
-
module.gradient_checkpointing = value
|
538
|
-
|
539
547
|
@property
|
540
548
|
def dummy_inputs(self):
|
541
549
|
pad_token = self.config.pad_token_id
|
@@ -676,15 +684,8 @@ class LDMBertEncoder(LDMBertPreTrainedModel):
|
|
676
684
|
if output_hidden_states:
|
677
685
|
encoder_states = encoder_states + (hidden_states,)
|
678
686
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
679
|
-
|
680
|
-
|
681
|
-
def custom_forward(*inputs):
|
682
|
-
return module(*inputs, output_attentions)
|
683
|
-
|
684
|
-
return custom_forward
|
685
|
-
|
686
|
-
layer_outputs = torch.utils.checkpoint.checkpoint(
|
687
|
-
create_custom_forward(encoder_layer),
|
687
|
+
layer_outputs = self._gradient_checkpointing_func(
|
688
|
+
encoder_layer,
|
688
689
|
hidden_states,
|
689
690
|
attention_mask,
|
690
691
|
(head_mask[idx] if head_mask is not None else None),
|
@@ -15,11 +15,19 @@ from ...schedulers import (
|
|
15
15
|
LMSDiscreteScheduler,
|
16
16
|
PNDMScheduler,
|
17
17
|
)
|
18
|
-
from ...utils import PIL_INTERPOLATION
|
18
|
+
from ...utils import PIL_INTERPOLATION, is_torch_xla_available
|
19
19
|
from ...utils.torch_utils import randn_tensor
|
20
20
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
21
21
|
|
22
22
|
|
23
|
+
if is_torch_xla_available():
|
24
|
+
import torch_xla.core.xla_model as xm
|
25
|
+
|
26
|
+
XLA_AVAILABLE = True
|
27
|
+
else:
|
28
|
+
XLA_AVAILABLE = False
|
29
|
+
|
30
|
+
|
23
31
|
def preprocess(image):
|
24
32
|
w, h = image.size
|
25
33
|
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
|
@@ -174,6 +182,9 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
|
|
174
182
|
# compute the previous noisy sample x_t -> x_t-1
|
175
183
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
|
176
184
|
|
185
|
+
if XLA_AVAILABLE:
|
186
|
+
xm.mark_step()
|
187
|
+
|
177
188
|
# decode the image latents with the VQVAE
|
178
189
|
image = self.vqvae.decode(latents).sample
|
179
190
|
image = torch.clamp(image, -1.0, 1.0)
|
@@ -30,8 +30,10 @@ from ...schedulers import KarrasDiffusionSchedulers
|
|
30
30
|
from ...utils import (
|
31
31
|
BACKENDS_MAPPING,
|
32
32
|
BaseOutput,
|
33
|
+
deprecate,
|
33
34
|
is_bs4_available,
|
34
35
|
is_ftfy_available,
|
36
|
+
is_torch_xla_available,
|
35
37
|
logging,
|
36
38
|
replace_example_docstring,
|
37
39
|
)
|
@@ -39,8 +41,16 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor
|
|
39
41
|
from ...video_processor import VideoProcessor
|
40
42
|
|
41
43
|
|
44
|
+
if is_torch_xla_available():
|
45
|
+
import torch_xla.core.xla_model as xm
|
46
|
+
|
47
|
+
XLA_AVAILABLE = True
|
48
|
+
else:
|
49
|
+
XLA_AVAILABLE = False
|
50
|
+
|
42
51
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
43
52
|
|
53
|
+
|
44
54
|
if is_bs4_available():
|
45
55
|
from bs4 import BeautifulSoup
|
46
56
|
|
@@ -180,7 +190,7 @@ class LattePipeline(DiffusionPipeline):
|
|
180
190
|
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
181
191
|
)
|
182
192
|
|
183
|
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
193
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
184
194
|
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
|
185
195
|
|
186
196
|
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
@@ -592,6 +602,10 @@ class LattePipeline(DiffusionPipeline):
|
|
592
602
|
def num_timesteps(self):
|
593
603
|
return self._num_timesteps
|
594
604
|
|
605
|
+
@property
|
606
|
+
def current_timestep(self):
|
607
|
+
return self._current_timestep
|
608
|
+
|
595
609
|
@property
|
596
610
|
def interrupt(self):
|
597
611
|
return self._interrupt
|
@@ -623,7 +637,7 @@ class LattePipeline(DiffusionPipeline):
|
|
623
637
|
clean_caption: bool = True,
|
624
638
|
mask_feature: bool = True,
|
625
639
|
enable_temporal_attentions: bool = True,
|
626
|
-
decode_chunk_size:
|
640
|
+
decode_chunk_size: int = 14,
|
627
641
|
) -> Union[LattePipelineOutput, Tuple]:
|
628
642
|
"""
|
629
643
|
Function invoked when calling the pipeline for generation.
|
@@ -719,6 +733,7 @@ class LattePipeline(DiffusionPipeline):
|
|
719
733
|
negative_prompt_embeds,
|
720
734
|
)
|
721
735
|
self._guidance_scale = guidance_scale
|
736
|
+
self._current_timestep = None
|
722
737
|
self._interrupt = False
|
723
738
|
|
724
739
|
# 2. Default height and width to transformer
|
@@ -780,6 +795,7 @@ class LattePipeline(DiffusionPipeline):
|
|
780
795
|
if self.interrupt:
|
781
796
|
continue
|
782
797
|
|
798
|
+
self._current_timestep = t
|
783
799
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
784
800
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
785
801
|
|
@@ -788,10 +804,11 @@ class LattePipeline(DiffusionPipeline):
|
|
788
804
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
789
805
|
# This would be a good case for the `match` statement (Python 3.10+)
|
790
806
|
is_mps = latent_model_input.device.type == "mps"
|
807
|
+
is_npu = latent_model_input.device.type == "npu"
|
791
808
|
if isinstance(current_timestep, float):
|
792
|
-
dtype = torch.float32 if is_mps else torch.float64
|
809
|
+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
793
810
|
else:
|
794
|
-
dtype = torch.int32 if is_mps else torch.int64
|
811
|
+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
795
812
|
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
|
796
813
|
elif len(current_timestep.shape) == 0:
|
797
814
|
current_timestep = current_timestep[None].to(latent_model_input.device)
|
@@ -800,7 +817,7 @@ class LattePipeline(DiffusionPipeline):
|
|
800
817
|
|
801
818
|
# predict noise model_output
|
802
819
|
noise_pred = self.transformer(
|
803
|
-
latent_model_input,
|
820
|
+
hidden_states=latent_model_input,
|
804
821
|
encoder_hidden_states=prompt_embeds,
|
805
822
|
timestep=current_timestep,
|
806
823
|
enable_temporal_attentions=enable_temporal_attentions,
|
@@ -836,8 +853,20 @@ class LattePipeline(DiffusionPipeline):
|
|
836
853
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
837
854
|
progress_bar.update()
|
838
855
|
|
839
|
-
|
840
|
-
|
856
|
+
if XLA_AVAILABLE:
|
857
|
+
xm.mark_step()
|
858
|
+
|
859
|
+
self._current_timestep = None
|
860
|
+
|
861
|
+
if output_type == "latents":
|
862
|
+
deprecation_message = (
|
863
|
+
"Passing `output_type='latents'` is deprecated. Please pass `output_type='latent'` instead."
|
864
|
+
)
|
865
|
+
deprecate("output_type_latents", "1.0.0", deprecation_message, standard_warn=False)
|
866
|
+
output_type = "latent"
|
867
|
+
|
868
|
+
if not output_type == "latent":
|
869
|
+
video = self.decode_latents(latents, video_length, decode_chunk_size=decode_chunk_size)
|
841
870
|
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
842
871
|
else:
|
843
872
|
video = latents
|
@@ -19,6 +19,7 @@ from ...schedulers import DDIMScheduler, DPMSolverMultistepScheduler
|
|
19
19
|
from ...utils import (
|
20
20
|
USE_PEFT_BACKEND,
|
21
21
|
deprecate,
|
22
|
+
is_torch_xla_available,
|
22
23
|
logging,
|
23
24
|
replace_example_docstring,
|
24
25
|
scale_lora_layers,
|
@@ -29,26 +30,32 @@ from ..pipeline_utils import DiffusionPipeline
|
|
29
30
|
from .pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput
|
30
31
|
|
31
32
|
|
33
|
+
if is_torch_xla_available():
|
34
|
+
import torch_xla.core.xla_model as xm
|
35
|
+
|
36
|
+
XLA_AVAILABLE = True
|
37
|
+
else:
|
38
|
+
XLA_AVAILABLE = False
|
39
|
+
|
32
40
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
33
41
|
|
42
|
+
|
34
43
|
EXAMPLE_DOC_STRING = """
|
35
44
|
Examples:
|
36
45
|
```py
|
37
|
-
>>> import PIL
|
38
|
-
>>> import requests
|
39
46
|
>>> import torch
|
40
|
-
>>> from io import BytesIO
|
41
47
|
|
42
48
|
>>> from diffusers import LEditsPPPipelineStableDiffusion
|
43
49
|
>>> from diffusers.utils import load_image
|
44
50
|
|
45
51
|
>>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
|
46
|
-
... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
|
52
|
+
... "runwayml/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
|
47
53
|
... )
|
54
|
+
>>> pipe.enable_vae_tiling()
|
48
55
|
>>> pipe = pipe.to("cuda")
|
49
56
|
|
50
57
|
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png"
|
51
|
-
>>> image = load_image(img_url).
|
58
|
+
>>> image = load_image(img_url).resize((512, 512))
|
52
59
|
|
53
60
|
>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1)
|
54
61
|
|
@@ -152,7 +159,7 @@ class LeditsGaussianSmoothing:
|
|
152
159
|
|
153
160
|
# The gaussian kernel is the product of the gaussian function of each dimension.
|
154
161
|
kernel = 1
|
155
|
-
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
|
162
|
+
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
|
156
163
|
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
157
164
|
mean = (size - 1) / 2
|
158
165
|
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
|
@@ -318,7 +325,7 @@ class LEditsPPPipelineStableDiffusion(
|
|
318
325
|
"The scheduler has been changed to DPMSolverMultistepScheduler."
|
319
326
|
)
|
320
327
|
|
321
|
-
if
|
328
|
+
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
322
329
|
deprecation_message = (
|
323
330
|
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
324
331
|
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
@@ -332,7 +339,7 @@ class LEditsPPPipelineStableDiffusion(
|
|
332
339
|
new_config["steps_offset"] = 1
|
333
340
|
scheduler._internal_dict = FrozenDict(new_config)
|
334
341
|
|
335
|
-
if
|
342
|
+
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
|
336
343
|
deprecation_message = (
|
337
344
|
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
338
345
|
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
@@ -361,10 +368,14 @@ class LEditsPPPipelineStableDiffusion(
|
|
361
368
|
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
362
369
|
)
|
363
370
|
|
364
|
-
is_unet_version_less_0_9_0 =
|
365
|
-
|
366
|
-
|
367
|
-
|
371
|
+
is_unet_version_less_0_9_0 = (
|
372
|
+
unet is not None
|
373
|
+
and hasattr(unet.config, "_diffusers_version")
|
374
|
+
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
375
|
+
)
|
376
|
+
is_unet_sample_size_less_64 = (
|
377
|
+
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
378
|
+
)
|
368
379
|
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
369
380
|
deprecation_message = (
|
370
381
|
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
@@ -391,7 +402,7 @@ class LEditsPPPipelineStableDiffusion(
|
|
391
402
|
safety_checker=safety_checker,
|
392
403
|
feature_extractor=feature_extractor,
|
393
404
|
)
|
394
|
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
405
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
395
406
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
396
407
|
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
397
408
|
|
@@ -706,6 +717,35 @@ class LEditsPPPipelineStableDiffusion(
|
|
706
717
|
def cross_attention_kwargs(self):
|
707
718
|
return self._cross_attention_kwargs
|
708
719
|
|
720
|
+
def enable_vae_slicing(self):
|
721
|
+
r"""
|
722
|
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
723
|
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
724
|
+
"""
|
725
|
+
self.vae.enable_slicing()
|
726
|
+
|
727
|
+
def disable_vae_slicing(self):
|
728
|
+
r"""
|
729
|
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
730
|
+
computing decoding in one step.
|
731
|
+
"""
|
732
|
+
self.vae.disable_slicing()
|
733
|
+
|
734
|
+
def enable_vae_tiling(self):
|
735
|
+
r"""
|
736
|
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
737
|
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
738
|
+
processing larger images.
|
739
|
+
"""
|
740
|
+
self.vae.enable_tiling()
|
741
|
+
|
742
|
+
def disable_vae_tiling(self):
|
743
|
+
r"""
|
744
|
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
745
|
+
computing decoding in one step.
|
746
|
+
"""
|
747
|
+
self.vae.disable_tiling()
|
748
|
+
|
709
749
|
@torch.no_grad()
|
710
750
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
711
751
|
def __call__(
|
@@ -1182,6 +1222,9 @@ class LEditsPPPipelineStableDiffusion(
|
|
1182
1222
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1183
1223
|
progress_bar.update()
|
1184
1224
|
|
1225
|
+
if XLA_AVAILABLE:
|
1226
|
+
xm.mark_step()
|
1227
|
+
|
1185
1228
|
# 8. Post-processing
|
1186
1229
|
if not output_type == "latent":
|
1187
1230
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
@@ -1271,6 +1314,8 @@ class LEditsPPPipelineStableDiffusion(
|
|
1271
1314
|
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
|
1272
1315
|
and respective VAE reconstruction(s).
|
1273
1316
|
"""
|
1317
|
+
if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
|
1318
|
+
raise ValueError("height and width must be a factor of 32.")
|
1274
1319
|
# Reset attn processor, we do not want to store attn maps during inversion
|
1275
1320
|
self.unet.set_attn_processor(AttnProcessor())
|
1276
1321
|
|
@@ -1349,6 +1394,9 @@ class LEditsPPPipelineStableDiffusion(
|
|
1349
1394
|
|
1350
1395
|
progress_bar.update()
|
1351
1396
|
|
1397
|
+
if XLA_AVAILABLE:
|
1398
|
+
xm.mark_step()
|
1399
|
+
|
1352
1400
|
self.init_latents = xts[-1].expand(self.batch_size, -1, -1, -1)
|
1353
1401
|
zs = zs.flip(0)
|
1354
1402
|
self.zs = zs
|
@@ -1360,6 +1408,12 @@ class LEditsPPPipelineStableDiffusion(
|
|
1360
1408
|
image = self.image_processor.preprocess(
|
1361
1409
|
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
|
1362
1410
|
)
|
1411
|
+
height, width = image.shape[-2:]
|
1412
|
+
if height % 32 != 0 or width % 32 != 0:
|
1413
|
+
raise ValueError(
|
1414
|
+
"Image height and width must be a factor of 32. "
|
1415
|
+
"Consider down-sampling the input using the `height` and `width` parameters"
|
1416
|
+
)
|
1363
1417
|
resized = self.image_processor.postprocess(image=image, output_type="pil")
|
1364
1418
|
|
1365
1419
|
if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
|