diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
2
|
+
|
3
|
+
from diffusers.utils.import_utils import is_optimum_quanto_version
|
4
|
+
|
5
|
+
from ...utils import (
|
6
|
+
get_module_from_name,
|
7
|
+
is_accelerate_available,
|
8
|
+
is_accelerate_version,
|
9
|
+
is_optimum_quanto_available,
|
10
|
+
is_torch_available,
|
11
|
+
logging,
|
12
|
+
)
|
13
|
+
from ..base import DiffusersQuantizer
|
14
|
+
|
15
|
+
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from ...models.modeling_utils import ModelMixin
|
18
|
+
|
19
|
+
|
20
|
+
if is_torch_available():
|
21
|
+
import torch
|
22
|
+
|
23
|
+
if is_accelerate_available():
|
24
|
+
from accelerate.utils import CustomDtype, set_module_tensor_to_device
|
25
|
+
|
26
|
+
if is_optimum_quanto_available():
|
27
|
+
from .utils import _replace_with_quanto_layers
|
28
|
+
|
29
|
+
logger = logging.get_logger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
class QuantoQuantizer(DiffusersQuantizer):
|
33
|
+
r"""
|
34
|
+
Diffusers Quantizer for Optimum Quanto
|
35
|
+
"""
|
36
|
+
|
37
|
+
use_keep_in_fp32_modules = True
|
38
|
+
requires_calibration = False
|
39
|
+
required_packages = ["quanto", "accelerate"]
|
40
|
+
|
41
|
+
def __init__(self, quantization_config, **kwargs):
|
42
|
+
super().__init__(quantization_config, **kwargs)
|
43
|
+
|
44
|
+
def validate_environment(self, *args, **kwargs):
|
45
|
+
if not is_optimum_quanto_available():
|
46
|
+
raise ImportError(
|
47
|
+
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
|
48
|
+
)
|
49
|
+
if not is_optimum_quanto_version(">=", "0.2.6"):
|
50
|
+
raise ImportError(
|
51
|
+
"Loading an optimum-quanto quantized model requires `optimum-quanto>=0.2.6`. "
|
52
|
+
"Please upgrade your installation with `pip install --upgrade optimum-quanto"
|
53
|
+
)
|
54
|
+
|
55
|
+
if not is_accelerate_available():
|
56
|
+
raise ImportError(
|
57
|
+
"Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)"
|
58
|
+
)
|
59
|
+
|
60
|
+
device_map = kwargs.get("device_map", None)
|
61
|
+
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
|
62
|
+
raise ValueError(
|
63
|
+
"`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the Quanto backend"
|
64
|
+
)
|
65
|
+
|
66
|
+
def check_if_quantized_param(
|
67
|
+
self,
|
68
|
+
model: "ModelMixin",
|
69
|
+
param_value: "torch.Tensor",
|
70
|
+
param_name: str,
|
71
|
+
state_dict: Dict[str, Any],
|
72
|
+
**kwargs,
|
73
|
+
):
|
74
|
+
# Quanto imports diffusers internally. This is here to prevent circular imports
|
75
|
+
from optimum.quanto import QModuleMixin, QTensor
|
76
|
+
from optimum.quanto.tensor.packed import PackedTensor
|
77
|
+
|
78
|
+
module, tensor_name = get_module_from_name(model, param_name)
|
79
|
+
if self.pre_quantized and any(isinstance(module, t) for t in [QTensor, PackedTensor]):
|
80
|
+
return True
|
81
|
+
elif isinstance(module, QModuleMixin) and "weight" in tensor_name:
|
82
|
+
return not module.frozen
|
83
|
+
|
84
|
+
return False
|
85
|
+
|
86
|
+
def create_quantized_param(
|
87
|
+
self,
|
88
|
+
model: "ModelMixin",
|
89
|
+
param_value: "torch.Tensor",
|
90
|
+
param_name: str,
|
91
|
+
target_device: "torch.device",
|
92
|
+
*args,
|
93
|
+
**kwargs,
|
94
|
+
):
|
95
|
+
"""
|
96
|
+
Create the quantized parameter by calling .freeze() after setting it to the module.
|
97
|
+
"""
|
98
|
+
|
99
|
+
dtype = kwargs.get("dtype", torch.float32)
|
100
|
+
module, tensor_name = get_module_from_name(model, param_name)
|
101
|
+
if self.pre_quantized:
|
102
|
+
setattr(module, tensor_name, param_value)
|
103
|
+
else:
|
104
|
+
set_module_tensor_to_device(model, param_name, target_device, param_value, dtype)
|
105
|
+
module.freeze()
|
106
|
+
module.weight.requires_grad = False
|
107
|
+
|
108
|
+
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
|
109
|
+
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
110
|
+
return max_memory
|
111
|
+
|
112
|
+
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
113
|
+
if is_accelerate_version(">=", "0.27.0"):
|
114
|
+
mapping = {
|
115
|
+
"int8": torch.int8,
|
116
|
+
"float8": CustomDtype.FP8,
|
117
|
+
"int4": CustomDtype.INT4,
|
118
|
+
"int2": CustomDtype.INT2,
|
119
|
+
}
|
120
|
+
target_dtype = mapping[self.quantization_config.weights_dtype]
|
121
|
+
|
122
|
+
return target_dtype
|
123
|
+
|
124
|
+
def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype":
|
125
|
+
if torch_dtype is None:
|
126
|
+
logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
|
127
|
+
torch_dtype = torch.float32
|
128
|
+
return torch_dtype
|
129
|
+
|
130
|
+
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
131
|
+
# Quanto imports diffusers internally. This is here to prevent circular imports
|
132
|
+
from optimum.quanto import QModuleMixin
|
133
|
+
|
134
|
+
not_missing_keys = []
|
135
|
+
for name, module in model.named_modules():
|
136
|
+
if isinstance(module, QModuleMixin):
|
137
|
+
for missing in missing_keys:
|
138
|
+
if (
|
139
|
+
(name in missing or name in f"{prefix}.{missing}")
|
140
|
+
and not missing.endswith(".weight")
|
141
|
+
and not missing.endswith(".bias")
|
142
|
+
):
|
143
|
+
not_missing_keys.append(missing)
|
144
|
+
return [k for k in missing_keys if k not in not_missing_keys]
|
145
|
+
|
146
|
+
def _process_model_before_weight_loading(
|
147
|
+
self,
|
148
|
+
model: "ModelMixin",
|
149
|
+
device_map,
|
150
|
+
keep_in_fp32_modules: List[str] = [],
|
151
|
+
**kwargs,
|
152
|
+
):
|
153
|
+
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
|
154
|
+
|
155
|
+
if not isinstance(self.modules_to_not_convert, list):
|
156
|
+
self.modules_to_not_convert = [self.modules_to_not_convert]
|
157
|
+
|
158
|
+
self.modules_to_not_convert.extend(keep_in_fp32_modules)
|
159
|
+
|
160
|
+
model = _replace_with_quanto_layers(
|
161
|
+
model,
|
162
|
+
modules_to_not_convert=self.modules_to_not_convert,
|
163
|
+
quantization_config=self.quantization_config,
|
164
|
+
pre_quantized=self.pre_quantized,
|
165
|
+
)
|
166
|
+
model.config.quantization_config = self.quantization_config
|
167
|
+
|
168
|
+
def _process_model_after_weight_loading(self, model, **kwargs):
|
169
|
+
return model
|
170
|
+
|
171
|
+
@property
|
172
|
+
def is_trainable(self):
|
173
|
+
return True
|
174
|
+
|
175
|
+
@property
|
176
|
+
def is_serializable(self):
|
177
|
+
return True
|
@@ -0,0 +1,60 @@
|
|
1
|
+
import torch.nn as nn
|
2
|
+
|
3
|
+
from ...utils import is_accelerate_available, logging
|
4
|
+
|
5
|
+
|
6
|
+
logger = logging.get_logger(__name__)
|
7
|
+
|
8
|
+
if is_accelerate_available():
|
9
|
+
from accelerate import init_empty_weights
|
10
|
+
|
11
|
+
|
12
|
+
def _replace_with_quanto_layers(model, quantization_config, modules_to_not_convert: list, pre_quantized=False):
|
13
|
+
# Quanto imports diffusers internally. These are placed here to avoid circular imports
|
14
|
+
from optimum.quanto import QLinear, freeze, qfloat8, qint2, qint4, qint8
|
15
|
+
|
16
|
+
def _get_weight_type(dtype: str):
|
17
|
+
return {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}[dtype]
|
18
|
+
|
19
|
+
def _replace_layers(model, quantization_config, modules_to_not_convert):
|
20
|
+
has_children = list(model.children())
|
21
|
+
if not has_children:
|
22
|
+
return model
|
23
|
+
|
24
|
+
for name, module in model.named_children():
|
25
|
+
_replace_layers(module, quantization_config, modules_to_not_convert)
|
26
|
+
|
27
|
+
if name in modules_to_not_convert:
|
28
|
+
continue
|
29
|
+
|
30
|
+
if isinstance(module, nn.Linear):
|
31
|
+
with init_empty_weights():
|
32
|
+
qlinear = QLinear(
|
33
|
+
in_features=module.in_features,
|
34
|
+
out_features=module.out_features,
|
35
|
+
bias=module.bias is not None,
|
36
|
+
dtype=module.weight.dtype,
|
37
|
+
weights=_get_weight_type(quantization_config.weights_dtype),
|
38
|
+
)
|
39
|
+
model._modules[name] = qlinear
|
40
|
+
model._modules[name].source_cls = type(module)
|
41
|
+
model._modules[name].requires_grad_(False)
|
42
|
+
|
43
|
+
return model
|
44
|
+
|
45
|
+
model = _replace_layers(model, quantization_config, modules_to_not_convert)
|
46
|
+
has_been_replaced = any(isinstance(replaced_module, QLinear) for _, replaced_module in model.named_modules())
|
47
|
+
|
48
|
+
if not has_been_replaced:
|
49
|
+
logger.warning(
|
50
|
+
f"{model.__class__.__name__} does not appear to have any `nn.Linear` modules. Quantization will not be applied."
|
51
|
+
" Please check your model architecture, or submit an issue on Github if you think this is a bug."
|
52
|
+
" https://github.com/huggingface/diffusers/issues/new"
|
53
|
+
)
|
54
|
+
|
55
|
+
# We need to freeze the pre_quantized model in order for the loaded state_dict and model state dict
|
56
|
+
# to match when trying to load weights with load_model_dict_into_meta
|
57
|
+
if pre_quantized:
|
58
|
+
freeze(model)
|
59
|
+
|
60
|
+
return model
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -23,7 +23,14 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
|
23
23
|
|
24
24
|
from packaging import version
|
25
25
|
|
26
|
-
from ...utils import
|
26
|
+
from ...utils import (
|
27
|
+
get_module_from_name,
|
28
|
+
is_torch_available,
|
29
|
+
is_torch_version,
|
30
|
+
is_torchao_available,
|
31
|
+
is_torchao_version,
|
32
|
+
logging,
|
33
|
+
)
|
27
34
|
from ..base import DiffusersQuantizer
|
28
35
|
|
29
36
|
|
@@ -62,6 +69,43 @@ if is_torchao_available():
|
|
62
69
|
from torchao.quantization import quantize_
|
63
70
|
|
64
71
|
|
72
|
+
def _update_torch_safe_globals():
|
73
|
+
safe_globals = [
|
74
|
+
(torch.uint1, "torch.uint1"),
|
75
|
+
(torch.uint2, "torch.uint2"),
|
76
|
+
(torch.uint3, "torch.uint3"),
|
77
|
+
(torch.uint4, "torch.uint4"),
|
78
|
+
(torch.uint5, "torch.uint5"),
|
79
|
+
(torch.uint6, "torch.uint6"),
|
80
|
+
(torch.uint7, "torch.uint7"),
|
81
|
+
]
|
82
|
+
try:
|
83
|
+
from torchao.dtypes import NF4Tensor
|
84
|
+
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
|
85
|
+
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
|
86
|
+
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
|
87
|
+
|
88
|
+
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
|
89
|
+
|
90
|
+
except (ImportError, ModuleNotFoundError) as e:
|
91
|
+
logger.warning(
|
92
|
+
"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`"
|
93
|
+
)
|
94
|
+
logger.debug(e)
|
95
|
+
|
96
|
+
finally:
|
97
|
+
torch.serialization.add_safe_globals(safe_globals=safe_globals)
|
98
|
+
|
99
|
+
|
100
|
+
if (
|
101
|
+
is_torch_available()
|
102
|
+
and is_torch_version(">=", "2.6.0")
|
103
|
+
and is_torchao_available()
|
104
|
+
and is_torchao_version(">=", "0.7.0")
|
105
|
+
):
|
106
|
+
_update_torch_safe_globals()
|
107
|
+
|
108
|
+
|
65
109
|
logger = logging.get_logger(__name__)
|
66
110
|
|
67
111
|
|
@@ -215,6 +259,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
|
215
259
|
target_device: "torch.device",
|
216
260
|
state_dict: Dict[str, Any],
|
217
261
|
unexpected_keys: List[str],
|
262
|
+
**kwargs,
|
218
263
|
):
|
219
264
|
r"""
|
220
265
|
Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor,
|
diffusers/schedulers/__init__.py
CHANGED
@@ -68,6 +68,7 @@ else:
|
|
68
68
|
_import_structure["scheduling_pndm"] = ["PNDMScheduler"]
|
69
69
|
_import_structure["scheduling_repaint"] = ["RePaintScheduler"]
|
70
70
|
_import_structure["scheduling_sasolver"] = ["SASolverScheduler"]
|
71
|
+
_import_structure["scheduling_scm"] = ["SCMScheduler"]
|
71
72
|
_import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"]
|
72
73
|
_import_structure["scheduling_tcd"] = ["TCDScheduler"]
|
73
74
|
_import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
|
@@ -168,13 +169,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
168
169
|
from .scheduling_pndm import PNDMScheduler
|
169
170
|
from .scheduling_repaint import RePaintScheduler
|
170
171
|
from .scheduling_sasolver import SASolverScheduler
|
172
|
+
from .scheduling_scm import SCMScheduler
|
171
173
|
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
172
174
|
from .scheduling_tcd import TCDScheduler
|
173
175
|
from .scheduling_unclip import UnCLIPScheduler
|
174
176
|
from .scheduling_unipc_multistep import UniPCMultistepScheduler
|
175
177
|
from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin
|
176
178
|
from .scheduling_vq_diffusion import VQDiffusionScheduler
|
177
|
-
|
178
179
|
try:
|
179
180
|
if not is_flax_available():
|
180
181
|
raise OptionalDependencyNotAvailable()
|
@@ -203,8 +203,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|
203
203
|
|
204
204
|
if timesteps[0] >= self.config.num_train_timesteps:
|
205
205
|
raise ValueError(
|
206
|
-
f"`timesteps` must start before `self.config.train_timesteps`:"
|
207
|
-
f" {self.config.num_train_timesteps}."
|
206
|
+
f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
|
208
207
|
)
|
209
208
|
|
210
209
|
timesteps = np.array(timesteps, dtype=np.int64)
|
@@ -266,7 +266,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
266
266
|
|
267
267
|
self.num_inference_steps = num_inference_steps
|
268
268
|
|
269
|
-
# "leading" and "trailing" corresponds to annotation of Table
|
269
|
+
# "leading" and "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
270
270
|
if self.config.timestep_spacing == "leading":
|
271
271
|
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
272
272
|
# creates integer timesteps by multiplying by ratio
|
@@ -142,7 +142,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|
142
142
|
The final `beta` value.
|
143
143
|
beta_schedule (`str`, defaults to `"linear"`):
|
144
144
|
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
145
|
-
`linear`, `scaled_linear`, or `
|
145
|
+
`linear`, `scaled_linear`, `squaredcos_cap_v2`, or `sigmoid`.
|
146
146
|
trained_betas (`np.ndarray`, *optional*):
|
147
147
|
An array of betas to pass directly to the constructor without using `beta_start` and `beta_end`.
|
148
148
|
variance_type (`str`, defaults to `"fixed_small"`):
|
@@ -279,8 +279,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|
279
279
|
|
280
280
|
if timesteps[0] >= self.config.num_train_timesteps:
|
281
281
|
raise ValueError(
|
282
|
-
f"`timesteps` must start before `self.config.train_timesteps`:"
|
283
|
-
f" {self.config.num_train_timesteps}."
|
282
|
+
f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
|
284
283
|
)
|
285
284
|
|
286
285
|
timesteps = np.array(timesteps, dtype=np.int64)
|
@@ -289,8 +289,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
289
289
|
|
290
290
|
if timesteps[0] >= self.config.num_train_timesteps:
|
291
291
|
raise ValueError(
|
292
|
-
f"`timesteps` must start before `self.config.train_timesteps`:"
|
293
|
-
f" {self.config.num_train_timesteps}."
|
292
|
+
f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
|
294
293
|
)
|
295
294
|
|
296
295
|
timesteps = np.array(timesteps, dtype=np.int64)
|
@@ -136,8 +136,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
136
136
|
sampling, and `solver_order=3` for unconditional sampling.
|
137
137
|
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
138
138
|
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
139
|
-
`sample` (directly predicts the noisy sample
|
140
|
-
Video](https://imagen.research.google/video/paper.pdf) paper)
|
139
|
+
`sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
|
140
|
+
Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
|
141
141
|
thresholding (`bool`, defaults to `False`):
|
142
142
|
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
143
143
|
as Stable Diffusion.
|
@@ -174,6 +174,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
174
174
|
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
|
175
175
|
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
|
176
176
|
`lambda(t)`.
|
177
|
+
use_flow_sigmas (`bool`, *optional*, defaults to `False`):
|
178
|
+
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
|
179
|
+
flow_shift (`float`, *optional*, defaults to 1.0):
|
180
|
+
The shift value for the timestep schedule for flow matching.
|
177
181
|
final_sigmas_type (`str`, defaults to `"zero"`):
|
178
182
|
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
179
183
|
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
@@ -395,12 +399,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
395
399
|
if self.config.use_karras_sigmas:
|
396
400
|
sigmas = np.flip(sigmas).copy()
|
397
401
|
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
398
|
-
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
402
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
403
|
+
if self.config.beta_schedule != "squaredcos_cap_v2":
|
404
|
+
timesteps = timesteps.round()
|
399
405
|
elif self.config.use_lu_lambdas:
|
400
406
|
lambdas = np.flip(log_sigmas.copy())
|
401
407
|
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
|
402
408
|
sigmas = np.exp(lambdas)
|
403
|
-
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
409
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
410
|
+
if self.config.beta_schedule != "squaredcos_cap_v2":
|
411
|
+
timesteps = timesteps.round()
|
404
412
|
elif self.config.use_exponential_sigmas:
|
405
413
|
sigmas = np.flip(sigmas).copy()
|
406
414
|
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
@@ -14,7 +14,7 @@
|
|
14
14
|
|
15
15
|
import math
|
16
16
|
from dataclasses import dataclass
|
17
|
-
from typing import Optional, Tuple, Union
|
17
|
+
from typing import List, Optional, Tuple, Union
|
18
18
|
|
19
19
|
import torch
|
20
20
|
|
@@ -77,6 +77,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
77
77
|
Video](https://imagen.research.google/video/paper.pdf) paper).
|
78
78
|
rho (`float`, *optional*, defaults to 7.0):
|
79
79
|
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
|
80
|
+
final_sigmas_type (`str`, defaults to `"zero"`):
|
81
|
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
82
|
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
80
83
|
"""
|
81
84
|
|
82
85
|
_compatibles = []
|
@@ -92,6 +95,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
92
95
|
num_train_timesteps: int = 1000,
|
93
96
|
prediction_type: str = "epsilon",
|
94
97
|
rho: float = 7.0,
|
98
|
+
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
|
95
99
|
):
|
96
100
|
if sigma_schedule not in ["karras", "exponential"]:
|
97
101
|
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
|
@@ -99,15 +103,24 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
99
103
|
# setable values
|
100
104
|
self.num_inference_steps = None
|
101
105
|
|
102
|
-
|
106
|
+
sigmas = torch.arange(num_train_timesteps + 1) / num_train_timesteps
|
103
107
|
if sigma_schedule == "karras":
|
104
|
-
sigmas = self._compute_karras_sigmas(
|
108
|
+
sigmas = self._compute_karras_sigmas(sigmas)
|
105
109
|
elif sigma_schedule == "exponential":
|
106
|
-
sigmas = self._compute_exponential_sigmas(
|
110
|
+
sigmas = self._compute_exponential_sigmas(sigmas)
|
107
111
|
|
108
112
|
self.timesteps = self.precondition_noise(sigmas)
|
109
113
|
|
110
|
-
self.
|
114
|
+
if self.config.final_sigmas_type == "sigma_min":
|
115
|
+
sigma_last = sigmas[-1]
|
116
|
+
elif self.config.final_sigmas_type == "zero":
|
117
|
+
sigma_last = 0
|
118
|
+
else:
|
119
|
+
raise ValueError(
|
120
|
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
121
|
+
)
|
122
|
+
|
123
|
+
self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
|
111
124
|
|
112
125
|
self.is_scale_input_called = False
|
113
126
|
|
@@ -197,7 +210,12 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
197
210
|
self.is_scale_input_called = True
|
198
211
|
return sample
|
199
212
|
|
200
|
-
def set_timesteps(
|
213
|
+
def set_timesteps(
|
214
|
+
self,
|
215
|
+
num_inference_steps: int = None,
|
216
|
+
device: Union[str, torch.device] = None,
|
217
|
+
sigmas: Optional[Union[torch.Tensor, List[float]]] = None,
|
218
|
+
):
|
201
219
|
"""
|
202
220
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
203
221
|
|
@@ -206,19 +224,36 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
|
206
224
|
The number of diffusion steps used when generating samples with a pre-trained model.
|
207
225
|
device (`str` or `torch.device`, *optional*):
|
208
226
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
227
|
+
sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
|
228
|
+
Custom sigmas to use for the denoising process. If not defined, the default behavior when
|
229
|
+
`num_inference_steps` is passed will be used.
|
209
230
|
"""
|
210
231
|
self.num_inference_steps = num_inference_steps
|
211
232
|
|
212
|
-
|
233
|
+
if sigmas is None:
|
234
|
+
sigmas = torch.linspace(0, 1, self.num_inference_steps)
|
235
|
+
elif isinstance(sigmas, float):
|
236
|
+
sigmas = torch.tensor(sigmas, dtype=torch.float32)
|
237
|
+
else:
|
238
|
+
sigmas = sigmas
|
213
239
|
if self.config.sigma_schedule == "karras":
|
214
|
-
sigmas = self._compute_karras_sigmas(
|
240
|
+
sigmas = self._compute_karras_sigmas(sigmas)
|
215
241
|
elif self.config.sigma_schedule == "exponential":
|
216
|
-
sigmas = self._compute_exponential_sigmas(
|
242
|
+
sigmas = self._compute_exponential_sigmas(sigmas)
|
217
243
|
|
218
244
|
sigmas = sigmas.to(dtype=torch.float32, device=device)
|
219
245
|
self.timesteps = self.precondition_noise(sigmas)
|
220
246
|
|
221
|
-
self.
|
247
|
+
if self.config.final_sigmas_type == "sigma_min":
|
248
|
+
sigma_last = sigmas[-1]
|
249
|
+
elif self.config.final_sigmas_type == "zero":
|
250
|
+
sigma_last = 0
|
251
|
+
else:
|
252
|
+
raise ValueError(
|
253
|
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
254
|
+
)
|
255
|
+
|
256
|
+
self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
|
222
257
|
self._step_index = None
|
223
258
|
self._begin_index = None
|
224
259
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|