diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +198 -28
- diffusers/loaders/lora_conversion_utils.py +679 -44
- diffusers/loaders/lora_pipeline.py +1963 -801
- diffusers/loaders/peft.py +169 -84
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +653 -75
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +22 -32
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +10 -2
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +14 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.1.dist-info/RECORD +0 -550
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -54,11 +54,32 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
54
54
|
Args:
|
55
55
|
num_train_timesteps (`int`, defaults to 1000):
|
56
56
|
The number of diffusion steps to train the model.
|
57
|
-
timestep_spacing (`str`, defaults to `"linspace"`):
|
58
|
-
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
59
|
-
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
60
57
|
shift (`float`, defaults to 1.0):
|
61
58
|
The shift value for the timestep schedule.
|
59
|
+
use_dynamic_shifting (`bool`, defaults to False):
|
60
|
+
Whether to apply timestep shifting on-the-fly based on the image resolution.
|
61
|
+
base_shift (`float`, defaults to 0.5):
|
62
|
+
Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent
|
63
|
+
with desired output.
|
64
|
+
max_shift (`float`, defaults to 1.15):
|
65
|
+
Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be
|
66
|
+
more exaggerated or stylized.
|
67
|
+
base_image_seq_len (`int`, defaults to 256):
|
68
|
+
The base image sequence length.
|
69
|
+
max_image_seq_len (`int`, defaults to 4096):
|
70
|
+
The maximum image sequence length.
|
71
|
+
invert_sigmas (`bool`, defaults to False):
|
72
|
+
Whether to invert the sigmas.
|
73
|
+
shift_terminal (`float`, defaults to None):
|
74
|
+
The end value of the shifted timestep schedule.
|
75
|
+
use_karras_sigmas (`bool`, defaults to False):
|
76
|
+
Whether to use Karras sigmas for step sizes in the noise schedule during sampling.
|
77
|
+
use_exponential_sigmas (`bool`, defaults to False):
|
78
|
+
Whether to use exponential sigmas for step sizes in the noise schedule during sampling.
|
79
|
+
use_beta_sigmas (`bool`, defaults to False):
|
80
|
+
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
|
81
|
+
time_shift_type (`str`, defaults to "exponential"):
|
82
|
+
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
|
62
83
|
"""
|
63
84
|
|
64
85
|
_compatibles = []
|
@@ -69,7 +90,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
69
90
|
self,
|
70
91
|
num_train_timesteps: int = 1000,
|
71
92
|
shift: float = 1.0,
|
72
|
-
use_dynamic_shifting=False,
|
93
|
+
use_dynamic_shifting: bool = False,
|
73
94
|
base_shift: Optional[float] = 0.5,
|
74
95
|
max_shift: Optional[float] = 1.15,
|
75
96
|
base_image_seq_len: Optional[int] = 256,
|
@@ -79,6 +100,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
79
100
|
use_karras_sigmas: Optional[bool] = False,
|
80
101
|
use_exponential_sigmas: Optional[bool] = False,
|
81
102
|
use_beta_sigmas: Optional[bool] = False,
|
103
|
+
time_shift_type: str = "exponential",
|
82
104
|
):
|
83
105
|
if self.config.use_beta_sigmas and not is_scipy_available():
|
84
106
|
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
@@ -86,6 +108,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
86
108
|
raise ValueError(
|
87
109
|
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
88
110
|
)
|
111
|
+
if time_shift_type not in {"exponential", "linear"}:
|
112
|
+
raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.")
|
113
|
+
|
89
114
|
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
90
115
|
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
91
116
|
|
@@ -192,7 +217,10 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
192
217
|
return sigma * self.config.num_train_timesteps
|
193
218
|
|
194
219
|
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
195
|
-
|
220
|
+
if self.config.time_shift_type == "exponential":
|
221
|
+
return self._time_shift_exponential(mu, sigma, t)
|
222
|
+
elif self.config.time_shift_type == "linear":
|
223
|
+
return self._time_shift_linear(mu, sigma, t)
|
196
224
|
|
197
225
|
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
|
198
226
|
r"""
|
@@ -217,54 +245,94 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
217
245
|
|
218
246
|
def set_timesteps(
|
219
247
|
self,
|
220
|
-
num_inference_steps: int = None,
|
248
|
+
num_inference_steps: Optional[int] = None,
|
221
249
|
device: Union[str, torch.device] = None,
|
222
250
|
sigmas: Optional[List[float]] = None,
|
223
251
|
mu: Optional[float] = None,
|
252
|
+
timesteps: Optional[List[float]] = None,
|
224
253
|
):
|
225
254
|
"""
|
226
255
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
227
256
|
|
228
257
|
Args:
|
229
|
-
num_inference_steps (`int
|
258
|
+
num_inference_steps (`int`, *optional*):
|
230
259
|
The number of diffusion steps used when generating samples with a pre-trained model.
|
231
260
|
device (`str` or `torch.device`, *optional*):
|
232
261
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
262
|
+
sigmas (`List[float]`, *optional*):
|
263
|
+
Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
|
264
|
+
automatically.
|
265
|
+
mu (`float`, *optional*):
|
266
|
+
Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep
|
267
|
+
shifting.
|
268
|
+
timesteps (`List[float]`, *optional*):
|
269
|
+
Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed
|
270
|
+
automatically.
|
233
271
|
"""
|
234
272
|
if self.config.use_dynamic_shifting and mu is None:
|
235
|
-
raise ValueError("
|
273
|
+
raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
|
274
|
+
|
275
|
+
if sigmas is not None and timesteps is not None:
|
276
|
+
if len(sigmas) != len(timesteps):
|
277
|
+
raise ValueError("`sigmas` and `timesteps` should have the same length")
|
278
|
+
|
279
|
+
if num_inference_steps is not None:
|
280
|
+
if (sigmas is not None and len(sigmas) != num_inference_steps) or (
|
281
|
+
timesteps is not None and len(timesteps) != num_inference_steps
|
282
|
+
):
|
283
|
+
raise ValueError(
|
284
|
+
"`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided"
|
285
|
+
)
|
286
|
+
else:
|
287
|
+
num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps)
|
236
288
|
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
289
|
+
self.num_inference_steps = num_inference_steps
|
290
|
+
|
291
|
+
# 1. Prepare default sigmas
|
292
|
+
is_timesteps_provided = timesteps is not None
|
241
293
|
|
294
|
+
if is_timesteps_provided:
|
295
|
+
timesteps = np.array(timesteps).astype(np.float32)
|
296
|
+
|
297
|
+
if sigmas is None:
|
298
|
+
if timesteps is None:
|
299
|
+
timesteps = np.linspace(
|
300
|
+
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
301
|
+
)
|
242
302
|
sigmas = timesteps / self.config.num_train_timesteps
|
243
303
|
else:
|
244
304
|
sigmas = np.array(sigmas).astype(np.float32)
|
245
305
|
num_inference_steps = len(sigmas)
|
246
|
-
self.num_inference_steps = num_inference_steps
|
247
306
|
|
307
|
+
# 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of
|
308
|
+
# "exponential" or "linear" type is applied
|
248
309
|
if self.config.use_dynamic_shifting:
|
249
310
|
sigmas = self.time_shift(mu, 1.0, sigmas)
|
250
311
|
else:
|
251
312
|
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
|
252
313
|
|
314
|
+
# 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value
|
253
315
|
if self.config.shift_terminal:
|
254
316
|
sigmas = self.stretch_shift_to_terminal(sigmas)
|
255
317
|
|
318
|
+
# 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules
|
256
319
|
if self.config.use_karras_sigmas:
|
257
320
|
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
258
|
-
|
259
321
|
elif self.config.use_exponential_sigmas:
|
260
322
|
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
261
|
-
|
262
323
|
elif self.config.use_beta_sigmas:
|
263
324
|
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
264
325
|
|
326
|
+
# 5. Convert sigmas and timesteps to tensors and move to specified device
|
265
327
|
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
266
|
-
|
328
|
+
if not is_timesteps_provided:
|
329
|
+
timesteps = sigmas * self.config.num_train_timesteps
|
330
|
+
else:
|
331
|
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
|
267
332
|
|
333
|
+
# 6. Append the terminal sigma value.
|
334
|
+
# If a model requires inverted sigma schedule for denoising but timesteps without inversion, the
|
335
|
+
# `invert_sigmas` flag can be set to `True`. This case is only required in Mochi
|
268
336
|
if self.config.invert_sigmas:
|
269
337
|
sigmas = 1.0 - sigmas
|
270
338
|
timesteps = sigmas * self.config.num_train_timesteps
|
@@ -272,7 +340,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
272
340
|
else:
|
273
341
|
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
274
342
|
|
275
|
-
self.timesteps = timesteps
|
343
|
+
self.timesteps = timesteps
|
276
344
|
self.sigmas = sigmas
|
277
345
|
self._step_index = None
|
278
346
|
self._begin_index = None
|
@@ -309,6 +377,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
309
377
|
s_tmax: float = float("inf"),
|
310
378
|
s_noise: float = 1.0,
|
311
379
|
generator: Optional[torch.Generator] = None,
|
380
|
+
per_token_timesteps: Optional[torch.Tensor] = None,
|
312
381
|
return_dict: bool = True,
|
313
382
|
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
314
383
|
"""
|
@@ -329,14 +398,17 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
329
398
|
Scaling factor for noise added to the sample.
|
330
399
|
generator (`torch.Generator`, *optional*):
|
331
400
|
A random number generator.
|
401
|
+
per_token_timesteps (`torch.Tensor`, *optional*):
|
402
|
+
The timesteps for each token in the sample.
|
332
403
|
return_dict (`bool`):
|
333
|
-
Whether or not to return a
|
334
|
-
tuple.
|
404
|
+
Whether or not to return a
|
405
|
+
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
|
335
406
|
|
336
407
|
Returns:
|
337
|
-
[`~schedulers.
|
338
|
-
If return_dict is `True`,
|
339
|
-
|
408
|
+
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`:
|
409
|
+
If return_dict is `True`,
|
410
|
+
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned,
|
411
|
+
otherwise a tuple is returned where the first element is the sample tensor.
|
340
412
|
"""
|
341
413
|
|
342
414
|
if (
|
@@ -347,7 +419,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
347
419
|
raise ValueError(
|
348
420
|
(
|
349
421
|
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
350
|
-
" `
|
422
|
+
" `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
351
423
|
" one of the `scheduler.timesteps` as a timestep."
|
352
424
|
),
|
353
425
|
)
|
@@ -358,16 +430,26 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
358
430
|
# Upcast to avoid precision issues when computing prev_sample
|
359
431
|
sample = sample.to(torch.float32)
|
360
432
|
|
361
|
-
|
362
|
-
|
433
|
+
if per_token_timesteps is not None:
|
434
|
+
per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps
|
363
435
|
|
364
|
-
|
436
|
+
sigmas = self.sigmas[:, None, None]
|
437
|
+
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
|
438
|
+
lower_sigmas = lower_mask * sigmas
|
439
|
+
lower_sigmas, _ = lower_sigmas.max(dim=0)
|
440
|
+
dt = (per_token_sigmas - lower_sigmas)[..., None]
|
441
|
+
else:
|
442
|
+
sigma = self.sigmas[self.step_index]
|
443
|
+
sigma_next = self.sigmas[self.step_index + 1]
|
444
|
+
dt = sigma_next - sigma
|
365
445
|
|
366
|
-
|
367
|
-
prev_sample = prev_sample.to(model_output.dtype)
|
446
|
+
prev_sample = sample + dt * model_output
|
368
447
|
|
369
448
|
# upon completion increase step index by one
|
370
449
|
self._step_index += 1
|
450
|
+
if per_token_timesteps is None:
|
451
|
+
# Cast sample back to model compatible dtype
|
452
|
+
prev_sample = prev_sample.to(model_output.dtype)
|
371
453
|
|
372
454
|
if not return_dict:
|
373
455
|
return (prev_sample,)
|
@@ -454,5 +536,11 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
454
536
|
)
|
455
537
|
return sigmas
|
456
538
|
|
539
|
+
def _time_shift_exponential(self, mu, sigma, t):
|
540
|
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
541
|
+
|
542
|
+
def _time_shift_linear(self, mu, sigma, t):
|
543
|
+
return mu / (mu + (1 / t - 1) ** sigma)
|
544
|
+
|
457
545
|
def __len__(self):
|
458
546
|
return self.config.num_train_timesteps
|
@@ -228,13 +228,14 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
228
228
|
generator (`torch.Generator`, *optional*):
|
229
229
|
A random number generator.
|
230
230
|
return_dict (`bool`):
|
231
|
-
Whether or not to return a
|
232
|
-
tuple.
|
231
|
+
Whether or not to return a
|
232
|
+
[`~schedulers.scheduling_flow_match_heun_discrete.FlowMatchHeunDiscreteSchedulerOutput`] tuple.
|
233
233
|
|
234
234
|
Returns:
|
235
|
-
[`~schedulers.
|
236
|
-
If return_dict is `True`,
|
237
|
-
|
235
|
+
[`~schedulers.scheduling_flow_match_heun_discrete.FlowMatchHeunDiscreteSchedulerOutput`] or `tuple`:
|
236
|
+
If return_dict is `True`,
|
237
|
+
[`~schedulers.scheduling_flow_match_heun_discrete.FlowMatchHeunDiscreteSchedulerOutput`] is returned,
|
238
|
+
otherwise a tuple is returned where the first element is the sample tensor.
|
238
239
|
"""
|
239
240
|
|
240
241
|
if (
|
@@ -245,7 +246,7 @@ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
245
246
|
raise ValueError(
|
246
247
|
(
|
247
248
|
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
248
|
-
" `
|
249
|
+
" `FlowMatchHeunDiscreteScheduler.step()` is not supported. Make sure to pass"
|
249
250
|
" one of the `scheduler.timesteps` as a timestep."
|
250
251
|
),
|
251
252
|
)
|
@@ -342,7 +342,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
342
342
|
timesteps = torch.from_numpy(timesteps)
|
343
343
|
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
|
344
344
|
|
345
|
-
self.timesteps = timesteps.to(device=device)
|
345
|
+
self.timesteps = timesteps.to(device=device, dtype=torch.float32)
|
346
346
|
|
347
347
|
# empty dt and derivative
|
348
348
|
self.prev_derivative = None
|
@@ -413,8 +413,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
|
413
413
|
|
414
414
|
if timesteps[0] >= self.config.num_train_timesteps:
|
415
415
|
raise ValueError(
|
416
|
-
f"`timesteps` must start before `self.config.train_timesteps`:"
|
417
|
-
f" {self.config.num_train_timesteps}."
|
416
|
+
f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
|
418
417
|
)
|
419
418
|
|
420
419
|
# Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
|
@@ -311,7 +311,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
311
311
|
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
312
312
|
|
313
313
|
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
314
|
-
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
314
|
+
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.float32)
|
315
315
|
self._step_index = None
|
316
316
|
self._begin_index = None
|
317
317
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
@@ -319,7 +319,11 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
|
|
319
319
|
prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance
|
320
320
|
|
321
321
|
# 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf
|
322
|
-
|
322
|
+
# The computation reported in Algorithm 1 Line 5 is incorrect. Line 5 refers to formula (8a) of the same paper,
|
323
|
+
# which tells to sample from a Gaussian distribution with mean "(alpha_prod_t_prev**0.5) * original_image"
|
324
|
+
# and variance "(1 - alpha_prod_t_prev)". This means that the standard Gaussian distribution "noise" should be
|
325
|
+
# scaled by the square root of the variance (as it is done here), however Algorithm 1 Line 5 tells to scale by the variance.
|
326
|
+
prev_known_part = (alpha_prod_t_prev**0.5) * original_image + ((1 - alpha_prod_t_prev) ** 0.5) * noise
|
323
327
|
|
324
328
|
# 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf
|
325
329
|
pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part
|
@@ -0,0 +1,265 @@
|
|
1
|
+
# # Copyright 2024 Sana-Sprint Authors and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
16
|
+
# and https://github.com/hojonathanho/diffusion
|
17
|
+
|
18
|
+
from dataclasses import dataclass
|
19
|
+
from typing import Optional, Tuple, Union
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
import torch
|
23
|
+
|
24
|
+
from ..configuration_utils import ConfigMixin, register_to_config
|
25
|
+
from ..schedulers.scheduling_utils import SchedulerMixin
|
26
|
+
from ..utils import BaseOutput, logging
|
27
|
+
from ..utils.torch_utils import randn_tensor
|
28
|
+
|
29
|
+
|
30
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
31
|
+
|
32
|
+
|
33
|
+
@dataclass
|
34
|
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->SCM
|
35
|
+
class SCMSchedulerOutput(BaseOutput):
|
36
|
+
"""
|
37
|
+
Output class for the scheduler's `step` function output.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
41
|
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
42
|
+
denoising loop.
|
43
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
44
|
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
45
|
+
`pred_original_sample` can be used to preview progress or for guidance.
|
46
|
+
"""
|
47
|
+
|
48
|
+
prev_sample: torch.Tensor
|
49
|
+
pred_original_sample: Optional[torch.Tensor] = None
|
50
|
+
|
51
|
+
|
52
|
+
class SCMScheduler(SchedulerMixin, ConfigMixin):
|
53
|
+
"""
|
54
|
+
`SCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
|
55
|
+
non-Markovian guidance. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass
|
56
|
+
documentation for the generic methods the library implements for all schedulers such as loading and saving.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
num_train_timesteps (`int`, defaults to 1000):
|
60
|
+
The number of diffusion steps to train the model.
|
61
|
+
prediction_type (`str`, defaults to `trigflow`):
|
62
|
+
Prediction type of the scheduler function. Currently only supports "trigflow".
|
63
|
+
sigma_data (`float`, defaults to 0.5):
|
64
|
+
The standard deviation of the noise added during multi-step inference.
|
65
|
+
"""
|
66
|
+
|
67
|
+
# _compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
68
|
+
order = 1
|
69
|
+
|
70
|
+
@register_to_config
|
71
|
+
def __init__(
|
72
|
+
self,
|
73
|
+
num_train_timesteps: int = 1000,
|
74
|
+
prediction_type: str = "trigflow",
|
75
|
+
sigma_data: float = 0.5,
|
76
|
+
):
|
77
|
+
"""
|
78
|
+
Initialize the SCM scheduler.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
num_train_timesteps (`int`, defaults to 1000):
|
82
|
+
The number of diffusion steps to train the model.
|
83
|
+
prediction_type (`str`, defaults to `trigflow`):
|
84
|
+
Prediction type of the scheduler function. Currently only supports "trigflow".
|
85
|
+
sigma_data (`float`, defaults to 0.5):
|
86
|
+
The standard deviation of the noise added during multi-step inference.
|
87
|
+
"""
|
88
|
+
# standard deviation of the initial noise distribution
|
89
|
+
self.init_noise_sigma = 1.0
|
90
|
+
|
91
|
+
# setable values
|
92
|
+
self.num_inference_steps = None
|
93
|
+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
94
|
+
|
95
|
+
self._step_index = None
|
96
|
+
self._begin_index = None
|
97
|
+
|
98
|
+
@property
|
99
|
+
def step_index(self):
|
100
|
+
return self._step_index
|
101
|
+
|
102
|
+
@property
|
103
|
+
def begin_index(self):
|
104
|
+
return self._begin_index
|
105
|
+
|
106
|
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
107
|
+
def set_begin_index(self, begin_index: int = 0):
|
108
|
+
"""
|
109
|
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
begin_index (`int`):
|
113
|
+
The begin index for the scheduler.
|
114
|
+
"""
|
115
|
+
self._begin_index = begin_index
|
116
|
+
|
117
|
+
def set_timesteps(
|
118
|
+
self,
|
119
|
+
num_inference_steps: int,
|
120
|
+
timesteps: torch.Tensor = None,
|
121
|
+
device: Union[str, torch.device] = None,
|
122
|
+
max_timesteps: float = 1.57080,
|
123
|
+
intermediate_timesteps: float = 1.3,
|
124
|
+
):
|
125
|
+
"""
|
126
|
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
127
|
+
|
128
|
+
Args:
|
129
|
+
num_inference_steps (`int`):
|
130
|
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
131
|
+
timesteps (`torch.Tensor`, *optional*):
|
132
|
+
Custom timesteps to use for the denoising process.
|
133
|
+
max_timesteps (`float`, defaults to 1.57080):
|
134
|
+
The maximum timestep value used in the SCM scheduler.
|
135
|
+
intermediate_timesteps (`float`, *optional*, defaults to 1.3):
|
136
|
+
The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
|
137
|
+
"""
|
138
|
+
if num_inference_steps > self.config.num_train_timesteps:
|
139
|
+
raise ValueError(
|
140
|
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
141
|
+
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
142
|
+
f" maximal {self.config.num_train_timesteps} timesteps."
|
143
|
+
)
|
144
|
+
|
145
|
+
if timesteps is not None and len(timesteps) != num_inference_steps + 1:
|
146
|
+
raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
|
147
|
+
|
148
|
+
if timesteps is not None and max_timesteps is not None:
|
149
|
+
raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
|
150
|
+
|
151
|
+
if timesteps is None and max_timesteps is None:
|
152
|
+
raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
|
153
|
+
|
154
|
+
if intermediate_timesteps is not None and num_inference_steps != 2:
|
155
|
+
raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
|
156
|
+
|
157
|
+
self.num_inference_steps = num_inference_steps
|
158
|
+
|
159
|
+
if timesteps is not None:
|
160
|
+
if isinstance(timesteps, list):
|
161
|
+
self.timesteps = torch.tensor(timesteps, device=device).float()
|
162
|
+
elif isinstance(timesteps, torch.Tensor):
|
163
|
+
self.timesteps = timesteps.to(device).float()
|
164
|
+
else:
|
165
|
+
raise ValueError(f"Unsupported timesteps type: {type(timesteps)}")
|
166
|
+
elif intermediate_timesteps is not None:
|
167
|
+
self.timesteps = torch.tensor([max_timesteps, intermediate_timesteps, 0], device=device).float()
|
168
|
+
else:
|
169
|
+
# max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
|
170
|
+
self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float()
|
171
|
+
print(f"Set timesteps: {self.timesteps}")
|
172
|
+
|
173
|
+
self._step_index = None
|
174
|
+
self._begin_index = None
|
175
|
+
|
176
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
177
|
+
def _init_step_index(self, timestep):
|
178
|
+
if self.begin_index is None:
|
179
|
+
if isinstance(timestep, torch.Tensor):
|
180
|
+
timestep = timestep.to(self.timesteps.device)
|
181
|
+
self._step_index = self.index_for_timestep(timestep)
|
182
|
+
else:
|
183
|
+
self._step_index = self._begin_index
|
184
|
+
|
185
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
186
|
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
187
|
+
if schedule_timesteps is None:
|
188
|
+
schedule_timesteps = self.timesteps
|
189
|
+
|
190
|
+
indices = (schedule_timesteps == timestep).nonzero()
|
191
|
+
|
192
|
+
# The sigma index that is taken for the **very** first `step`
|
193
|
+
# is always the second index (or the last index if there is only 1)
|
194
|
+
# This way we can ensure we don't accidentally skip a sigma in
|
195
|
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
196
|
+
pos = 1 if len(indices) > 1 else 0
|
197
|
+
|
198
|
+
return indices[pos].item()
|
199
|
+
|
200
|
+
def step(
|
201
|
+
self,
|
202
|
+
model_output: torch.FloatTensor,
|
203
|
+
timestep: float,
|
204
|
+
sample: torch.FloatTensor,
|
205
|
+
generator: torch.Generator = None,
|
206
|
+
return_dict: bool = True,
|
207
|
+
) -> Union[SCMSchedulerOutput, Tuple]:
|
208
|
+
"""
|
209
|
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
210
|
+
process from the learned model outputs (most often the predicted noise).
|
211
|
+
|
212
|
+
Args:
|
213
|
+
model_output (`torch.FloatTensor`):
|
214
|
+
The direct output from learned diffusion model.
|
215
|
+
timestep (`float`):
|
216
|
+
The current discrete timestep in the diffusion chain.
|
217
|
+
sample (`torch.FloatTensor`):
|
218
|
+
A current instance of a sample created by the diffusion process.
|
219
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
220
|
+
Whether or not to return a [`~schedulers.scheduling_scm.SCMSchedulerOutput`] or `tuple`.
|
221
|
+
Returns:
|
222
|
+
[`~schedulers.scheduling_utils.SCMSchedulerOutput`] or `tuple`:
|
223
|
+
If return_dict is `True`, [`~schedulers.scheduling_scm.SCMSchedulerOutput`] is returned, otherwise a
|
224
|
+
tuple is returned where the first element is the sample tensor.
|
225
|
+
"""
|
226
|
+
if self.num_inference_steps is None:
|
227
|
+
raise ValueError(
|
228
|
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
229
|
+
)
|
230
|
+
|
231
|
+
if self.step_index is None:
|
232
|
+
self._init_step_index(timestep)
|
233
|
+
|
234
|
+
# 2. compute alphas, betas
|
235
|
+
t = self.timesteps[self.step_index + 1]
|
236
|
+
s = self.timesteps[self.step_index]
|
237
|
+
|
238
|
+
# 4. Different Parameterization:
|
239
|
+
parameterization = self.config.prediction_type
|
240
|
+
|
241
|
+
if parameterization == "trigflow":
|
242
|
+
pred_x0 = torch.cos(s) * sample - torch.sin(s) * model_output
|
243
|
+
else:
|
244
|
+
raise ValueError(f"Unsupported parameterization: {parameterization}")
|
245
|
+
|
246
|
+
# 5. Sample z ~ N(0, I), For MultiStep Inference
|
247
|
+
# Noise is not used for one-step sampling.
|
248
|
+
if len(self.timesteps) > 1:
|
249
|
+
noise = (
|
250
|
+
randn_tensor(model_output.shape, device=model_output.device, generator=generator)
|
251
|
+
* self.config.sigma_data
|
252
|
+
)
|
253
|
+
prev_sample = torch.cos(t) * pred_x0 + torch.sin(t) * noise
|
254
|
+
else:
|
255
|
+
prev_sample = pred_x0
|
256
|
+
|
257
|
+
self._step_index += 1
|
258
|
+
|
259
|
+
if not return_dict:
|
260
|
+
return (prev_sample, pred_x0)
|
261
|
+
|
262
|
+
return SCMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0)
|
263
|
+
|
264
|
+
def __len__(self):
|
265
|
+
return self.config.num_train_timesteps
|
@@ -431,8 +431,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|
431
431
|
|
432
432
|
if timesteps[0] >= self.config.num_train_timesteps:
|
433
433
|
raise ValueError(
|
434
|
-
f"`timesteps` must start before `self.config.train_timesteps`:"
|
435
|
-
f" {self.config.num_train_timesteps}."
|
434
|
+
f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
|
436
435
|
)
|
437
436
|
|
438
437
|
# Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
|
@@ -19,6 +19,7 @@ from typing import Optional, Union
|
|
19
19
|
|
20
20
|
import torch
|
21
21
|
from huggingface_hub.utils import validate_hf_hub_args
|
22
|
+
from typing_extensions import Self
|
22
23
|
|
23
24
|
from ..utils import BaseOutput, PushToHubMixin
|
24
25
|
|
@@ -99,7 +100,7 @@ class SchedulerMixin(PushToHubMixin):
|
|
99
100
|
subfolder: Optional[str] = None,
|
100
101
|
return_unused_kwargs=False,
|
101
102
|
**kwargs,
|
102
|
-
):
|
103
|
+
) -> Self:
|
103
104
|
r"""
|
104
105
|
Instantiate a scheduler from a pre-defined JSON configuration file in a local directory or Hub repository.
|
105
106
|
|